package aliasmgr
import (
"encoding/binary"
"fmt"
"maps"
"slices"
"sync"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
// UpdateLinkAliases is a function type for a function that locates the active
// link that matches the given shortID and triggers an update based on the
// latest values of the alias manager.
type UpdateLinkAliases func(shortID lnwire.ShortChannelID) error
// ScidAliasMap is a map from a base short channel ID to a set of alias short
// channel IDs.
type ScidAliasMap map[lnwire.ShortChannelID][]lnwire.ShortChannelID
var (
// aliasBucket stores aliases as keys and their base SCIDs as values.
// This is used to populate the maps that the Manager uses. The keys
// are alias SCIDs and the values are their respective base SCIDs. This
// is used instead of the other way around (base -> alias...) because
// updating an alias would require fetching all the existing aliases,
// adding another one, and then flushing the write to disk. This is
// inefficient compared to N 1:1 mappings at the cost of marginally
// more disk space.
aliasBucket = []byte("alias-bucket")
// confirmedBucket stores whether or not a given base SCID should no
// longer have entries in the ToBase maps. The key is the SCID that is
// confirmed with 6 confirmations and is public, and the value is
// empty.
confirmedBucket = []byte("base-bucket")
// aliasAllocBucket is a root-level bucket that stores the last alias
// that was allocated. It is used to allocate a new alias when
// requested.
aliasAllocBucket = []byte("alias-alloc-bucket")
// lastAliasKey is a key in the aliasAllocBucket whose value is the
// last allocated alias ShortChannelID. This will be updated upon calls
// to RequestAlias.
lastAliasKey = []byte("last-alias-key")
// invoiceAliasBucket is a root-level bucket that stores the alias
// SCIDs that our peers send us in the channel_ready TLV. The keys are
// the ChannelID generated from the FundingOutpoint and the values are
// the remote peer's alias SCID.
invoiceAliasBucket = []byte("invoice-alias-bucket")
// byteOrder denotes the byte order of database (de)-serialization
// operations.
byteOrder = binary.BigEndian
// AliasStartBlockHeight is the starting block height of the alias
// range.
AliasStartBlockHeight uint32 = 16_000_000
// AliasEndBlockHeight is the ending block height of the alias range.
AliasEndBlockHeight uint32 = 16_250_000
// StartingAlias is the first alias ShortChannelID that will get
// assigned by RequestAlias. The starting BlockHeight is chosen so that
// legitimate SCIDs in integration tests aren't mistaken for an alias.
StartingAlias = lnwire.ShortChannelID{
BlockHeight: AliasStartBlockHeight,
TxIndex: 0,
TxPosition: 0,
}
// errNoBase is returned when a base SCID isn't found.
errNoBase = fmt.Errorf("no base found")
// errNoPeerAlias is returned when the peer's alias for a given
// channel is not found.
errNoPeerAlias = fmt.Errorf("no peer alias found")
// ErrAliasNotFound is returned when the alias is not found and can't
// be mapped to a base SCID.
ErrAliasNotFound = fmt.Errorf("alias not found")
)
// Manager is a struct that handles aliases for LND. It has an underlying
// database that can allocate aliases for channels, stores the peer's last
// alias for use in our hop hints, and contains mappings that both the Switch
// and Gossiper use.
type Manager struct {
backend kvdb.Backend
// linkAliasUpdater is a function used by the alias manager to
// facilitate live update of aliases in other subsystems.
linkAliasUpdater UpdateLinkAliases
// baseToSet is a mapping from the "base" SCID to the set of aliases
// for this channel. This mapping includes all channels that
// negotiated the option-scid-alias feature bit.
baseToSet ScidAliasMap
// aliasToBase is a mapping that maps all aliases for a given channel
// to its base SCID. This is only used for channels that have
// negotiated option-scid-alias feature bit.
aliasToBase map[lnwire.ShortChannelID]lnwire.ShortChannelID
// peerAlias is a cache for the alias SCIDs that our peers send us in
// the channel_ready TLV. The keys are the ChannelID generated from
// the FundingOutpoint and the values are the remote peer's alias SCID.
// The values should match the ones stored in the "invoice-alias-bucket"
// bucket.
peerAlias map[lnwire.ChannelID]lnwire.ShortChannelID
sync.RWMutex
}
// NewManager initializes an alias Manager from the passed database backend.
func NewManager(db kvdb.Backend, linkAliasUpdater UpdateLinkAliases) (*Manager,
error) {
m := &Manager{
backend: db,
baseToSet: make(ScidAliasMap),
linkAliasUpdater: linkAliasUpdater,
}
m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID)
err := m.populateMaps()
return m, err
}
// populateMaps reads the database state and populates the maps.
func (m *Manager) populateMaps() error {
// This map tracks the base SCIDs that are confirmed and don't need to
// have entries in the *ToBase mappings as they won't be used in the
// gossiper.
baseConfMap := make(map[lnwire.ShortChannelID]struct{})
// This map caches what is found in the database and is used to
// populate the Manager's actual maps.
aliasMap := make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
// This map caches the ChannelID/alias SCIDs stored in the database and
// is used to populate the Manager's cache.
peerAliasMap := make(map[lnwire.ChannelID]lnwire.ShortChannelID)
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
baseConfBucket, err := tx.CreateTopLevelBucket(confirmedBucket)
if err != nil {
return err
}
err = baseConfBucket.ForEach(func(k, v []byte) error {
// The key will the base SCID and the value will be
// empty. Existence in the bucket means the SCID is
// confirmed.
baseScid := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(k),
)
baseConfMap[baseScid] = struct{}{}
return nil
})
if err != nil {
return err
}
aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket)
if err != nil {
return err
}
err = aliasToBaseBucket.ForEach(func(k, v []byte) error {
// The key will be the alias SCID and the value will be
// the base SCID.
aliasScid := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(k),
)
baseScid := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(v),
)
aliasMap[aliasScid] = baseScid
return nil
})
if err != nil {
return err
}
invAliasBucket, err := tx.CreateTopLevelBucket(
invoiceAliasBucket,
)
if err != nil {
return err
}
err = invAliasBucket.ForEach(func(k, v []byte) error {
var chanID lnwire.ChannelID
copy(chanID[:], k)
alias := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(v),
)
peerAliasMap[chanID] = alias
return nil
})
return err
}, func() {
baseConfMap = make(map[lnwire.ShortChannelID]struct{})
aliasMap = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
peerAliasMap = make(map[lnwire.ChannelID]lnwire.ShortChannelID)
})
if err != nil {
return err
}
// Populate the baseToSet map regardless if the baseSCID is marked as
// public with 6 confirmations.
for aliasSCID, baseSCID := range aliasMap {
m.baseToSet[baseSCID] = append(m.baseToSet[baseSCID], aliasSCID)
// Skip if baseSCID is in the baseConfMap.
if _, ok := baseConfMap[baseSCID]; ok {
continue
}
m.aliasToBase[aliasSCID] = baseSCID
}
// Populate the peer alias cache.
m.peerAlias = peerAliasMap
return nil
}
// AddLocalAlias adds a database mapping from the passed alias to the passed
// base SCID. The gossip boolean marks whether or not to create a mapping
// that the gossiper will use. It is set to false for the upgrade path where
// the feature-bit is toggled on and there are existing channels. The linkUpdate
// flag is used to signal whether this function should also trigger an update
// on the htlcswitch scid alias maps.
func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID,
gossip, linkUpdate bool) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock()
unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
// If the caller does not want to allow the alias to be used
// for a channel update, we'll mark it in the baseConfBucket.
if !gossip {
var baseGossipBytes [8]byte
byteOrder.PutUint64(
baseGossipBytes[:], baseScid.ToUint64(),
)
confBucket, err := tx.CreateTopLevelBucket(
confirmedBucket,
)
if err != nil {
return err
}
err = confBucket.Put(baseGossipBytes[:], []byte{})
if err != nil {
return err
}
}
aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket)
if err != nil {
return err
}
var (
aliasBytes [8]byte
baseBytes [8]byte
)
byteOrder.PutUint64(aliasBytes[:], alias.ToUint64())
byteOrder.PutUint64(baseBytes[:], baseScid.ToUint64())
return aliasToBaseBucket.Put(aliasBytes[:], baseBytes[:])
}, func() {})
if err != nil {
return err
}
// Update the aliasToBase and baseToSet maps.
m.baseToSet[baseScid] = append(m.baseToSet[baseScid], alias)
// Only store the gossiper map if gossip is true.
if gossip {
m.aliasToBase[alias] = baseScid
}
// We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
// Finally, we trigger a htlcswitch update if the flag is set, in order
// for any future htlc that references the added alias to be properly
// routed.
if linkUpdate {
return m.linkAliasUpdater(baseScid)
}
return nil
}
// GetAliases fetches the set of aliases stored under a given base SCID from
// write-through caches.
func (m *Manager) GetAliases(
base lnwire.ShortChannelID) []lnwire.ShortChannelID {
m.RLock()
defer m.RUnlock()
aliasSet, ok := m.baseToSet[base]
if ok {
// Copy the found alias slice.
setCopy := make([]lnwire.ShortChannelID, len(aliasSet))
copy(setCopy, aliasSet)
return setCopy
}
return nil
}
// FindBaseSCID finds the base SCID for a given alias. This is used in the
// gossiper to find the correct SCID to lookup in the graph database.
func (m *Manager) FindBaseSCID(
alias lnwire.ShortChannelID) (lnwire.ShortChannelID, error) {
m.RLock()
defer m.RUnlock()
base, ok := m.aliasToBase[alias]
if ok {
return base, nil
}
return lnwire.ShortChannelID{}, errNoBase
}
// DeleteSixConfs removes a mapping for the gossiper once six confirmations
// have been reached and the channel is public. At this point, only the
// confirmed SCID should be used.
func (m *Manager) DeleteSixConfs(baseScid lnwire.ShortChannelID) error {
m.Lock()
defer m.Unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
baseConfBucket, err := tx.CreateTopLevelBucket(confirmedBucket)
if err != nil {
return err
}
var baseBytes [8]byte
byteOrder.PutUint64(baseBytes[:], baseScid.ToUint64())
return baseConfBucket.Put(baseBytes[:], []byte{})
}, func() {})
if err != nil {
return err
}
// Now that the database state has been updated, we'll delete all of
// the aliasToBase mappings for this SCID.
for alias, base := range m.aliasToBase {
if base.ToUint64() == baseScid.ToUint64() {
delete(m.aliasToBase, alias)
}
}
return nil
}
// DeleteLocalAlias removes a mapping from the database and the Manager's maps.
func (m *Manager) DeleteLocalAlias(alias,
baseScid lnwire.ShortChannelID) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock()
unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket)
if err != nil {
return err
}
var aliasBytes [8]byte
byteOrder.PutUint64(aliasBytes[:], alias.ToUint64())
// If the user attempts to delete an alias that doesn't exist,
// we'll want to inform them about it and not just do nothing.
if aliasToBaseBucket.Get(aliasBytes[:]) == nil {
return ErrAliasNotFound
}
return aliasToBaseBucket.Delete(aliasBytes[:])
}, func() {})
if err != nil {
return err
}
// Now that the database state has been updated, we'll delete the
// mapping from the Manager's maps.
aliasSet, ok := m.baseToSet[baseScid]
if !ok {
return ErrAliasNotFound
}
// We'll filter the alias set and remove the alias from it.
aliasSet = fn.Filter(aliasSet, func(a lnwire.ShortChannelID) bool {
return a.ToUint64() != alias.ToUint64()
})
// If the alias set is empty, we'll delete the base SCID from the
// baseToSet map.
if len(aliasSet) == 0 {
delete(m.baseToSet, baseScid)
} else {
m.baseToSet[baseScid] = aliasSet
}
// Finally, we'll delete the aliasToBase mapping from the Manager's
// cache (but this is only set if we gossip the alias).
delete(m.aliasToBase, alias)
// We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
return m.linkAliasUpdater(baseScid)
}
// PutPeerAlias stores the peer's alias SCID once we learn of it in the
// channel_ready message.
func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID,
alias lnwire.ShortChannelID) error {
m.Lock()
defer m.Unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(invoiceAliasBucket)
if err != nil {
return err
}
var scratch [8]byte
byteOrder.PutUint64(scratch[:], alias.ToUint64())
return bucket.Put(chanID[:], scratch[:])
}, func() {})
if err != nil {
return err
}
// Now that the database state has been updated, we can update it in
// our cache.
m.peerAlias[chanID] = alias
return nil
}
// GetPeerAlias retrieves a peer's alias SCID by the channel's ChanID.
func (m *Manager) GetPeerAlias(chanID lnwire.ChannelID) (lnwire.ShortChannelID,
error) {
m.RLock()
defer m.RUnlock()
alias, ok := m.peerAlias[chanID]
if !ok || alias == hop.Source {
return lnwire.ShortChannelID{}, errNoPeerAlias
}
return alias, nil
}
// RequestAlias returns a new ALIAS ShortChannelID to the caller by allocating
// the next un-allocated ShortChannelID. The starting ShortChannelID is
// 16000000:0:0 and the ending ShortChannelID is 16250000:16777215:65535. This
// gives roughly 2^58 possible ALIAS ShortChannelIDs which ensures this space
// won't get exhausted.
func (m *Manager) RequestAlias() (lnwire.ShortChannelID, error) {
var nextAlias lnwire.ShortChannelID
m.RLock()
defer m.RUnlock()
// haveAlias returns true if the passed alias is already assigned to a
// channel in the baseToSet map.
haveAlias := func(maybeNextAlias lnwire.ShortChannelID) bool {
return fn.Any(
slices.Collect(maps.Values(m.baseToSet)),
func(aliasList []lnwire.ShortChannelID) bool {
return fn.Any(
aliasList,
func(alias lnwire.ShortChannelID) bool {
return alias == maybeNextAlias
},
)
},
)
}
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(aliasAllocBucket)
if err != nil {
return err
}
lastBytes := bucket.Get(lastAliasKey)
if lastBytes == nil {
// If the key does not exist, then we can write the
// StartingAlias to it.
nextAlias = StartingAlias
// If the very first alias is already assigned, we'll
// keep incrementing until we find an unassigned alias.
// This is to avoid collision with custom added SCID
// aliases that fall into the same range as the ones we
// generate here monotonically. Those custom SCIDs are
// stored in a different bucket, but we can just check
// the in-memory map for simplicity.
for {
if !haveAlias(nextAlias) {
break
}
nextAlias = getNextScid(nextAlias)
// Abort if we've reached the end of the range.
if nextAlias.BlockHeight >=
AliasEndBlockHeight {
return fmt.Errorf("range for custom " +
"aliases exhausted")
}
}
var scratch [8]byte
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
return bucket.Put(lastAliasKey, scratch[:])
}
// Otherwise the key does exist so we can convert the retrieved
// lastAlias to a ShortChannelID and use it to assign the next
// ShortChannelID. This next ShortChannelID will then be
// persisted in the database.
lastScid := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(lastBytes),
)
nextAlias = getNextScid(lastScid)
// If the next alias is already assigned, we'll keep
// incrementing until we find an unassigned alias. This is to
// avoid collision with custom added SCID aliases that fall into
// the same range as the ones we generate here monotonically.
// Those custom SCIDs are stored in a different bucket, but we
// can just check the in-memory map for simplicity.
for {
if !haveAlias(nextAlias) {
break
}
nextAlias = getNextScid(nextAlias)
// Abort if we've reached the end of the range.
if nextAlias.BlockHeight >= AliasEndBlockHeight {
return fmt.Errorf("range for custom " +
"aliases exhausted")
}
}
var scratch [8]byte
byteOrder.PutUint64(scratch[:], nextAlias.ToUint64())
return bucket.Put(lastAliasKey, scratch[:])
}, func() {
nextAlias = lnwire.ShortChannelID{}
})
if err != nil {
return nextAlias, err
}
return nextAlias, nil
}
// ListAliases returns a carbon copy of baseToSet. This is used by the rpc
// layer.
func (m *Manager) ListAliases() ScidAliasMap {
m.RLock()
defer m.RUnlock()
baseCopy := make(ScidAliasMap)
for k, v := range m.baseToSet {
setCopy := make([]lnwire.ShortChannelID, len(v))
copy(setCopy, v)
baseCopy[k] = setCopy
}
return baseCopy
}
// getNextScid is a utility function that returns the next SCID for a given
// alias SCID. The BlockHeight ranges from [16000000, 16250000], the TxIndex
// ranges from [1, 16777215], and the TxPosition ranges from [1, 65535].
func getNextScid(last lnwire.ShortChannelID) lnwire.ShortChannelID {
var (
next lnwire.ShortChannelID
incrementIdx bool
incrementHeight bool
)
// If the TxPosition is 65535, then it goes to 0 and we need to
// increment the TxIndex.
if last.TxPosition == 65535 {
incrementIdx = true
}
// If the TxIndex is 16777215 and we need to increment it, then it goes
// to 0 and we need to increment the BlockHeight.
if last.TxIndex == 16777215 && incrementIdx {
incrementIdx = false
incrementHeight = true
}
switch {
// If we increment the TxIndex, then TxPosition goes to 0.
case incrementIdx:
next.BlockHeight = last.BlockHeight
next.TxIndex = last.TxIndex + 1
next.TxPosition = 0
// If we increment the BlockHeight, then the Tx fields go to 0.
case incrementHeight:
next.BlockHeight = last.BlockHeight + 1
next.TxIndex = 0
next.TxPosition = 0
// Otherwise, we only need to increment the TxPosition.
default:
next.BlockHeight = last.BlockHeight
next.TxIndex = last.TxIndex
next.TxPosition = last.TxPosition + 1
}
return next
}
// IsAlias returns true if the passed SCID is an alias. The function determines
// this by looking at the BlockHeight. If the BlockHeight is greater than
// AliasStartBlockHeight and less than AliasEndBlockHeight, then it is an alias
// assigned by RequestAlias. These bounds only apply to aliases we generate.
// Our peers are free to use any range they choose.
func IsAlias(scid lnwire.ShortChannelID) bool {
return scid.BlockHeight >= AliasStartBlockHeight &&
scid.BlockHeight < AliasEndBlockHeight
}
package amp
import (
"crypto/sha256"
"encoding/binary"
"fmt"
"github.com/lightningnetwork/lnd/lntypes"
)
// Share represents an n-of-n sharing of a secret 32-byte value. The secret can
// be recovered by XORing all n shares together.
type Share [32]byte
// Xor stores the byte-wise xor of shares x and y in z.
func (z *Share) Xor(x, y *Share) {
for i := range z {
z[i] = x[i] ^ y[i]
}
}
// ChildDesc contains the information necessary to derive a child hash/preimage
// pair that is attached to a particular HTLC. This information will be known by
// both the sender and receiver in the process of fulfilling an AMP payment.
type ChildDesc struct {
// Share is one of n shares of the root seed. Once all n shares are
// known to the receiver, the Share will also provide entropy to the
// derivation of child hash and preimage.
Share Share
// Index is 32-bit value that can be used to derive up to 2^32 child
// hashes and preimages from a single Share. This allows the payment
// hashes sent over the network to be refreshed without needing to
// modify the Share.
Index uint32
}
// Child is a payment hash and preimage pair derived from the root seed. In
// addition to the derived values, a Child carries all information required in
// the derivation apart from the root seed (unless n=1).
type Child struct {
// ChildDesc contains the data required to derive the child hash and
// preimage below.
ChildDesc
// Preimage is the child payment preimage that can be used to settle the
// HTLC carrying Hash.
Preimage lntypes.Preimage
// Hash is the child payment hash that to be carried by the HTLC.
Hash lntypes.Hash
}
// String returns a human-readable description of a Child.
func (c *Child) String() string {
return fmt.Sprintf("share=%x, index=%d -> preimage=%v, hash=%v",
c.Share, c.Index, c.Preimage, c.Hash)
}
// DeriveChild computes the child preimage and child hash for a given (root,
// share, index) tuple. The derivation is defined as:
//
// child_preimage = SHA256(root || share || be32(index)),
// child_hash = SHA256(child_preimage).
func DeriveChild(root Share, desc ChildDesc) *Child {
var (
indexBytes [4]byte
preimage lntypes.Preimage
hash lntypes.Hash
)
// Serialize the child index in big-endian order.
binary.BigEndian.PutUint32(indexBytes[:], desc.Index)
// Compute child_preimage as SHA256(root || share || child_index).
h := sha256.New()
_, _ = h.Write(root[:])
_, _ = h.Write(desc.Share[:])
_, _ = h.Write(indexBytes[:])
copy(preimage[:], h.Sum(nil))
// Compute child_hash as SHA256(child_preimage).
h = sha256.New()
_, _ = h.Write(preimage[:])
copy(hash[:], h.Sum(nil))
return &Child{
ChildDesc: desc,
Preimage: preimage,
Hash: hash,
}
}
package amp
import (
"crypto/rand"
"encoding/binary"
"fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/shards"
)
// Shard is an implementation of the shards.PaymentShards interface specific
// to AMP payments.
type Shard struct {
child *Child
mpp *record.MPP
amp *record.AMP
}
// A compile time check to ensure Shard implements the shards.PaymentShard
// interface.
var _ shards.PaymentShard = (*Shard)(nil)
// Hash returns the hash used for the HTLC representing this AMP shard.
func (s *Shard) Hash() lntypes.Hash {
return s.child.Hash
}
// MPP returns any extra MPP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) MPP() *record.MPP {
return s.mpp
}
// AMP returns any extra AMP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) AMP() *record.AMP {
return s.amp
}
// ShardTracker is an implementation of the shards.ShardTracker interface
// that is able to generate payment shards according to the AMP splitting
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
// cancel shares used for failed payment shards.
type ShardTracker struct {
setID [32]byte
paymentAddr [32]byte
totalAmt lnwire.MilliSatoshi
sharer Sharer
shards map[uint64]*Child
sync.Mutex
}
// A compile time check to ensure ShardTracker implements the
// shards.ShardTracker interface.
var _ shards.ShardTracker = (*ShardTracker)(nil)
// NewShardTracker creates a new shard tracker to use for AMP payments. The
// root shard, setID, payment address and total amount must be correctly set in
// order for the TLV options to include with each shard to be created
// correctly.
func NewShardTracker(root, setID, payAddr [32]byte,
totalAmt lnwire.MilliSatoshi) *ShardTracker {
// Create a new seed sharer from this root.
rootShare := Share(root)
rootSharer := SeedSharerFromRoot(&rootShare)
return &ShardTracker{
setID: setID,
paymentAddr: payAddr,
totalAmt: totalAmt,
sharer: rootSharer,
shards: make(map[uint64]*Child),
}
}
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be canceled
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
error) {
s.Lock()
defer s.Unlock()
// Use a random child index.
var childIndex [4]byte
if _, err := rand.Read(childIndex[:]); err != nil {
return nil, err
}
idx := binary.BigEndian.Uint32(childIndex[:])
// Depending on whether we are requesting the last shard or not, either
// split the current share into two, or get a Child directly from the
// current sharer.
var child *Child
if last {
child = s.sharer.Child(idx)
// If this was the last shard, set the current share to the
// zero share to indicate we cannot split it further.
s.sharer = s.sharer.Zero()
} else {
left, sharer, err := s.sharer.Split()
if err != nil {
return nil, err
}
s.sharer = sharer
child = left.Child(idx)
}
// Track the new child and return the shard.
s.shards[pid] = child
mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
amp := record.NewAMP(
child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
)
return &Shard{
child: child,
mpp: mpp,
amp: amp,
}, nil
}
// CancelShard cancel's the shard corresponding to the given attempt ID.
func (s *ShardTracker) CancelShard(pid uint64) error {
s.Lock()
defer s.Unlock()
c, ok := s.shards[pid]
if !ok {
return fmt.Errorf("pid not found")
}
delete(s.shards, pid)
// Now that we are canceling this shard, we XOR the share back into our
// current share.
s.sharer = s.sharer.Merge(c)
return nil
}
// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
s.Lock()
defer s.Unlock()
c, ok := s.shards[pid]
if !ok {
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
"not found", pid)
}
return c.Hash, nil
}
package amp
import (
"crypto/rand"
"fmt"
)
// zeroShare is the all-zero 32-byte share.
var zeroShare = Share{}
// Sharer facilitates dynamic splitting of a root share value and derivation of
// child preimage and hashes for individual HTLCs in an AMP payment. A sharer
// represents a specific node in an abstract binary tree that can generate up to
// 2^32-1 unique child preimage-hash pairs for the same share value. A node can
// also be split into it's left and right child in the tree. The Sharer
// guarantees that the share value of the left and right child XOR to the share
// value of the parent. This allows larger HTLCs to split into smaller
// subpayments, while ensuring that the reconstructed secret will exactly match
// the root seed.
type Sharer interface {
// Root returns the root share of the derivation tree. This is the value
// that will be reconstructed when combining the set of all child
// shares.
Root() Share
// Child derives a child preimage and child hash given a 32-bit index.
// Passing a different index will generate a unique preimage-hash pair
// with high probability, allowing the payment hash carried on HTLCs to
// be refreshed without needing to modify the share value. This would
// typically be used when an partial payment needs to be retried if it
// encounters routine network failures.
Child(index uint32) *Child
// Split returns a Sharer for the left and right child of the parent
// Sharer. XORing the share values of both sharers always yields the
// share value of the parent. The sender should use this to recursively
// divide payments that are too large into smaller subpayments, knowing
// that the shares of all nodes descending from the parent will XOR to
// the parent's share.
Split() (Sharer, Sharer, error)
// Merge takes the given Child and "merges" it into the Sharer by
// XOR-ing its share with the Sharer's current share.
Merge(*Child) Sharer
// Zero returns a a new "zero Sharer" that has its current share set to
// zero, while keeping the root share. Merging a Child from the
// original Sharer into this zero-Sharer gives back the original
// Sharer.
Zero() Sharer
}
// SeedSharer orchestrates the sharing of the root AMP seed along multiple
// paths. It also supports derivation of the child payment hashes that get
// attached to HTLCs, and the child preimages used by the receiver to settle
// individual HTLCs in the set.
type SeedSharer struct {
root Share
curr Share
}
// NewSeedSharer generates a new SeedSharer instance with a seed drawn at
// random.
func NewSeedSharer() (*SeedSharer, error) {
var root Share
if _, err := rand.Read(root[:]); err != nil {
return nil, err
}
return SeedSharerFromRoot(&root), nil
}
// SeedSharerFromRoot instantiates a SeedSharer with an externally provided
// seed.
func SeedSharerFromRoot(root *Share) *SeedSharer {
return initSeedSharer(root, root)
}
func initSeedSharer(root, curr *Share) *SeedSharer {
return &SeedSharer{
root: *root,
curr: *curr,
}
}
// Seed returns the sharer's seed, the primary source of entropy for deriving
// shares of the root.
func (s *SeedSharer) Root() Share {
return s.root
}
// Split constructs two child Sharers whose shares sum to the parent Sharer.
// This allows an HTLC whose payment amount could not be routed to be
// recursively split into smaller subpayments. After splitting a sharer the
// parent share should no longer be used, and the caller should use the Child
// method on each to derive preimage/hash pairs for the HTLCs.
func (s *SeedSharer) Split() (Sharer, Sharer, error) {
// We cannot split the zero-Sharer.
if s.curr == zeroShare {
return nil, nil, fmt.Errorf("cannot split zero-Sharer")
}
shareLeft, shareRight, err := split(&s.curr)
if err != nil {
return nil, nil, err
}
left := initSeedSharer(&s.root, &shareLeft)
right := initSeedSharer(&s.root, &shareRight)
return left, right, nil
}
// Merge takes the given Child and "merges" it into the Sharer by XOR-ing its
// share with the Sharer's current share.
func (s *SeedSharer) Merge(child *Child) Sharer {
var curr Share
curr.Xor(&s.curr, &child.Share)
sharer := initSeedSharer(&s.root, &curr)
return sharer
}
// Zero returns a a new "zero Sharer" that has its current share set to zero,
// while keeping the root share. Merging a Child from the original Sharer into
// this zero-Sharer gives back the original Sharer.
func (s *SeedSharer) Zero() Sharer {
return initSeedSharer(&s.root, &zeroShare)
}
// Child derives a preimage/hash pair to be used for an AMP HTLC.
// All children of s will use the same underlying share, but have unique
// preimage and hash. This can be used to rerandomize the preimage/hash pair for
// a given HTLC if a new route is needed.
func (s *SeedSharer) Child(index uint32) *Child {
desc := ChildDesc{
Share: s.curr,
Index: index,
}
return DeriveChild(s.root, desc)
}
// ReconstructChildren derives the set of children hashes and preimages from the
// provided descriptors. The shares from each child descriptor are first used to
// compute the root, afterwards the child hashes and preimages are
// deterministically computed. For child descriptor at index i in the input,
// it's derived child will occupy index i of the returned children.
func ReconstructChildren(descs ...ChildDesc) []*Child {
// Recompute the root by XORing the provided shares.
var root Share
for _, desc := range descs {
root.Xor(&root, &desc.Share)
}
// With the root computed, derive the child hashes and preimages from
// the child descriptors.
children := make([]*Child, len(descs))
for i, desc := range descs {
children[i] = DeriveChild(root, desc)
}
return children
}
// split splits a share into two random values, that when XOR'd reproduce the
// original share. Given a share s, the two shares are derived as:
//
// left <-$- random
// right = parent ^ left.
//
// When reconstructed, we have that:
//
// left ^ right = left ^ parent ^ left
// = parent.
func split(parent *Share) (Share, Share, error) {
// Generate a random share for the left child.
var left Share
if _, err := rand.Read(left[:]); err != nil {
return Share{}, Share{}, err
}
// Compute right = parent ^ left.
var right Share
right.Xor(parent, &left)
return left, right, nil
}
var _ Sharer = (*SeedSharer)(nil)
package batch
import (
"errors"
"sync"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/sqldb"
)
// errSolo is a sentinel error indicating that the requester should re-run the
// operation in isolation.
var errSolo = errors.New(
"batch function returned an error and should be re-run solo",
)
type request struct {
*Request
errChan chan error
}
type batch struct {
db kvdb.Backend
start sync.Once
reqs []*request
clear func(b *batch)
locker sync.Locker
}
// trigger is the entry point for the batch and ensures that run is started at
// most once.
func (b *batch) trigger() {
b.start.Do(b.run)
}
// run executes the current batch of requests. If any individual requests fail
// alongside others they will be retried by the caller.
func (b *batch) run() {
// Clear the batch from its scheduler, ensuring that no new requests are
// added to this batch.
b.clear(b)
// If a cache lock was provided, hold it until the this method returns.
// This is critical for ensuring external consistency of the operation,
// so that caches don't get out of sync with the on disk state.
if b.locker != nil {
b.locker.Lock()
defer b.locker.Unlock()
}
// Apply the batch until a subset succeeds or all of them fail. Requests
// that fail will be retried individually.
for len(b.reqs) > 0 {
var failIdx = -1
err := kvdb.Update(b.db, func(tx kvdb.RwTx) error {
for i, req := range b.reqs {
err := req.Update(tx)
if err != nil {
// If we get a serialization error, we
// want the underlying SQL retry
// mechanism to retry the entire batch.
// Otherwise, we can succeed in an
// sqldb retry and still re-execute the
// failing request individually.
dbErr := sqldb.MapSQLError(err)
if !sqldb.IsSerializationError(dbErr) {
failIdx = i
return err
}
return dbErr
}
}
return nil
}, func() {
for _, req := range b.reqs {
if req.Reset != nil {
req.Reset()
}
}
})
// If a request's Update failed, extract it and re-run the
// batch. The removed request will be retried individually by
// the caller.
if failIdx >= 0 {
req := b.reqs[failIdx]
// It's safe to shorten b.reqs here because the
// scheduler's batch no longer points to us.
b.reqs[failIdx] = b.reqs[len(b.reqs)-1]
b.reqs = b.reqs[:len(b.reqs)-1]
// Tell the submitter re-run it solo, continue with the
// rest of the batch.
req.errChan <- errSolo
continue
}
// None of the remaining requests failed, process the errors
// using each request's OnCommit closure and return the error
// to the requester. If no OnCommit closure is provided, simply
// return the error directly.
for _, req := range b.reqs {
if req.OnCommit != nil {
req.errChan <- req.OnCommit(err)
} else {
req.errChan <- err
}
}
return
}
}
package batch
import "github.com/lightningnetwork/lnd/kvdb"
// Request defines an operation that can be batched into a single bbolt
// transaction.
type Request struct {
// Reset is called before each invocation of Update and is used to clear
// any possible modifications to local state as a result of previous
// calls to Update that were not committed due to a concurrent batch
// failure.
//
// NOTE: This field is optional.
Reset func()
// Update is applied alongside other operations in the batch.
//
// NOTE: This method MUST NOT acquire any mutexes.
Update func(tx kvdb.RwTx) error
// OnCommit is called if the batch or a subset of the batch including
// this request all succeeded without failure. The passed error should
// contain the result of the transaction commit, as that can still fail
// even if none of the closures returned an error.
//
// NOTE: This field is optional.
OnCommit func(commitErr error) error
// lazy should be true if we don't have to immediately execute this
// request when it comes in. This means that it can be scheduled later,
// allowing larger batches.
lazy bool
}
// SchedulerOption is a type that can be used to supply options to a scheduled
// request.
type SchedulerOption func(r *Request)
// LazyAdd will make the request be executed lazily, added to the next batch to
// reduce db contention.
func LazyAdd() SchedulerOption {
return func(r *Request) {
r.lazy = true
}
}
// Scheduler abstracts a generic batching engine that accumulates an incoming
// set of Requests, executes them, and returns the error from the operation.
type Scheduler interface {
// Execute schedules a Request for execution with the next available
// batch. This method blocks until the underlying closure has been
// run against the database. The resulting error is returned to the
// caller.
Execute(req *Request) error
}
package batch
import (
"sync"
"time"
"github.com/lightningnetwork/lnd/kvdb"
)
// TimeScheduler is a batching engine that executes requests within a fixed
// horizon. When the first request is received, a TimeScheduler waits a
// configurable duration for other concurrent requests to join the batch. Once
// this time has elapsed, the batch is closed and executed. Subsequent requests
// are then added to a new batch which undergoes the same process.
type TimeScheduler struct {
db kvdb.Backend
locker sync.Locker
duration time.Duration
mu sync.Mutex
b *batch
}
// NewTimeScheduler initializes a new TimeScheduler with a fixed duration at
// which to schedule batches. If the operation needs to modify a higher-level
// cache, the cache's lock should be provided to so that external consistency
// can be maintained, as successful db operations will cause a request's
// OnCommit method to be executed while holding this lock.
func NewTimeScheduler(db kvdb.Backend, locker sync.Locker,
duration time.Duration) *TimeScheduler {
return &TimeScheduler{
db: db,
locker: locker,
duration: duration,
}
}
// Execute schedules the provided request for batch execution along with other
// concurrent requests. The request will be executed within a fixed horizon,
// parameterizeed by the duration of the scheduler. The error from the
// underlying operation is returned to the caller.
//
// NOTE: Part of the Scheduler interface.
func (s *TimeScheduler) Execute(r *Request) error {
req := request{
Request: r,
errChan: make(chan error, 1),
}
// Add the request to the current batch. If the batch has been cleared
// or no batch exists, create a new one.
s.mu.Lock()
if s.b == nil {
s.b = &batch{
db: s.db,
clear: s.clear,
locker: s.locker,
}
time.AfterFunc(s.duration, s.b.trigger)
}
s.b.reqs = append(s.b.reqs, &req)
// If this is a non-lazy request, we'll execute the batch immediately.
if !r.lazy {
go s.b.trigger()
}
s.mu.Unlock()
// Wait for the batch to process the request. If the batch didn't
// ask us to execute the request individually, simply return the error.
err := <-req.errChan
if err != errSolo {
return err
}
// Obtain exclusive access to the cache if this scheduler needs to
// modify the cache in OnCommit.
if s.locker != nil {
s.locker.Lock()
defer s.locker.Unlock()
}
// Otherwise, run the request on its own.
commitErr := kvdb.Update(s.db, req.Update, func() {
if req.Reset != nil {
req.Reset()
}
})
// Finally, return the commit error directly or execute the OnCommit
// closure with the commit error if present.
if req.OnCommit != nil {
return req.OnCommit(commitErr)
}
return commitErr
}
// clear resets the scheduler's batch to nil so that no more requests can be
// added.
func (s *TimeScheduler) clear(b *batch) {
s.mu.Lock()
if s.b == b {
s.b = nil
}
s.mu.Unlock()
}
package blockcache
import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/neutrino"
"github.com/lightninglabs/neutrino/cache"
"github.com/lightninglabs/neutrino/cache/lru"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/multimutex"
)
// BlockCache is an lru cache for blocks.
type BlockCache struct {
Cache *lru.Cache[wire.InvVect, *neutrino.CacheableBlock]
HashMutex *multimutex.Mutex[lntypes.Hash]
}
// NewBlockCache creates a new BlockCache with the given maximum capacity.
func NewBlockCache(capacity uint64) *BlockCache {
return &BlockCache{
Cache: lru.NewCache[wire.InvVect, *neutrino.CacheableBlock](
capacity,
),
HashMutex: multimutex.NewMutex[lntypes.Hash](),
}
}
// GetBlock first checks to see if the BlockCache already contains the block
// with the given hash. If it does then the block is fetched from the cache and
// returned. Otherwise the getBlockImpl function is used in order to fetch the
// new block and then it is stored in the block cache and returned.
func (bc *BlockCache) GetBlock(hash *chainhash.Hash,
getBlockImpl func(hash *chainhash.Hash) (*wire.MsgBlock,
error)) (*wire.MsgBlock, error) {
bc.HashMutex.Lock(lntypes.Hash(*hash))
defer bc.HashMutex.Unlock(lntypes.Hash(*hash))
// Create an inv vector for getting the block.
inv := wire.NewInvVect(wire.InvTypeWitnessBlock, hash)
// Check if the block corresponding to the given hash is already
// stored in the blockCache and return it if it is.
cacheBlock, err := bc.Cache.Get(*inv)
if err != nil && err != cache.ErrElementNotFound {
return nil, err
}
if cacheBlock != nil {
return cacheBlock.MsgBlock(), nil
}
// Fetch the block from the chain backends.
msgBlock, err := getBlockImpl(hash)
if err != nil {
return nil, err
}
// Make a copy of the block so it won't escape to the heap.
msgBlockCopy := msgBlock.Copy()
block := btcutil.NewBlock(msgBlockCopy)
// Add the new block to blockCache. If the Cache is at its maximum
// capacity then the LFU item will be evicted in favour of this new
// block.
_, err = bc.Cache.Put(
*inv, &neutrino.CacheableBlock{
Block: block,
},
)
if err != nil {
return nil, err
}
return msgBlockCopy, nil
}
package build
import (
"fmt"
"github.com/btcsuite/btclog/v2"
)
const (
callSiteOff = "off"
callSiteShort = "short"
callSiteLong = "long"
defaultLogCompressor = Gzip
// DefaultMaxLogFiles is the default maximum number of log files to
// keep.
DefaultMaxLogFiles = 10
// DefaultMaxLogFileSize is the default maximum log file size in MB.
DefaultMaxLogFileSize = 20
)
// LogConfig holds logging configuration options.
//
//nolint:ll
type LogConfig struct {
Console *consoleLoggerCfg `group:"console" namespace:"console" description:"The logger writing to stdout and stderr."`
File *FileLoggerConfig `group:"file" namespace:"file" description:"The logger writing to LND's standard log file."`
NoCommitHash bool `long:"no-commit-hash" description:"If set, the commit-hash of the current build will not be included in log lines by default."`
}
// Validate validates the LogConfig struct values.
func (c *LogConfig) Validate() error {
if !SupportedLogCompressor(c.File.Compressor) {
return fmt.Errorf("invalid log compressor: %v",
c.File.Compressor)
}
return nil
}
// LoggerConfig holds options for a particular logger.
//
//nolint:ll
type LoggerConfig struct {
Disable bool `long:"disable" description:"Disable this logger."`
NoTimestamps bool `long:"no-timestamps" description:"Omit timestamps from log lines."`
CallSite string `long:"call-site" description:"Include the call-site of each log line." choice:"off" choice:"short" choice:"long"`
}
// DefaultLogConfig returns the default logging config options.
func DefaultLogConfig() *LogConfig {
return &LogConfig{
Console: defaultConsoleLoggerCfg(),
File: &FileLoggerConfig{
Compressor: defaultLogCompressor,
MaxLogFiles: DefaultMaxLogFiles,
MaxLogFileSize: DefaultMaxLogFileSize,
LoggerConfig: &LoggerConfig{
CallSite: callSiteOff,
},
},
}
}
// HandlerOptions returns the set of btclog.HandlerOptions that the state of the
// config struct translates to.
func (cfg *LoggerConfig) HandlerOptions() []btclog.HandlerOption {
opts := []btclog.HandlerOption{
// We wrap the logger provided by the logging library with
// another layer of abstraction with the handlerSet, and so we
// need to increase the default skip depth by 1.
btclog.WithCallSiteSkipDepth(btclog.DefaultSkipDepth + 1),
}
if cfg.NoTimestamps {
opts = append(opts, btclog.WithNoTimestamp())
}
switch cfg.CallSite {
case callSiteShort:
opts = append(opts, btclog.WithCallerFlags(btclog.Lshortfile))
case callSiteLong:
opts = append(opts, btclog.WithCallerFlags(btclog.Llongfile))
}
return opts
}
// FileLoggerConfig extends LoggerConfig with specific log file options.
//
//nolint:ll
type FileLoggerConfig struct {
*LoggerConfig `yaml:",inline"`
Compressor string `long:"compressor" description:"Compression algorithm to use when rotating logs." choice:"gzip" choice:"zstd"`
MaxLogFiles int `long:"max-files" description:"Maximum logfiles to keep (0 for no rotation)"`
MaxLogFileSize int `long:"max-file-size" description:"Maximum logfile size in MB"`
}
//go:build !dev
// +build !dev
package build
// consoleLoggerCfg embeds the LoggerConfig struct along with any extensions
// specific to a production deployment.
//
//nolint:ll
type consoleLoggerCfg struct {
*LoggerConfig `yaml:",inline"`
}
// defaultConsoleLoggerCfg returns the default consoleLoggerCfg for the prod
// console logger.
func defaultConsoleLoggerCfg() *consoleLoggerCfg {
return &consoleLoggerCfg{
LoggerConfig: &LoggerConfig{
CallSite: callSiteOff,
},
}
}
package build
// DeploymentType is an enum specifying the deployment to compile.
type DeploymentType byte
const (
// Development is a deployment that includes extra testing hooks and
// logging configurations.
Development DeploymentType = iota
// Production is a deployment that strips out testing logic and uses
// Default logging.
Production
)
// String returns a human readable name for a build type.
func (b DeploymentType) String() string {
switch b {
case Development:
return "development"
case Production:
return "production"
default:
return "unknown"
}
}
// IsProdBuild returns true if this is a production build.
func IsProdBuild() bool {
return Deployment == Production
}
// IsDevBuild returns true if this is a development build.
func IsDevBuild() bool {
return Deployment == Development
}
package build
import (
"context"
"log/slog"
btclogv1 "github.com/btcsuite/btclog"
"github.com/btcsuite/btclog/v2"
)
// handlerSet is an implementation of Handler that abstracts away multiple
// Handlers.
type handlerSet struct {
level btclogv1.Level
set []btclog.Handler
}
// newHandlerSet constructs a new HandlerSet.
func newHandlerSet(level btclogv1.Level, set ...btclog.Handler) *handlerSet {
h := &handlerSet{
set: set,
level: level,
}
h.SetLevel(level)
return h
}
// Enabled reports whether the handler handles records at the given level.
//
// NOTE: this is part of the slog.Handler interface.
func (h *handlerSet) Enabled(ctx context.Context, level slog.Level) bool {
for _, handler := range h.set {
if !handler.Enabled(ctx, level) {
return false
}
}
return true
}
// Handle handles the Record.
//
// NOTE: this is part of the slog.Handler interface.
func (h *handlerSet) Handle(ctx context.Context, record slog.Record) error {
for _, handler := range h.set {
if err := handler.Handle(ctx, record); err != nil {
return err
}
}
return nil
}
// WithAttrs returns a new Handler whose attributes consist of both the
// receiver's attributes and the arguments.
//
// NOTE: this is part of the slog.Handler interface.
func (h *handlerSet) WithAttrs(attrs []slog.Attr) slog.Handler {
newSet := &reducedSet{set: make([]slog.Handler, len(h.set))}
for i, handler := range h.set {
newSet.set[i] = handler.WithAttrs(attrs)
}
return newSet
}
// WithGroup returns a new Handler with the given group appended to the
// receiver's existing groups.
//
// NOTE: this is part of the slog.Handler interface.
func (h *handlerSet) WithGroup(name string) slog.Handler {
newSet := &reducedSet{set: make([]slog.Handler, len(h.set))}
for i, handler := range h.set {
newSet.set[i] = handler.WithGroup(name)
}
return newSet
}
// SubSystem creates a new Handler with the given sub-system tag.
//
// NOTE: this is part of the Handler interface.
func (h *handlerSet) SubSystem(tag string) btclog.Handler {
newSet := &handlerSet{set: make([]btclog.Handler, len(h.set))}
for i, handler := range h.set {
newSet.set[i] = handler.SubSystem(tag)
}
return newSet
}
// SetLevel changes the logging level of the Handler to the passed
// level.
//
// NOTE: this is part of the btclog.Handler interface.
func (h *handlerSet) SetLevel(level btclogv1.Level) {
for _, handler := range h.set {
handler.SetLevel(level)
}
h.level = level
}
// Level returns the current logging level of the Handler.
//
// NOTE: this is part of the btclog.Handler interface.
func (h *handlerSet) Level() btclogv1.Level {
return h.level
}
// WithPrefix returns a copy of the Handler but with the given string prefixed
// to each log message.
//
// NOTE: this is part of the btclog.Handler interface.
func (h *handlerSet) WithPrefix(prefix string) btclog.Handler {
newSet := &handlerSet{set: make([]btclog.Handler, len(h.set))}
for i, handler := range h.set {
newSet.set[i] = handler.WithPrefix(prefix)
}
return newSet
}
// A compile-time check to ensure that handlerSet implements btclog.Handler.
var _ btclog.Handler = (*handlerSet)(nil)
// reducedSet is an implementation of the slog.Handler interface which is
// itself backed by multiple slog.Handlers. This is used by the handlerSet
// WithGroup and WithAttrs methods so that we can apply the WithGroup and
// WithAttrs to the underlying handlers in the set. These calls, however,
// produce slog.Handlers and not btclog.Handlers. So the reducedSet represents
// the resulting set produced.
type reducedSet struct {
set []slog.Handler
}
// Enabled reports whether the handler handles records at the given level.
//
// NOTE: this is part of the slog.Handler interface.
func (r *reducedSet) Enabled(ctx context.Context, level slog.Level) bool {
for _, handler := range r.set {
if !handler.Enabled(ctx, level) {
return false
}
}
return true
}
// Handle handles the Record.
//
// NOTE: this is part of the slog.Handler interface.
func (r *reducedSet) Handle(ctx context.Context, record slog.Record) error {
for _, handler := range r.set {
if err := handler.Handle(ctx, record); err != nil {
return err
}
}
return nil
}
// WithAttrs returns a new Handler whose attributes consist of both the
// receiver's attributes and the arguments.
//
// NOTE: this is part of the slog.Handler interface.
func (r *reducedSet) WithAttrs(attrs []slog.Attr) slog.Handler {
newSet := &reducedSet{set: make([]slog.Handler, len(r.set))}
for i, handler := range r.set {
newSet.set[i] = handler.WithAttrs(attrs)
}
return newSet
}
// WithGroup returns a new Handler with the given group appended to the
// receiver's existing groups.
//
// NOTE: this is part of the slog.Handler interface.
func (r *reducedSet) WithGroup(name string) slog.Handler {
newSet := &reducedSet{set: make([]slog.Handler, len(r.set))}
for i, handler := range r.set {
newSet.set[i] = handler.WithGroup(name)
}
return newSet
}
// A compile-time check to ensure that handlerSet implements slog.Handler.
var _ slog.Handler = (*reducedSet)(nil)
// subLogGenerator implements the SubLogCreator backed by a Handler.
type subLogGenerator struct {
handler btclog.Handler
}
// newSubLogGenerator constructs a new subLogGenerator from a Handler.
func newSubLogGenerator(handler btclog.Handler) *subLogGenerator {
return &subLogGenerator{
handler: handler,
}
}
// Logger returns a new logger for a particular sub-system.
//
// NOTE: this is part of the SubLogCreator interface.
func (b *subLogGenerator) Logger(subsystemTag string) btclog.Logger {
handler := b.handler.SubSystem(subsystemTag)
return btclog.NewSLogger(handler)
}
// A compile-time check to ensure that handlerSet implements slog.Handler.
var _ SubLogCreator = (*subLogGenerator)(nil)
package build
import (
"os"
"github.com/btcsuite/btclog/v2"
)
// NewDefaultLogHandlers returns the standard console logger and rotating log
// writer handlers that we generally want to use. It also applies the various
// config options to the loggers.
func NewDefaultLogHandlers(cfg *LogConfig,
rotator *RotatingLogWriter) []btclog.Handler {
var handlers []btclog.Handler
consoleLogHandler := btclog.NewDefaultHandler(
os.Stdout, cfg.Console.HandlerOptions()...,
)
logFileHandler := btclog.NewDefaultHandler(
rotator, cfg.File.HandlerOptions()...,
)
maybeAddLogger := func(cmdOptionDisable bool, handler btclog.Handler) {
if !cmdOptionDisable {
handlers = append(handlers, handler)
}
}
switch LoggingType {
case LogTypeStdOut:
maybeAddLogger(cfg.Console.Disable, consoleLogHandler)
case LogTypeDefault:
maybeAddLogger(cfg.Console.Disable, consoleLogHandler)
maybeAddLogger(cfg.File.Disable, logFileHandler)
}
return handlers
}
package build
import (
"os"
"github.com/btcsuite/btclog/v2"
)
// LogType is an indicating the type of logging specified by the build flag.
type LogType byte
const (
// LogTypeNone indicates no logging.
LogTypeNone LogType = iota
// LogTypeStdOut all logging is written directly to stdout.
LogTypeStdOut
// LogTypeDefault logs to both stdout and a given io.PipeWriter.
LogTypeDefault
)
// String returns a human readable identifier for the logging type.
func (t LogType) String() string {
switch t {
case LogTypeNone:
return "none"
case LogTypeStdOut:
return "stdout"
case LogTypeDefault:
return "default"
default:
return "unknown"
}
}
// Declare the supported log file compressors as exported consts for easier use
// from other projects.
const (
// Gzip is the default compressor.
Gzip = "gzip"
// Zstd is a modern compressor that compresses better than Gzip, in less
// time.
Zstd = "zstd"
)
// logCompressors maps the identifier for each supported compression algorithm
// to the extension used for the compressed log files.
var logCompressors = map[string]string{
Gzip: "gz",
Zstd: "zst",
}
// SupportedLogCompressor returns whether or not logCompressor is a supported
// compression algorithm for log files.
func SupportedLogCompressor(logCompressor string) bool {
_, ok := logCompressors[logCompressor]
return ok
}
// NewSubLogger constructs a new subsystem log from the current LogWriter
// implementation. This is primarily intended for use with stdlog, as the actual
// writer is shared amongst all instantiations.
func NewSubLogger(subsystem string,
genSubLogger func(string) btclog.Logger) btclog.Logger {
switch Deployment {
// For production builds, generate a new subsystem logger from the
// primary log backend. If no function is provided, logging will be
// disabled.
case Production:
if genSubLogger != nil {
return genSubLogger(subsystem)
}
// For development builds, we must handle two distinct types of logging:
// unit tests and running the live daemon, e.g. for integration testing.
case Development:
switch LoggingType {
// Default logging is used when running the standalone daemon.
// We'll use the optional sublogger constructor to mimic the
// production behavior.
case LogTypeDefault:
if genSubLogger != nil {
return genSubLogger(subsystem)
}
// Logging to stdout is used in unit tests. It is not important
// that they share the same backend, since all output is written
// to std out.
case LogTypeStdOut:
backend := btclog.NewDefaultHandler(os.Stdout)
logger := btclog.NewSLogger(
backend.SubSystem(subsystem),
)
// Set the logging level of the stdout logger to use the
// configured logging level specified by build flags.
level, _ := btclog.LevelFromString(LogLevel)
logger.SetLevel(level)
return logger
}
}
// For any other configurations, we'll disable logging.
return btclog.Disabled
}
package build
import (
"context"
"github.com/btcsuite/btclog/v2"
)
// ShutdownLogger wraps an existing logger with a shutdown function which will
// be called on Critical/Criticalf to prompt shutdown.
type ShutdownLogger struct {
btclog.Logger
shutdown func()
}
// NewShutdownLogger creates a shutdown logger for the log provided which will
// use the signal package to request shutdown on critical errors.
func NewShutdownLogger(logger btclog.Logger, shutdown func()) *ShutdownLogger {
return &ShutdownLogger{
Logger: logger,
shutdown: shutdown,
}
}
// Criticalf formats message according to format specifier and writes to
// log with LevelCritical. It will then call the shutdown logger's shutdown
// function to prompt safe shutdown.
//
// Note: it is part of the btclog.Logger interface.
func (s *ShutdownLogger) Criticalf(format string, params ...interface{}) {
s.Logger.Criticalf(format, params...)
s.Logger.Info("Sending request for shutdown")
s.shutdown()
}
// Critical formats message using the default formats for its operands
// and writes to log with LevelCritical. It will then call the shutdown
// logger's shutdown function to prompt safe shutdown.
//
// Note: it is part of the btclog.Logger interface.
func (s *ShutdownLogger) Critical(v ...interface{}) {
s.Logger.Critical(v)
s.Logger.Info("Sending request for shutdown")
s.shutdown()
}
// CriticalS writes a structured log with the given message and key-value pair
// attributes with LevelCritical to the log. It will then call the shutdown
// logger's shutdown function to prompt safe shutdown.
//
// Note: it is part of the btclog.Logger interface.
func (s *ShutdownLogger) CriticalS(ctx context.Context, msg string, err error,
attr ...interface{}) {
s.Logger.CriticalS(ctx, msg, err, attr...)
s.Logger.Info("Sending request for shutdown")
s.shutdown()
}
package build
import (
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"github.com/jrick/logrotate/rotator"
"github.com/klauspost/compress/zstd"
)
// RotatingLogWriter is a wrapper around the LogWriter that supports log file
// rotation.
type RotatingLogWriter struct {
// pipe is the write-end pipe for writing to the log rotator.
pipe *io.PipeWriter
rotator *rotator.Rotator
}
// NewRotatingLogWriter creates a new file rotating log writer.
//
// NOTE: `InitLogRotator` must be called to set up log rotation after creating
// the writer.
func NewRotatingLogWriter() *RotatingLogWriter {
return &RotatingLogWriter{}
}
// InitLogRotator initializes the log file rotator to write logs to logFile and
// create roll files in the same directory. It should be called as early on
// startup and possible and must be closed on shutdown by calling `Close`.
func (r *RotatingLogWriter) InitLogRotator(cfg *FileLoggerConfig,
logFile string) error {
logDir, _ := filepath.Split(logFile)
err := os.MkdirAll(logDir, 0700)
if err != nil {
return fmt.Errorf("failed to create log directory: %w", err)
}
r.rotator, err = rotator.New(
logFile, int64(cfg.MaxLogFileSize*1024), false, cfg.MaxLogFiles,
)
if err != nil {
return fmt.Errorf("failed to create file rotator: %w", err)
}
// Reject unknown compressors.
if !SupportedLogCompressor(cfg.Compressor) {
return fmt.Errorf("unknown log compressor: %v", cfg.Compressor)
}
var c rotator.Compressor
switch cfg.Compressor {
case Gzip:
c = gzip.NewWriter(nil)
case Zstd:
c, err = zstd.NewWriter(nil)
if err != nil {
return fmt.Errorf("failed to create zstd compressor: "+
"%w", err)
}
}
// Apply the compressor and its file suffix to the log rotator.
r.rotator.SetCompressor(c, logCompressors[cfg.Compressor])
// Run rotator as a goroutine now but make sure we catch any errors
// that happen in case something with the rotation goes wrong during
// runtime (like running out of disk space or not being allowed to
// create a new logfile for whatever reason).
pr, pw := io.Pipe()
go func() {
err := r.rotator.Run(pr)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr,
"failed to run file rotator: %v\n", err)
}
}()
r.pipe = pw
return nil
}
// Write writes the byte slice to the log rotator, if present.
func (r *RotatingLogWriter) Write(b []byte) (int, error) {
if r.rotator != nil {
return r.rotator.Write(b)
}
return len(b), nil
}
// Close closes the underlying log rotator if it has already been created.
func (r *RotatingLogWriter) Close() error {
if r.rotator != nil {
return r.rotator.Close()
}
return nil
}
package build
import (
"fmt"
"sort"
"strings"
"sync"
"github.com/btcsuite/btclog/v2"
)
// SubLogCreator can be used to create a new logger for a particular subsystem.
type SubLogCreator interface {
// Logger returns a new logger for a particular subsystem.
Logger(subsystemTag string) btclog.Logger
}
// SubLoggerManager manages a set of subsystem loggers. Level updates will be
// applied to all the loggers managed by the manager.
type SubLoggerManager struct {
genLogger SubLogCreator
loggers SubLoggers
mu sync.Mutex
}
// A compile time check to ensure SubLoggerManager implements the
// LeveledSubLogger interface.
var _ LeveledSubLogger = (*SubLoggerManager)(nil)
// NewSubLoggerManager constructs a new SubLoggerManager.
func NewSubLoggerManager(handlers ...btclog.Handler) *SubLoggerManager {
return &SubLoggerManager{
loggers: make(SubLoggers),
genLogger: newSubLogGenerator(
newHandlerSet(btclog.LevelInfo, handlers...),
),
}
}
// GenSubLogger creates a new sub-logger and adds it to the set managed by the
// SubLoggerManager. A shutdown callback function is provided to be able to shut
// down in case of a critical error.
func (r *SubLoggerManager) GenSubLogger(subsystem string,
shutdown func()) btclog.Logger {
// Create a new logger with the given subsystem tag.
logger := r.genLogger.Logger(subsystem)
// Wrap the new logger in a Shutdown logger so that the shutdown
// call back is called if a critical log is ever written via this new
// logger.
l := NewShutdownLogger(logger, shutdown)
r.RegisterSubLogger(subsystem, l)
return l
}
// RegisterSubLogger registers the given logger under the given subsystem name.
func (r *SubLoggerManager) RegisterSubLogger(subsystem string,
logger btclog.Logger) {
// Add the new logger to the set of loggers managed by the manager.
r.mu.Lock()
r.loggers[subsystem] = logger
r.mu.Unlock()
}
// SubLoggers returns all currently registered subsystem loggers for this log
// writer.
//
// NOTE: This is part of the LeveledSubLogger interface.
func (r *SubLoggerManager) SubLoggers() SubLoggers {
r.mu.Lock()
defer r.mu.Unlock()
return r.loggers
}
// SupportedSubsystems returns a sorted string slice of all keys in the
// subsystems map, corresponding to the names of the subsystems.
//
// NOTE: This is part of the LeveledSubLogger interface.
func (r *SubLoggerManager) SupportedSubsystems() []string {
r.mu.Lock()
defer r.mu.Unlock()
// Convert the subsystemLoggers map keys to a string slice.
subsystems := make([]string, 0, len(r.loggers))
for subsysID := range r.loggers {
subsystems = append(subsystems, subsysID)
}
// Sort the subsystems for stable display.
sort.Strings(subsystems)
return subsystems
}
// SetLogLevel sets the logging level for provided subsystem. Invalid
// subsystems are ignored. Uninitialized subsystems are dynamically created as
// needed.
//
// NOTE: This is part of the LeveledSubLogger interface.
func (r *SubLoggerManager) SetLogLevel(subsystemID string, logLevel string) {
r.mu.Lock()
defer r.mu.Unlock()
r.setLogLevelUnsafe(subsystemID, logLevel)
}
// setLogLevelUnsafe sets the logging level for provided subsystem. Invalid
// subsystems are ignored. Uninitialized subsystems are dynamically created as
// needed.
//
// NOTE: the SubLoggerManager mutex must be held before calling this method.
func (r *SubLoggerManager) setLogLevelUnsafe(subsystemID string,
logLevel string) {
// Ignore invalid subsystems.
logger, ok := r.loggers[subsystemID]
if !ok {
return
}
// Defaults to info if the log level is invalid.
level, _ := btclog.LevelFromString(logLevel)
logger.SetLevel(level)
}
// SetLogLevels sets the log level for all subsystem loggers to the passed
// level. It also dynamically creates the subsystem loggers as needed, so it
// can be used to initialize the logging system.
//
// NOTE: This is part of the LeveledSubLogger interface.
func (r *SubLoggerManager) SetLogLevels(logLevel string) {
r.mu.Lock()
defer r.mu.Unlock()
// Configure all sub-systems with the new logging level. Dynamically
// create loggers as needed.
for subsystemID := range r.loggers {
r.setLogLevelUnsafe(subsystemID, logLevel)
}
}
// SubLoggers is a type that holds a map of subsystem loggers keyed by their
// subsystem name.
type SubLoggers map[string]btclog.Logger
// LeveledSubLogger provides the ability to retrieve the subsystem loggers of
// a logger and set their log levels individually or all at once.
type LeveledSubLogger interface {
// SubLoggers returns the map of all registered subsystem loggers.
SubLoggers() SubLoggers
// SupportedSubsystems returns a slice of strings containing the names
// of the supported subsystems. Should ideally correspond to the keys
// of the subsystem logger map and be sorted.
SupportedSubsystems() []string
// SetLogLevel assigns an individual subsystem logger a new log level.
SetLogLevel(subsystemID string, logLevel string)
// SetLogLevels assigns all subsystem loggers the same new log level.
SetLogLevels(logLevel string)
}
// ParseAndSetDebugLevels attempts to parse the specified debug level and set
// the levels accordingly on the given logger. An appropriate error is returned
// if anything is invalid.
func ParseAndSetDebugLevels(level string, logger LeveledSubLogger) error {
// Split at the delimiter.
levels := strings.Split(level, ",")
if len(levels) == 0 {
return fmt.Errorf("invalid log level: %v", level)
}
// If the first entry has no =, treat is as the log level for all
// subsystems.
globalLevel := levels[0]
if !strings.Contains(globalLevel, "=") {
// Validate debug log level.
if !validLogLevel(globalLevel) {
str := "the specified debug level [%v] is invalid"
return fmt.Errorf(str, globalLevel)
}
// Change the logging level for all subsystems.
logger.SetLogLevels(globalLevel)
// The rest will target specific subsystems.
levels = levels[1:]
}
// Go through the subsystem/level pairs while detecting issues and
// update the log levels accordingly.
for _, logLevelPair := range levels {
if !strings.Contains(logLevelPair, "=") {
str := "the specified debug level contains an " +
"invalid subsystem/level pair [%v]"
return fmt.Errorf(str, logLevelPair)
}
// Extract the specified subsystem and log level.
fields := strings.Split(logLevelPair, "=")
if len(fields) != 2 {
str := "the specified debug level has an invalid " +
"format [%v] -- use format subsystem1=level1," +
"subsystem2=level2"
return fmt.Errorf(str, logLevelPair)
}
subsysID, logLevel := fields[0], fields[1]
subLoggers := logger.SubLoggers()
// Validate subsystem.
if _, exists := subLoggers[subsysID]; !exists {
str := "the specified subsystem [%v] is invalid -- " +
"supported subsystems are %v"
return fmt.Errorf(
str, subsysID, logger.SupportedSubsystems(),
)
}
// Validate log level.
if !validLogLevel(logLevel) {
str := "the specified debug level [%v] is invalid"
return fmt.Errorf(str, logLevel)
}
logger.SetLogLevel(subsysID, logLevel)
}
return nil
}
// validLogLevel returns whether or not logLevel is a valid debug log level.
func validLogLevel(logLevel string) bool {
switch logLevel {
case "trace":
fallthrough
case "debug":
fallthrough
case "info":
fallthrough
case "warn":
fallthrough
case "error":
fallthrough
case "critical":
fallthrough
case "off":
return true
}
return false
}
// Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2015-2016 The Decred developers
// Heavily inspired by https://github.com/btcsuite/btcd/blob/master/version.go
// Copyright (C) 2015-2022 The Lightning Network Developers
package build
import (
"context"
"encoding/hex"
"fmt"
"runtime/debug"
"strings"
"github.com/btcsuite/btclog/v2"
)
var (
// Commit stores the current commit of this build, which includes the
// most recent tag, the number of commits since that tag (if non-zero),
// the commit hash, and a dirty marker. This should be set using the
// -ldflags during compilation.
Commit string
// CommitHash stores the current commit hash of this build.
CommitHash string
// RawTags contains the raw set of build tags, separated by commas.
RawTags string
// GoVersion stores the go version that the executable was compiled
// with.
GoVersion string
)
// semanticAlphabet is the set of characters that are permitted for use in an
// AppPreRelease.
const semanticAlphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-."
// These constants define the application version and follow the semantic
// versioning 2.0.0 spec (http://semver.org/).
const (
// AppMajor defines the major version of this binary.
AppMajor uint = 0
// AppMinor defines the minor version of this binary.
AppMinor uint = 19
// AppPatch defines the application patch for this binary.
AppPatch uint = 00
// AppPreRelease MUST only contain characters from semanticAlphabet per
// the semantic versioning spec.
AppPreRelease = "beta.rc2"
)
func init() {
// Assert that AppPreRelease is valid according to the semantic
// versioning guidelines for pre-release version and build metadata
// strings. In particular it MUST only contain characters in
// semanticAlphabet.
for _, r := range AppPreRelease {
if !strings.ContainsRune(semanticAlphabet, r) {
panic(fmt.Errorf("rune: %v is not in the semantic "+
"alphabet", r))
}
}
// Get build information from the runtime.
if info, ok := debug.ReadBuildInfo(); ok {
GoVersion = info.GoVersion
for _, setting := range info.Settings {
switch setting.Key {
case "vcs.revision":
CommitHash = setting.Value
case "-tags":
RawTags = setting.Value
}
}
}
}
// Version returns the application version as a properly formed string per the
// semantic versioning 2.0.0 spec (http://semver.org/).
func Version() string {
// Start with the major, minor, and patch versions.
version := fmt.Sprintf("%d.%d.%d", AppMajor, AppMinor, AppPatch)
// Append pre-release version if there is one. The hyphen called for by
// the semantic versioning spec is automatically appended and should not
// be contained in the pre-release string.
if AppPreRelease != "" {
version = fmt.Sprintf("%s-%s", version, AppPreRelease)
}
return version
}
// Tags returns the list of build tags that were compiled into the executable.
func Tags() []string {
if len(RawTags) == 0 {
return nil
}
return strings.Split(RawTags, ",")
}
// WithBuildInfo derives a child context with the build information attached as
// attributes. At the moment, this only includes the current build's commit
// hash.
func WithBuildInfo(ctx context.Context, cfg *LogConfig) (context.Context,
error) {
if cfg.NoCommitHash {
return ctx, nil
}
// Convert the commit hash to a byte slice.
commitHash, err := hex.DecodeString(CommitHash)
if err != nil {
return nil, fmt.Errorf("unable to decode commit hash: %w", err)
}
return btclog.WithCtx(ctx, btclog.Hex3("rev", commitHash)), nil
}
package chainio
import (
"fmt"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/chainntnfs"
)
// Beat implements the Blockbeat interface. It contains the block epoch and a
// customized logger.
//
// TODO(yy): extend this to check for confirmation status - which serves as the
// single source of truth, to avoid the potential race between receiving blocks
// and `GetTransactionDetails/RegisterSpendNtfn/RegisterConfirmationsNtfn`.
type Beat struct {
// epoch is the current block epoch the blockbeat is aware of.
epoch chainntnfs.BlockEpoch
// log is the customized logger for the blockbeat which prints the
// block height.
log btclog.Logger
}
// Compile-time check to ensure Beat satisfies the Blockbeat interface.
var _ Blockbeat = (*Beat)(nil)
// NewBeat creates a new beat with the specified block epoch and a customized
// logger.
func NewBeat(epoch chainntnfs.BlockEpoch) *Beat {
b := &Beat{
epoch: epoch,
}
// Create a customized logger for the blockbeat.
logPrefix := fmt.Sprintf("Height[%6d]:", b.Height())
b.log = clog.WithPrefix(logPrefix)
return b
}
// Height returns the height of the block epoch.
//
// NOTE: Part of the Blockbeat interface.
func (b *Beat) Height() int32 {
return b.epoch.Height
}
// logger returns the logger for the blockbeat.
//
// NOTE: Part of the private blockbeat interface.
func (b *Beat) logger() btclog.Logger {
return b.log
}
package chainio
// BeatConsumer defines a supplementary component that should be used by
// subsystems which implement the `Consumer` interface. It partially implements
// the `Consumer` interface by providing the method `ProcessBlock` such that
// subsystems don't need to re-implement it.
//
// While inheritance is not commonly used in Go, subsystems embedding this
// struct cannot pass the interface check for `Consumer` because the `Name`
// method is not implemented, which gives us a "mortise and tenon" structure.
// In addition to reducing code duplication, this design allows `ProcessBlock`
// to work on the concrete type `Beat` to access its internal states.
type BeatConsumer struct {
// BlockbeatChan is a channel to receive blocks from Blockbeat. The
// received block contains the best known height and the txns confirmed
// in this block.
BlockbeatChan chan Blockbeat
// name is the name of the consumer which embeds the BlockConsumer.
name string
// quit is a channel that closes when the BlockConsumer is shutting
// down.
//
// NOTE: this quit channel should be mounted to the same quit channel
// used by the subsystem.
quit chan struct{}
// errChan is a buffered chan that receives an error returned from
// processing this block.
errChan chan error
}
// NewBeatConsumer creates a new BlockConsumer.
func NewBeatConsumer(quit chan struct{}, name string) BeatConsumer {
// Refuse to start `lnd` if the quit channel is not initialized. We
// treat this case as if we are facing a nil pointer dereference, as
// there's no point to return an error here, which will cause the node
// to fail to be started anyway.
if quit == nil {
panic("quit channel is nil")
}
b := BeatConsumer{
BlockbeatChan: make(chan Blockbeat),
name: name,
errChan: make(chan error, 1),
quit: quit,
}
return b
}
// ProcessBlock takes a blockbeat and sends it to the consumer's blockbeat
// channel. It will send it to the subsystem's BlockbeatChan, and block until
// the processed result is received from the subsystem. The subsystem must call
// `NotifyBlockProcessed` after it has finished processing the block.
//
// NOTE: part of the `chainio.Consumer` interface.
func (b *BeatConsumer) ProcessBlock(beat Blockbeat) error {
// Update the current height.
beat.logger().Tracef("set current height for [%s]", b.name)
select {
// Send the beat to the blockbeat channel. It's expected that the
// consumer will read from this channel and process the block. Once
// processed, it should return the error or nil to the beat.Err chan.
case b.BlockbeatChan <- beat:
beat.logger().Tracef("Sent blockbeat to [%s]", b.name)
case <-b.quit:
beat.logger().Debugf("[%s] received shutdown before sending "+
"beat", b.name)
return nil
}
// Check the consumer's err chan. We expect the consumer to call
// `beat.NotifyBlockProcessed` to send the error back here.
select {
case err := <-b.errChan:
beat.logger().Tracef("[%s] processed beat: err=%v", b.name, err)
return err
case <-b.quit:
beat.logger().Debugf("[%s] received shutdown", b.name)
}
return nil
}
// NotifyBlockProcessed signals that the block has been processed. It takes the
// blockbeat being processed and an error resulted from processing it. This
// error is then sent back to the consumer's err chan to unblock
// `ProcessBlock`.
//
// NOTE: This method must be called by the subsystem after it has finished
// processing the block.
func (b *BeatConsumer) NotifyBlockProcessed(beat Blockbeat, err error) {
// Update the current height.
beat.logger().Tracef("[%s]: notifying beat processed", b.name)
select {
case b.errChan <- err:
beat.logger().Tracef("[%s]: notified beat processed, err=%v",
b.name, err)
case <-b.quit:
beat.logger().Debugf("[%s] received shutdown before notifying "+
"beat processed", b.name)
}
}
package chainio
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/lnutils"
"golang.org/x/sync/errgroup"
)
// DefaultProcessBlockTimeout is the timeout value used when waiting for one
// consumer to finish processing the new block epoch.
var DefaultProcessBlockTimeout = 60 * time.Second
// ErrProcessBlockTimeout is the error returned when a consumer takes too long
// to process the block.
var ErrProcessBlockTimeout = errors.New("process block timeout")
// BlockbeatDispatcher is a service that handles dispatching new blocks to
// `lnd`'s subsystems. During startup, subsystems that are block-driven should
// implement the `Consumer` interface and register themselves via
// `RegisterQueue`. When two subsystems are independent of each other, they
// should be registered in different queues so blocks are notified concurrently.
// Otherwise, when living in the same queue, the subsystems are notified of the
// new blocks sequentially, which means it's critical to understand the
// relationship of these systems to properly handle the order.
type BlockbeatDispatcher struct {
wg sync.WaitGroup
// notifier is used to receive new block epochs.
notifier chainntnfs.ChainNotifier
// beat is the latest blockbeat received.
beat Blockbeat
// consumerQueues is a map of consumers that will receive blocks. Its
// key is a unique counter and its value is a queue of consumers. Each
// queue is notified concurrently, and consumers in the same queue is
// notified sequentially.
consumerQueues map[uint32][]Consumer
// counter is used to assign a unique id to each queue.
counter atomic.Uint32
// quit is used to signal the BlockbeatDispatcher to stop.
quit chan struct{}
// queryHeightChan is used to receive queries on the current height of
// the dispatcher.
queryHeightChan chan *query
}
// query is used to fetch the internal state of the dispatcher.
type query struct {
// respChan is used to send back the current height back to the caller.
//
// NOTE: This channel must be buffered.
respChan chan int32
}
// newQuery creates a query to be used to fetch the internal state of the
// dispatcher.
func newQuery() *query {
return &query{
respChan: make(chan int32, 1),
}
}
// NewBlockbeatDispatcher returns a new blockbeat dispatcher instance.
func NewBlockbeatDispatcher(n chainntnfs.ChainNotifier) *BlockbeatDispatcher {
return &BlockbeatDispatcher{
notifier: n,
quit: make(chan struct{}),
consumerQueues: make(map[uint32][]Consumer),
queryHeightChan: make(chan *query, 1),
}
}
// RegisterQueue takes a list of consumers and registers them in the same
// queue.
//
// NOTE: these consumers are notified sequentially.
func (b *BlockbeatDispatcher) RegisterQueue(consumers []Consumer) {
qid := b.counter.Add(1)
b.consumerQueues[qid] = append(b.consumerQueues[qid], consumers...)
clog.Infof("Registered queue=%d with %d blockbeat consumers", qid,
len(consumers))
for _, c := range consumers {
clog.Debugf("Consumer [%s] registered in queue %d", c.Name(),
qid)
}
}
// Start starts the blockbeat dispatcher - it registers a block notification
// and monitors and dispatches new blocks in a goroutine. It will refuse to
// start if there are no registered consumers.
func (b *BlockbeatDispatcher) Start() error {
// Make sure consumers are registered.
if len(b.consumerQueues) == 0 {
return fmt.Errorf("no consumers registered")
}
// Start listening to new block epochs. We should get a notification
// with the current best block immediately.
blockEpochs, err := b.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return fmt.Errorf("register block epoch ntfn: %w", err)
}
clog.Infof("BlockbeatDispatcher is starting with %d consumer queues",
len(b.consumerQueues))
defer clog.Debug("BlockbeatDispatcher started")
b.wg.Add(1)
go b.dispatchBlocks(blockEpochs)
return nil
}
// Stop shuts down the blockbeat dispatcher.
func (b *BlockbeatDispatcher) Stop() {
clog.Info("BlockbeatDispatcher is stopping")
defer clog.Debug("BlockbeatDispatcher stopped")
// Signal the dispatchBlocks goroutine to stop.
close(b.quit)
b.wg.Wait()
}
func (b *BlockbeatDispatcher) log() btclog.Logger {
return b.beat.logger()
}
// dispatchBlocks listens to new block epoch and dispatches it to all the
// consumers. Each queue is notified concurrently, and the consumers in the
// same queue are notified sequentially.
//
// NOTE: Must be run as a goroutine.
func (b *BlockbeatDispatcher) dispatchBlocks(
blockEpochs *chainntnfs.BlockEpochEvent) {
defer b.wg.Done()
defer blockEpochs.Cancel()
for {
select {
case blockEpoch, ok := <-blockEpochs.Epochs:
if !ok {
clog.Debugf("Block epoch channel closed")
return
}
// Log a separator so it's easier to identify when a
// new block arrives for subsystems.
clog.Debugf("%v", lnutils.NewSeparatorClosure())
clog.Debugf("Received new block %v at height %d, "+
"notifying consumers...", blockEpoch.Hash,
blockEpoch.Height)
// Record the time it takes the consumer to process
// this block.
start := time.Now()
// Update the current block epoch.
b.beat = NewBeat(*blockEpoch)
// Notify all consumers.
err := b.notifyQueues()
if err != nil {
b.log().Errorf("Notify block failed: %v", err)
}
b.log().Debugf("Notified all consumers on new block "+
"in %v", time.Since(start))
// A query has been made to fetch the current height, we now
// send the height from its current beat.
case query := <-b.queryHeightChan:
// The beat may not be set yet, e.g., during the startup
// the query is made before the block epoch being sent.
height := int32(0)
if b.beat != nil {
height = b.beat.Height()
}
query.respChan <- height
case <-b.quit:
b.log().Debugf("BlockbeatDispatcher quit signal " +
"received")
return
}
}
}
// CurrentHeight returns the current best height known to the dispatcher. 0 is
// returned if the dispatcher is shutting down.
func (b *BlockbeatDispatcher) CurrentHeight() int32 {
query := newQuery()
select {
case b.queryHeightChan <- query:
case <-b.quit:
clog.Debugf("BlockbeatDispatcher quit before query")
return 0
}
select {
case height := <-query.respChan:
clog.Debugf("Responded current height: %v", height)
return height
case <-b.quit:
clog.Debugf("BlockbeatDispatcher quit before response")
return 0
}
}
// notifyQueues notifies each queue concurrently about the latest block epoch.
func (b *BlockbeatDispatcher) notifyQueues() error {
// errChans is a map of channels that will be used to receive errors
// returned from notifying the consumers.
errChans := make(map[uint32]chan error, len(b.consumerQueues))
// Notify each queue in goroutines.
for qid, consumers := range b.consumerQueues {
b.log().Debugf("Notifying queue=%d with %d consumers", qid,
len(consumers))
// Create a signal chan.
errChan := make(chan error, 1)
errChans[qid] = errChan
// Notify each queue concurrently.
go func(qid uint32, c []Consumer, beat Blockbeat) {
// Notify each consumer in this queue sequentially.
errChan <- DispatchSequential(beat, c)
}(qid, consumers, b.beat)
}
// Wait for all consumers in each queue to finish.
for qid, errChan := range errChans {
select {
case err := <-errChan:
if err != nil {
return fmt.Errorf("queue=%d got err: %w", qid,
err)
}
b.log().Debugf("Notified queue=%d", qid)
case <-b.quit:
b.log().Debugf("BlockbeatDispatcher quit signal " +
"received, exit notifyQueues")
return nil
}
}
return nil
}
// DispatchSequential takes a list of consumers and notify them about the new
// epoch sequentially. It requires the consumer to finish processing the block
// within the specified time, otherwise a timeout error is returned.
func DispatchSequential(b Blockbeat, consumers []Consumer) error {
for _, c := range consumers {
// Send the beat to the consumer.
err := notifyAndWait(b, c, DefaultProcessBlockTimeout)
if err != nil {
b.logger().Errorf("Failed to process block: %v", err)
return err
}
}
return nil
}
// DispatchConcurrent notifies each consumer concurrently about the blockbeat.
// It requires the consumer to finish processing the block within the specified
// time, otherwise a timeout error is returned.
func DispatchConcurrent(b Blockbeat, consumers []Consumer) error {
eg := &errgroup.Group{}
// Notify each queue in goroutines.
for _, c := range consumers {
// Notify each consumer concurrently.
eg.Go(func() error {
// Send the beat to the consumer.
err := notifyAndWait(b, c, DefaultProcessBlockTimeout)
// Exit early if there's no error.
if err == nil {
return nil
}
b.logger().Errorf("Consumer=%v failed to process "+
"block: %v", c.Name(), err)
return err
})
}
// Wait for all consumers in each queue to finish.
if err := eg.Wait(); err != nil {
return err
}
return nil
}
// notifyAndWait sends the blockbeat to the specified consumer. It requires the
// consumer to finish processing the block within the specified time, otherwise
// a timeout error is returned.
func notifyAndWait(b Blockbeat, c Consumer, timeout time.Duration) error {
b.logger().Debugf("Waiting for consumer[%s] to process it", c.Name())
// Record the time it takes the consumer to process this block.
start := time.Now()
errChan := make(chan error, 1)
go func() {
errChan <- c.ProcessBlock(b)
}()
// We expect the consumer to finish processing this block under 30s,
// otherwise a timeout error is returned.
select {
case err := <-errChan:
if err == nil {
break
}
return fmt.Errorf("%s got err in ProcessBlock: %w", c.Name(),
err)
case <-time.After(timeout):
return fmt.Errorf("consumer %s: %w", c.Name(),
ErrProcessBlockTimeout)
}
b.logger().Debugf("Consumer[%s] processed block in %v", c.Name(),
time.Since(start))
return nil
}
package chainio
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// Subsystem defines the logging code for this subsystem.
const Subsystem = "CHIO"
// clog is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var clog btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
clog = logger
}
package chainio
import (
"github.com/btcsuite/btclog/v2"
"github.com/stretchr/testify/mock"
)
// MockConsumer is a mock implementation of the Consumer interface.
type MockConsumer struct {
mock.Mock
}
// Compile-time constraint to ensure MockConsumer implements Consumer.
var _ Consumer = (*MockConsumer)(nil)
// Name returns a human-readable string for this subsystem.
func (m *MockConsumer) Name() string {
args := m.Called()
return args.String(0)
}
// ProcessBlock takes a blockbeat and processes it. A receive-only error chan
// must be returned.
func (m *MockConsumer) ProcessBlock(b Blockbeat) error {
args := m.Called(b)
return args.Error(0)
}
// MockBlockbeat is a mock implementation of the Blockbeat interface.
type MockBlockbeat struct {
mock.Mock
}
// Compile-time constraint to ensure MockBlockbeat implements Blockbeat.
var _ Blockbeat = (*MockBlockbeat)(nil)
// Height returns the current block height.
func (m *MockBlockbeat) Height() int32 {
args := m.Called()
return args.Get(0).(int32)
}
// logger returns the logger for the blockbeat.
func (m *MockBlockbeat) logger() btclog.Logger {
args := m.Called()
return args.Get(0).(btclog.Logger)
}
package chainntnfs
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/wire"
)
// BestBlockView is an interface that allows the querying of the most
// up-to-date blockchain state with low overhead. Valid implementations of this
// interface must track the latest chain state.
type BestBlockView interface {
// BestHeight gets the most recent block height known to the view.
BestHeight() (uint32, error)
// BestBlockHeader gets the most recent block header known to the view.
BestBlockHeader() (*wire.BlockHeader, error)
}
// BestBlockTracker is a tiny subsystem that tracks the blockchain tip
// and saves the most recent tip information in memory for querying. It is a
// valid implementation of BestBlockView and additionally includes
// methods for starting and stopping the system.
type BestBlockTracker struct {
notifier ChainNotifier
blockNtfnStream *BlockEpochEvent
current atomic.Pointer[BlockEpoch]
mu sync.Mutex
quit chan struct{}
wg sync.WaitGroup
}
// This is a compile time check to ensure that BestBlockTracker implements
// BestBlockView.
var _ BestBlockView = (*BestBlockTracker)(nil)
// NewBestBlockTracker creates a new BestBlockTracker that isn't running yet.
// It will not provide up to date information unless it has been started. The
// ChainNotifier parameter must also be started prior to starting the
// BestBlockTracker.
func NewBestBlockTracker(chainNotifier ChainNotifier) *BestBlockTracker {
return &BestBlockTracker{
notifier: chainNotifier,
blockNtfnStream: nil,
quit: make(chan struct{}),
}
}
// BestHeight gets the most recent block height known to the
// BestBlockTracker.
func (t *BestBlockTracker) BestHeight() (uint32, error) {
epoch := t.current.Load()
if epoch == nil {
return 0, errors.New("best block height not yet known")
}
return uint32(epoch.Height), nil
}
// BestBlockHeader gets the most recent block header known to the
// BestBlockTracker.
func (t *BestBlockTracker) BestBlockHeader() (*wire.BlockHeader, error) {
epoch := t.current.Load()
if epoch == nil {
return nil, errors.New("best block header not yet known")
}
return epoch.BlockHeader, nil
}
// updateLoop is a helper that subscribes to the underlying BlockEpochEvent
// stream and updates the internal values to match the new BlockEpochs that
// are discovered.
//
// MUST be run as a goroutine.
func (t *BestBlockTracker) updateLoop() {
defer t.wg.Done()
for {
select {
case epoch, ok := <-t.blockNtfnStream.Epochs:
if !ok {
Log.Error("dead epoch stream in " +
"BestBlockTracker")
return
}
t.current.Store(epoch)
case <-t.quit:
t.current.Store(nil)
return
}
}
}
// Start starts the BestBlockTracker. It is an error to start it if it
// is already started.
func (t *BestBlockTracker) Start() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.blockNtfnStream != nil {
return fmt.Errorf("BestBlockTracker is already started")
}
var err error
t.blockNtfnStream, err = t.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
}
t.wg.Add(1)
go t.updateLoop()
return nil
}
// Stop stops the BestBlockTracker. It is an error to stop it if it has
// not been started or if it has already been stopped.
func (t *BestBlockTracker) Stop() error {
t.mu.Lock()
defer t.mu.Unlock()
if t.blockNtfnStream == nil {
return fmt.Errorf("BestBlockTracker is not running")
}
close(t.quit)
t.wg.Wait()
t.blockNtfnStream.Cancel()
t.blockNtfnStream = nil
return nil
}
package chainntnfs
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"strings"
"sync"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
)
var (
// ErrChainNotifierShuttingDown is used when we are trying to
// measure a spend notification when notifier is already stopped.
ErrChainNotifierShuttingDown = errors.New("chain notifier shutting down")
)
// TxConfStatus denotes the status of a transaction's lookup.
type TxConfStatus uint8
const (
// TxFoundMempool denotes that the transaction was found within the
// backend node's mempool.
TxFoundMempool TxConfStatus = iota
// TxFoundIndex denotes that the transaction was found within the
// backend node's txindex.
TxFoundIndex
// TxNotFoundIndex denotes that the transaction was not found within the
// backend node's txindex.
TxNotFoundIndex
// TxFoundManually denotes that the transaction was found within the
// chain by scanning for it manually.
TxFoundManually
// TxNotFoundManually denotes that the transaction was not found within
// the chain by scanning for it manually.
TxNotFoundManually
)
// String returns the string representation of the TxConfStatus.
func (t TxConfStatus) String() string {
switch t {
case TxFoundMempool:
return "TxFoundMempool"
case TxFoundIndex:
return "TxFoundIndex"
case TxNotFoundIndex:
return "TxNotFoundIndex"
case TxFoundManually:
return "TxFoundManually"
case TxNotFoundManually:
return "TxNotFoundManually"
default:
return "unknown"
}
}
// notifierOptions is a set of functional options that allow callers to further
// modify the type of chain event notifications they receive.
type notifierOptions struct {
// includeBlock if true, then the dispatched confirmation notification
// will include the block that mined the transaction.
includeBlock bool
}
// defaultNotifierOptions returns the set of default options for the notifier.
func defaultNotifierOptions() *notifierOptions {
return ¬ifierOptions{}
}
// NotifierOption is a functional option that allows a caller to modify the
// events received from the notifier.
type NotifierOption func(*notifierOptions)
// WithIncludeBlock is an optional argument that allows the calelr to specify
// that the block that mined a transaction should be included in the response.
func WithIncludeBlock() NotifierOption {
return func(o *notifierOptions) {
o.includeBlock = true
}
}
// ChainNotifier represents a trusted source to receive notifications concerning
// targeted events on the Bitcoin blockchain. The interface specification is
// intentionally general in order to support a wide array of chain notification
// implementations such as: btcd's websockets notifications, Bitcoin Core's
// ZeroMQ notifications, various Bitcoin API services, Electrum servers, etc.
//
// Concrete implementations of ChainNotifier should be able to support multiple
// concurrent client requests, as well as multiple concurrent notification events.
type ChainNotifier interface {
// RegisterConfirmationsNtfn registers an intent to be notified once
// txid reaches numConfs confirmations. We also pass in the pkScript as
// the default light client instead needs to match on scripts created in
// the block. If a nil txid is passed in, then not only should we match
// on the script, but we should also dispatch once the transaction
// containing the script reaches numConfs confirmations. This can be
// useful in instances where we only know the script in advance, but not
// the transaction containing it.
//
// The returned ConfirmationEvent should properly notify the client once
// the specified number of confirmations has been reached for the txid,
// as well as if the original tx gets re-org'd out of the mainchain. The
// heightHint parameter is provided as a convenience to light clients.
// It heightHint denotes the earliest height in the blockchain in which
// the target txid _could_ have been included in the chain. This can be
// used to bound the search space when checking to see if a notification
// can immediately be dispatched due to historical data.
//
// NOTE: Dispatching notifications to multiple clients subscribed to
// the same (txid, numConfs) tuple MUST be supported.
RegisterConfirmationsNtfn(txid *chainhash.Hash, pkScript []byte,
numConfs, heightHint uint32,
opts ...NotifierOption) (*ConfirmationEvent, error)
// RegisterSpendNtfn registers an intent to be notified once the target
// outpoint is successfully spent within a transaction. The script that
// the outpoint creates must also be specified. This allows this
// interface to be implemented by BIP 158-like filtering. If a nil
// outpoint is passed in, then not only should we match on the script,
// but we should also dispatch once a transaction spends the output
// containing said script. This can be useful in instances where we only
// know the script in advance, but not the outpoint itself.
//
// The returned SpendEvent will receive a send on the 'Spend'
// transaction once a transaction spending the input is detected on the
// blockchain. The heightHint parameter is provided as a convenience to
// light clients. It denotes the earliest height in the blockchain in
// which the target output could have been spent.
//
// NOTE: The notification should only be triggered when the spending
// transaction receives a single confirmation.
//
// NOTE: Dispatching notifications to multiple clients subscribed to a
// spend of the same outpoint MUST be supported.
RegisterSpendNtfn(outpoint *wire.OutPoint, pkScript []byte,
heightHint uint32) (*SpendEvent, error)
// RegisterBlockEpochNtfn registers an intent to be notified of each
// new block connected to the tip of the main chain. The returned
// BlockEpochEvent struct contains a channel which will be sent upon
// for each new block discovered.
//
// Clients have the option of passing in their best known block.
// If they specify a block, the ChainNotifier checks whether the client
// is behind on blocks. If they are, the ChainNotifier sends a backlog
// of block notifications for the missed blocks. If they do not provide
// one, then a notification will be dispatched immediately for the
// current tip of the chain upon a successful registration.
RegisterBlockEpochNtfn(*BlockEpoch) (*BlockEpochEvent, error)
// Start the ChainNotifier. Once started, the implementation should be
// ready, and able to receive notification registrations from clients.
Start() error
// Started returns true if this instance has been started, and false otherwise.
Started() bool
// Stops the concrete ChainNotifier. Once stopped, the ChainNotifier
// should disallow any future requests from potential clients.
// Additionally, all pending client notifications will be canceled
// by closing the related channels on the *Event's.
Stop() error
}
// TxConfirmation carries some additional block-level details of the exact
// block that specified transactions was confirmed within.
type TxConfirmation struct {
// BlockHash is the hash of the block that confirmed the original
// transition.
BlockHash *chainhash.Hash
// BlockHeight is the height of the block in which the transaction was
// confirmed within.
BlockHeight uint32
// TxIndex is the index within the block of the ultimate confirmed
// transaction.
TxIndex uint32
// Tx is the transaction for which the notification was requested for.
Tx *wire.MsgTx
// Block is the block that contains the transaction referenced above.
//
// NOTE: This is only specified if the confirmation request opts to
// have the response include the block itself.
Block *wire.MsgBlock
}
// ConfirmationEvent encapsulates a confirmation notification. With this struct,
// callers can be notified of: the instance the target txid reaches the targeted
// number of confirmations, how many confirmations are left for the target txid
// to be fully confirmed at every new block height, and also in the event that
// the original txid becomes disconnected from the blockchain as a result of a
// re-org.
//
// Once the txid reaches the specified number of confirmations, the 'Confirmed'
// channel will be sent upon fulfilling the notification.
//
// If the event that the original transaction becomes re-org'd out of the main
// chain, the 'NegativeConf' will be sent upon with a value representing the
// depth of the re-org.
//
// NOTE: If the caller wishes to cancel their registered spend notification,
// the Cancel closure MUST be called.
type ConfirmationEvent struct {
// Confirmed is a channel that will be sent upon once the transaction
// has been fully confirmed. The struct sent will contain all the
// details of the channel's confirmation.
//
// NOTE: This channel must be buffered.
Confirmed chan *TxConfirmation
// Updates is a channel that will sent upon, at every incremental
// confirmation, how many confirmations are left to declare the
// transaction as fully confirmed.
//
// NOTE: This channel must be buffered with the number of required
// confirmations.
Updates chan uint32
// NegativeConf is a channel that will be sent upon if the transaction
// confirms, but is later reorged out of the chain. The integer sent
// through the channel represents the reorg depth.
//
// NOTE: This channel must be buffered.
NegativeConf chan int32
// Done is a channel that gets sent upon once the confirmation request
// is no longer under the risk of being reorged out of the chain.
//
// NOTE: This channel must be buffered.
Done chan struct{}
// Cancel is a closure that should be executed by the caller in the case
// that they wish to prematurely abandon their registered confirmation
// notification.
Cancel func()
}
// NewConfirmationEvent constructs a new ConfirmationEvent with newly opened
// channels.
func NewConfirmationEvent(numConfs uint32, cancel func()) *ConfirmationEvent {
return &ConfirmationEvent{
// We cannot rely on the subscriber to immediately read from
// the channel so we need to create a larger buffer to avoid
// blocking the notifier.
Confirmed: make(chan *TxConfirmation, 1),
Updates: make(chan uint32, numConfs),
NegativeConf: make(chan int32, 1),
Done: make(chan struct{}, 1),
Cancel: cancel,
}
}
// SpendDetail contains details pertaining to a spent output. This struct itself
// is the spentness notification. It includes the original outpoint which triggered
// the notification, the hash of the transaction spending the output, the
// spending transaction itself, and finally the input index which spent the
// target output.
type SpendDetail struct {
SpentOutPoint *wire.OutPoint
SpenderTxHash *chainhash.Hash
SpendingTx *wire.MsgTx
SpenderInputIndex uint32
SpendingHeight int32
}
// HasSpenderWitness returns true if the spending transaction has non-empty
// witness.
func (s *SpendDetail) HasSpenderWitness() bool {
tx := s.SpendingTx
// If there are no inputs, then there is no witness.
if len(tx.TxIn) == 0 {
return false
}
// If the spender input index is larger than the number of inputs, then
// we don't have a witness and this is an error case so we log it.
if uint32(len(tx.TxIn)) <= s.SpenderInputIndex {
Log.Errorf("SpenderInputIndex %d is out of range for tx %v",
s.SpenderInputIndex, tx.TxHash())
return false
}
// If the witness is empty, then there is no witness.
if len(tx.TxIn[s.SpenderInputIndex].Witness) == 0 {
return false
}
// If the witness is non-empty, then we have a witness.
return true
}
// String returns a string representation of SpendDetail.
func (s *SpendDetail) String() string {
return fmt.Sprintf("%v[%d] spending %v at height=%v", s.SpenderTxHash,
s.SpenderInputIndex, s.SpentOutPoint, s.SpendingHeight)
}
// SpendEvent encapsulates a spentness notification. Its only field 'Spend' will
// be sent upon once the target output passed into RegisterSpendNtfn has been
// spent on the blockchain.
//
// NOTE: If the caller wishes to cancel their registered spend notification,
// the Cancel closure MUST be called.
type SpendEvent struct {
// Spend is a receive only channel which will be sent upon once the
// target outpoint has been spent.
//
// NOTE: This channel must be buffered.
Spend chan *SpendDetail
// Reorg is a channel that will be sent upon once we detect the spending
// transaction of the outpoint in question has been reorged out of the
// chain.
//
// NOTE: This channel must be buffered.
Reorg chan struct{}
// Done is a channel that gets sent upon once the confirmation request
// is no longer under the risk of being reorged out of the chain.
//
// NOTE: This channel must be buffered.
Done chan struct{}
// Cancel is a closure that should be executed by the caller in the case
// that they wish to prematurely abandon their registered spend
// notification.
Cancel func()
}
// NewSpendEvent constructs a new SpendEvent with newly opened channels.
func NewSpendEvent(cancel func()) *SpendEvent {
return &SpendEvent{
Spend: make(chan *SpendDetail, 1),
Reorg: make(chan struct{}, 1),
Done: make(chan struct{}, 1),
Cancel: cancel,
}
}
// BlockEpoch represents metadata concerning each new block connected to the
// main chain.
type BlockEpoch struct {
// Hash is the block hash of the latest block to be added to the tip of
// the main chain.
Hash *chainhash.Hash
// Height is the height of the latest block to be added to the tip of
// the main chain.
Height int32
// BlockHeader is the block header of this new height.
BlockHeader *wire.BlockHeader
}
// BlockEpochEvent encapsulates an on-going stream of block epoch
// notifications. Its only field 'Epochs' will be sent upon for each new block
// connected to the main-chain.
//
// NOTE: If the caller wishes to cancel their registered block epoch
// notification, the Cancel closure MUST be called.
type BlockEpochEvent struct {
// Epochs is a receive only channel that will be sent upon each time a
// new block is connected to the end of the main chain.
//
// NOTE: This channel must be buffered.
Epochs <-chan *BlockEpoch
// Cancel is a closure that should be executed by the caller in the case
// that they wish to abandon their registered block epochs notification.
Cancel func()
}
// NotifierDriver represents a "driver" for a particular interface. A driver is
// identified by a globally unique string identifier along with a 'New()'
// method which is responsible for initializing a particular ChainNotifier
// concrete implementation.
type NotifierDriver struct {
// NotifierType is a string which uniquely identifies the ChainNotifier
// that this driver, drives.
NotifierType string
// New creates a new instance of a concrete ChainNotifier
// implementation given a variadic set up arguments. The function takes
// a variadic number of interface parameters in order to provide
// initialization flexibility, thereby accommodating several potential
// ChainNotifier implementations.
New func(args ...interface{}) (ChainNotifier, error)
}
var (
notifiers = make(map[string]*NotifierDriver)
registerMtx sync.Mutex
)
// RegisteredNotifiers returns a slice of all currently registered notifiers.
//
// NOTE: This function is safe for concurrent access.
func RegisteredNotifiers() []*NotifierDriver {
registerMtx.Lock()
defer registerMtx.Unlock()
drivers := make([]*NotifierDriver, 0, len(notifiers))
for _, driver := range notifiers {
drivers = append(drivers, driver)
}
return drivers
}
// RegisterNotifier registers a NotifierDriver which is capable of driving a
// concrete ChainNotifier interface. In the case that this driver has already
// been registered, an error is returned.
//
// NOTE: This function is safe for concurrent access.
func RegisterNotifier(driver *NotifierDriver) error {
registerMtx.Lock()
defer registerMtx.Unlock()
if _, ok := notifiers[driver.NotifierType]; ok {
return fmt.Errorf("notifier already registered")
}
notifiers[driver.NotifierType] = driver
return nil
}
// SupportedNotifiers returns a slice of strings that represent the database
// drivers that have been registered and are therefore supported.
//
// NOTE: This function is safe for concurrent access.
func SupportedNotifiers() []string {
registerMtx.Lock()
defer registerMtx.Unlock()
supportedNotifiers := make([]string, 0, len(notifiers))
for driverName := range notifiers {
supportedNotifiers = append(supportedNotifiers, driverName)
}
return supportedNotifiers
}
// ChainConn enables notifiers to pass in their chain backend to interface
// functions that require it.
type ChainConn interface {
// GetBlockHeader returns the block header for a hash.
GetBlockHeader(blockHash *chainhash.Hash) (*wire.BlockHeader, error)
// GetBlockHeaderVerbose returns the verbose block header for a hash.
GetBlockHeaderVerbose(blockHash *chainhash.Hash) (
*btcjson.GetBlockHeaderVerboseResult, error)
// GetBlockHash returns the hash from a block height.
GetBlockHash(blockHeight int64) (*chainhash.Hash, error)
}
// GetCommonBlockAncestorHeight takes in:
// (1) the hash of a block that has been reorged out of the main chain
// (2) the hash of the block of the same height from the main chain
// It returns the height of the nearest common ancestor between the two hashes,
// or an error
func GetCommonBlockAncestorHeight(chainConn ChainConn, reorgHash,
chainHash chainhash.Hash) (int32, error) {
for reorgHash != chainHash {
reorgHeader, err := chainConn.GetBlockHeader(&reorgHash)
if err != nil {
return 0, fmt.Errorf("unable to get header for "+
"hash=%v: %w", reorgHash, err)
}
chainHeader, err := chainConn.GetBlockHeader(&chainHash)
if err != nil {
return 0, fmt.Errorf("unable to get header for "+
"hash=%v: %w", chainHash, err)
}
reorgHash = reorgHeader.PrevBlock
chainHash = chainHeader.PrevBlock
}
verboseHeader, err := chainConn.GetBlockHeaderVerbose(&chainHash)
if err != nil {
return 0, fmt.Errorf("unable to get verbose header for "+
"hash=%v: %w", chainHash, err)
}
return verboseHeader.Height, nil
}
// GetClientMissedBlocks uses a client's best block to determine what blocks
// it missed being notified about, and returns them in a slice. Its
// backendStoresReorgs parameter tells it whether or not the notifier's
// chainConn stores information about blocks that have been reorged out of the
// chain, which allows GetClientMissedBlocks to find out whether the client's
// best block has been reorged out of the chain, rewind to the common ancestor
// and return blocks starting right after the common ancestor.
func GetClientMissedBlocks(chainConn ChainConn, clientBestBlock *BlockEpoch,
notifierBestHeight int32, backendStoresReorgs bool) ([]BlockEpoch, error) {
startingHeight := clientBestBlock.Height
if backendStoresReorgs {
// If a reorg causes the client's best hash to be incorrect,
// retrieve the closest common ancestor and dispatch
// notifications from there.
hashAtBestHeight, err := chainConn.GetBlockHash(
int64(clientBestBlock.Height))
if err != nil {
return nil, fmt.Errorf("unable to find blockhash for "+
"height=%d: %v", clientBestBlock.Height, err)
}
startingHeight, err = GetCommonBlockAncestorHeight(
chainConn, *clientBestBlock.Hash, *hashAtBestHeight,
)
if err != nil {
return nil, fmt.Errorf("unable to find common ancestor: "+
"%v", err)
}
}
// We want to start dispatching historical notifications from the block
// right after the client's best block, to avoid a redundant notification.
missedBlocks, err := getMissedBlocks(
chainConn, startingHeight+1, notifierBestHeight+1,
)
if err != nil {
return nil, fmt.Errorf("unable to get missed blocks: %w", err)
}
return missedBlocks, nil
}
// RewindChain handles internal state updates for the notifier's TxNotifier. It
// has no effect if given a height greater than or equal to our current best
// known height. It returns the new best block for the notifier.
func RewindChain(chainConn ChainConn, txNotifier *TxNotifier,
currBestBlock BlockEpoch, targetHeight int32) (BlockEpoch, error) {
newBestBlock := BlockEpoch{
Height: currBestBlock.Height,
Hash: currBestBlock.Hash,
BlockHeader: currBestBlock.BlockHeader,
}
for height := currBestBlock.Height; height > targetHeight; height-- {
hash, err := chainConn.GetBlockHash(int64(height - 1))
if err != nil {
return newBestBlock, fmt.Errorf("unable to "+
"find blockhash for disconnected height=%d: %v",
height, err)
}
header, err := chainConn.GetBlockHeader(hash)
if err != nil {
return newBestBlock, fmt.Errorf("unable to get block "+
"header for height=%v", height-1)
}
Log.Infof("Block disconnected from main chain: "+
"height=%v, sha=%v", height, newBestBlock.Hash)
err = txNotifier.DisconnectTip(uint32(height))
if err != nil {
return newBestBlock, fmt.Errorf("unable to "+
" disconnect tip for height=%d: %v",
height, err)
}
newBestBlock.Height = height - 1
newBestBlock.Hash = hash
newBestBlock.BlockHeader = header
}
return newBestBlock, nil
}
// HandleMissedBlocks is called when the chain backend for a notifier misses a
// series of blocks, handling a reorg if necessary. Its backendStoresReorgs
// parameter tells it whether or not the notifier's chainConn stores
// information about blocks that have been reorged out of the chain, which allows
// HandleMissedBlocks to check whether the notifier's best block has been
// reorged out, and rewind the chain accordingly. It returns the best block for
// the notifier and a slice of the missed blocks. The new best block needs to be
// returned in case a chain rewind occurs and partially completes before
// erroring. In the case where there is no rewind, the notifier's
// current best block is returned.
func HandleMissedBlocks(chainConn ChainConn, txNotifier *TxNotifier,
currBestBlock BlockEpoch, newHeight int32,
backendStoresReorgs bool) (BlockEpoch, []BlockEpoch, error) {
startingHeight := currBestBlock.Height
if backendStoresReorgs {
// If a reorg causes our best hash to be incorrect, rewind the
// chain so our best block is set to the closest common
// ancestor, then dispatch notifications from there.
hashAtBestHeight, err := chainConn.GetBlockHash(
int64(currBestBlock.Height),
)
if err != nil {
return currBestBlock, nil, fmt.Errorf("unable to find "+
"blockhash for height=%d: %v",
currBestBlock.Height, err)
}
startingHeight, err = GetCommonBlockAncestorHeight(
chainConn, *currBestBlock.Hash, *hashAtBestHeight,
)
if err != nil {
return currBestBlock, nil, fmt.Errorf("unable to find "+
"common ancestor: %v", err)
}
currBestBlock, err = RewindChain(
chainConn, txNotifier, currBestBlock, startingHeight,
)
if err != nil {
return currBestBlock, nil, fmt.Errorf("unable to "+
"rewind chain: %v", err)
}
}
// We want to start dispatching historical notifications from the block
// right after our best block, to avoid a redundant notification.
missedBlocks, err := getMissedBlocks(chainConn, startingHeight+1, newHeight)
if err != nil {
return currBestBlock, nil, fmt.Errorf("unable to get missed "+
"blocks: %v", err)
}
return currBestBlock, missedBlocks, nil
}
// getMissedBlocks returns a slice of blocks: [startingHeight, endingHeight)
// fetched from the chain.
func getMissedBlocks(chainConn ChainConn, startingHeight,
endingHeight int32) ([]BlockEpoch, error) {
numMissedBlocks := endingHeight - startingHeight
if numMissedBlocks < 0 {
return nil, fmt.Errorf("starting height %d is greater than "+
"ending height %d", startingHeight, endingHeight)
}
missedBlocks := make([]BlockEpoch, 0, numMissedBlocks)
for height := startingHeight; height < endingHeight; height++ {
hash, err := chainConn.GetBlockHash(int64(height))
if err != nil {
return nil, fmt.Errorf("unable to find blockhash for "+
"height=%d: %v", height, err)
}
header, err := chainConn.GetBlockHeader(hash)
if err != nil {
return nil, fmt.Errorf("unable to find block header "+
"for height=%d: %v", height, err)
}
missedBlocks = append(
missedBlocks,
BlockEpoch{
Hash: hash,
Height: height,
BlockHeader: header,
},
)
}
return missedBlocks, nil
}
// TxIndexConn abstracts an RPC backend with txindex enabled.
type TxIndexConn interface {
// GetRawTransactionVerbose returns the transaction identified by the
// passed chain hash, and returns additional information such as the
// block that the transaction confirmed.
GetRawTransactionVerbose(*chainhash.Hash) (*btcjson.TxRawResult, error)
// GetBlock returns the block identified by the chain hash.
GetBlock(*chainhash.Hash) (*wire.MsgBlock, error)
}
// ConfDetailsFromTxIndex looks up whether a transaction is already included in
// a block in the active chain by using the backend node's transaction index.
// If the transaction is found its TxConfStatus is returned. If it was found in
// the mempool this will be TxFoundMempool, if it is found in a block this will
// be TxFoundIndex. Otherwise TxNotFoundIndex is returned. If the tx is found
// in a block its confirmation details are also returned.
func ConfDetailsFromTxIndex(chainConn TxIndexConn, r ConfRequest,
txNotFoundErr string) (*TxConfirmation, TxConfStatus, error) {
// If the transaction has some or all of its confirmations required,
// then we may be able to dispatch it immediately.
rawTxRes, err := chainConn.GetRawTransactionVerbose(&r.TxID)
if err != nil {
// If the transaction lookup was successful, but it wasn't found
// within the index itself, then we can exit early. We'll also
// need to look at the error message returned as the error code
// is used for multiple errors.
jsonErr, ok := err.(*btcjson.RPCError)
if ok && jsonErr.Code == btcjson.ErrRPCNoTxInfo &&
strings.Contains(jsonErr.Message, txNotFoundErr) {
return nil, TxNotFoundIndex, nil
}
return nil, TxNotFoundIndex,
fmt.Errorf("unable to query for txid %v: %w",
r.TxID, err)
}
// Deserialize the hex-encoded transaction to include it in the
// confirmation details.
rawTx, err := hex.DecodeString(rawTxRes.Hex)
if err != nil {
return nil, TxNotFoundIndex,
fmt.Errorf("unable to deserialize tx %v: %w",
r.TxID, err)
}
var tx wire.MsgTx
if err := tx.Deserialize(bytes.NewReader(rawTx)); err != nil {
return nil, TxNotFoundIndex,
fmt.Errorf("unable to deserialize tx %v: %w",
r.TxID, err)
}
// Ensure the transaction matches our confirmation request in terms of
// txid and pkscript.
if !r.MatchesTx(&tx) {
return nil, TxNotFoundIndex,
fmt.Errorf("unable to locate tx %v", r.TxID)
}
// Make sure we actually retrieved a transaction that is included in a
// block. If not, the transaction must be unconfirmed (in the mempool),
// and we'll return TxFoundMempool together with a nil TxConfirmation.
if rawTxRes.BlockHash == "" {
return nil, TxFoundMempool, nil
}
// As we need to fully populate the returned TxConfirmation struct,
// grab the block in which the transaction was confirmed so we can
// locate its exact index within the block.
blockHash, err := chainhash.NewHashFromStr(rawTxRes.BlockHash)
if err != nil {
return nil, TxNotFoundIndex,
fmt.Errorf("unable to get block hash %v for "+
"historical dispatch: %w", rawTxRes.BlockHash,
err)
}
block, err := chainConn.GetBlock(blockHash)
if err != nil {
return nil, TxNotFoundIndex,
fmt.Errorf("unable to get block with hash %v for "+
"historical dispatch: %w", blockHash, err)
}
// In the modern chain (the only one we really care about for LN), the
// coinbase transaction of all blocks will include the block height.
// Therefore we can save another query, and just use that height
// directly.
blockHeight, err := blockchain.ExtractCoinbaseHeight(
btcutil.NewTx(block.Transactions[0]),
)
if err != nil {
return nil, TxNotFoundIndex, fmt.Errorf("unable to extract "+
"coinbase height: %w", err)
}
// If the block was obtained, locate the transaction's index within the
// block so we can give the subscriber full confirmation details.
for txIndex, blockTx := range block.Transactions {
if blockTx.TxHash() != r.TxID {
continue
}
return &TxConfirmation{
Tx: tx.Copy(),
BlockHash: blockHash,
BlockHeight: uint32(blockHeight),
TxIndex: uint32(txIndex),
Block: block,
}, TxFoundIndex, nil
}
// We return an error because we should have found the transaction
// within the block, but didn't.
return nil, TxNotFoundIndex, fmt.Errorf("unable to locate "+
"tx %v in block %v", r.TxID, blockHash)
}
// SpendHintCache is an interface whose duty is to cache spend hints for
// outpoints. A spend hint is defined as the earliest height in the chain at
// which an outpoint could have been spent within.
type SpendHintCache interface {
// CommitSpendHint commits a spend hint for the outpoints to the cache.
CommitSpendHint(height uint32, spendRequests ...SpendRequest) error
// QuerySpendHint returns the latest spend hint for an outpoint.
// ErrSpendHintNotFound is returned if a spend hint does not exist
// within the cache for the outpoint.
QuerySpendHint(spendRequest SpendRequest) (uint32, error)
// PurgeSpendHint removes the spend hint for the outpoints from the
// cache.
PurgeSpendHint(spendRequests ...SpendRequest) error
}
// ConfirmHintCache is an interface whose duty is to cache confirm hints for
// transactions. A confirm hint is defined as the earliest height in the chain
// at which a transaction could have been included in a block.
type ConfirmHintCache interface {
// CommitConfirmHint commits a confirm hint for the transactions to the
// cache.
CommitConfirmHint(height uint32, confRequests ...ConfRequest) error
// QueryConfirmHint returns the latest confirm hint for a transaction
// hash. ErrConfirmHintNotFound is returned if a confirm hint does not
// exist within the cache for the transaction hash.
QueryConfirmHint(confRequest ConfRequest) (uint32, error)
// PurgeConfirmHint removes the confirm hint for the transactions from
// the cache.
PurgeConfirmHint(confRequests ...ConfRequest) error
}
// MempoolWatcher defines an interface that allows the caller to query
// information in the mempool.
type MempoolWatcher interface {
// SubscribeMempoolSpent allows the caller to register a subscription
// to watch for a spend of an outpoint in the mempool.The event will be
// dispatched once the outpoint is spent in the mempool.
SubscribeMempoolSpent(op wire.OutPoint) (*MempoolSpendEvent, error)
// CancelMempoolSpendEvent allows the caller to cancel a subscription to
// watch for a spend of an outpoint in the mempool.
CancelMempoolSpendEvent(sub *MempoolSpendEvent)
// LookupInputMempoolSpend looks up the mempool to find a spending tx
// which spends the given outpoint. A fn.None is returned if it's not
// found.
LookupInputMempoolSpend(op wire.OutPoint) fn.Option[wire.MsgTx]
}
package chainntnfs
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// Log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var Log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("NTFN", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
Log = logger
}
package chainntnfs
import (
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnutils"
)
// inputsWithTx is a map of outpoints to the tx that spends them.
type inputsWithTx map[wire.OutPoint]*SpendDetail
// MempoolNotifier defines an internal mempool notifier that's used to notify
// the spending of given inputs. This is mounted to either `BitcoindNotifier`
// or `BtcdNotifier` depending on the chain backend.
type MempoolNotifier struct {
wg sync.WaitGroup
// subscribedInputs stores the inputs that we want to watch their
// spending event for.
subscribedInputs *lnutils.SyncMap[wire.OutPoint,
*lnutils.SyncMap[uint64, *MempoolSpendEvent]]
// sCounter is used to generate unique subscription IDs.
sCounter atomic.Uint64
// quit is closed when the notifier is torn down.
quit chan struct{}
}
// MempoolSpendEvent is returned to the subscriber to watch for the spending
// event for the requested input.
type MempoolSpendEvent struct {
// Spend is a receive only channel which will be sent upon once the
// target outpoint has been spent.
//
// NOTE: This channel must be buffered.
Spend <-chan *SpendDetail
// id is the unique identifier of this subscription.
id uint64
// outpoint is the subscribed outpoint.
outpoint wire.OutPoint
// event is the channel that will be sent upon once the target outpoint
// is spent.
event chan *SpendDetail
// cancel cancels the subscription.
cancel chan struct{}
}
// newMempoolSpendEvent returns a new instance of MempoolSpendEvent.
func newMempoolSpendEvent(id uint64, op wire.OutPoint) *MempoolSpendEvent {
sub := &MempoolSpendEvent{
id: id,
outpoint: op,
event: make(chan *SpendDetail, 1),
cancel: make(chan struct{}),
}
// Mount the receive only channel to the event channel.
sub.Spend = (<-chan *SpendDetail)(sub.event)
return sub
}
// NewMempoolNotifier takes a chain connection and returns a new mempool
// notifier.
func NewMempoolNotifier() *MempoolNotifier {
return &MempoolNotifier{
subscribedInputs: &lnutils.SyncMap[
wire.OutPoint, *lnutils.SyncMap[
uint64, *MempoolSpendEvent,
]]{},
quit: make(chan struct{}),
}
}
// SubscribeInput takes an outpoint of the input and returns a channel that the
// subscriber can listen to for the spending event.
func (m *MempoolNotifier) SubscribeInput(
outpoint wire.OutPoint) *MempoolSpendEvent {
// Get the current subscribers for this input or create a new one.
clients := &lnutils.SyncMap[uint64, *MempoolSpendEvent]{}
clients, _ = m.subscribedInputs.LoadOrStore(outpoint, clients)
// Increment the subscription counter and return the new value.
subscriptionID := m.sCounter.Add(1)
// Create a new subscription.
sub := newMempoolSpendEvent(subscriptionID, outpoint)
// Add the subscriber with a unique id.
clients.Store(subscriptionID, sub)
// Update the subscribed inputs.
m.subscribedInputs.Store(outpoint, clients)
Log.Debugf("Subscribed(id=%v) mempool event for input=%s",
subscriptionID, outpoint)
return sub
}
// UnsubscribeInput removes all the subscriptions for the given outpoint.
func (m *MempoolNotifier) UnsubscribeInput(outpoint wire.OutPoint) {
Log.Debugf("Unsubscribing MempoolSpendEvent for input %s", outpoint)
m.subscribedInputs.Delete(outpoint)
}
// UnsubscribeEvent removes a given subscriber for the given MempoolSpendEvent.
func (m *MempoolNotifier) UnsubscribeEvent(sub *MempoolSpendEvent) {
Log.Debugf("Unsubscribing(id=%v) MempoolSpendEvent for input=%s",
sub.id, sub.outpoint)
// Load all the subscribers for this input.
clients, loaded := m.subscribedInputs.Load(sub.outpoint)
if !loaded {
Log.Debugf("No subscribers for input %s", sub.outpoint)
return
}
// Load the subscriber.
subscriber, loaded := clients.Load(sub.id)
if !loaded {
Log.Debugf("No subscribers for input %s with id %v",
sub.outpoint, sub.id)
return
}
// Close the cancel channel in case it's been used in a goroutine.
close(subscriber.cancel)
// Remove the subscriber.
clients.Delete(sub.id)
}
// UnsubsribeConfirmedSpentTx takes a transaction and removes the subscriptions
// identified using its inputs.
func (m *MempoolNotifier) UnsubsribeConfirmedSpentTx(tx *btcutil.Tx) {
Log.Tracef("Unsubscribe confirmed tx %s", tx.Hash())
// Get the spent inputs of interest.
spentInputs, err := m.findRelevantInputs(tx)
if err != nil {
Log.Errorf("Unable to find relevant inputs for tx %s: %v",
tx.Hash(), err)
return
}
// Unsubscribe the subscribers.
for outpoint := range spentInputs {
m.UnsubscribeInput(outpoint)
}
Log.Tracef("Finished unsubscribing confirmed tx %s, found %d inputs",
tx.Hash(), len(spentInputs))
}
// ProcessRelevantSpendTx takes a transaction and checks whether it spends any
// of the subscribed inputs. If so, spend notifications are sent to the
// relevant subscribers.
func (m *MempoolNotifier) ProcessRelevantSpendTx(tx *btcutil.Tx) error {
Log.Tracef("Processing mempool tx %s", tx.Hash())
defer Log.Tracef("Finished processing mempool tx %s", tx.Hash())
// Get the spent inputs of interest.
spentInputs, err := m.findRelevantInputs(tx)
if err != nil {
return err
}
// Notify the subscribers.
m.notifySpent(spentInputs)
return nil
}
// TearDown stops the notifier and cleans up resources.
func (m *MempoolNotifier) TearDown() {
Log.Infof("Stopping mempool notifier")
defer Log.Debug("mempool notifier stopped")
close(m.quit)
m.wg.Wait()
}
// findRelevantInputs takes a transaction to find the subscribed inputs and
// returns them.
func (m *MempoolNotifier) findRelevantInputs(tx *btcutil.Tx) (inputsWithTx,
error) {
txid := tx.Hash()
watchedInputs := make(inputsWithTx)
// NOTE: we may have found multiple targeted inputs in the same tx.
for i, input := range tx.MsgTx().TxIn {
op := &input.PreviousOutPoint
// Check whether this input is subscribed.
_, loaded := m.subscribedInputs.Load(*op)
if !loaded {
continue
}
// If found, save it to watchedInputs to notify the
// subscriber later.
Log.Debugf("Found input %s, spent in %s", op, txid)
// Construct the spend details.
details := &SpendDetail{
SpentOutPoint: op,
SpenderTxHash: txid,
SpendingTx: tx.MsgTx().Copy(),
SpenderInputIndex: uint32(i),
SpendingHeight: 0,
}
watchedInputs[*op] = details
// Sanity check the witness stack. If it's not empty, continue
// to next iteration.
if details.HasSpenderWitness() {
continue
}
// Return an error if the witness data is not present in the
// spending transaction.
Log.Criticalf("Found spending tx for outpoint=%v in mempool, "+
"but the transaction %v does not have witness",
op, details.SpendingTx.TxHash())
return nil, ErrEmptyWitnessStack
}
return watchedInputs, nil
}
// notifySpent iterates all the spentInputs and notifies the subscribers about
// the spent details.
func (m *MempoolNotifier) notifySpent(spentInputs inputsWithTx) {
// notifySingle sends a notification to a single subscriber about the
// spending event.
//
// NOTE: must be used inside a goroutine.
notifySingle := func(id uint64, sub *MempoolSpendEvent,
op wire.OutPoint, detail *SpendDetail) {
defer m.wg.Done()
// Send the spend details to the subscriber.
select {
case sub.event <- detail:
Log.Debugf("Notified(id=%v) mempool spent for input %s",
sub.id, op)
case <-sub.cancel:
Log.Debugf("Subscription(id=%v) canceled, skipped "+
"notifying spent for input %s", sub.id, op)
case <-m.quit:
Log.Debugf("Mempool notifier quit, skipped notifying "+
"mempool spent for input %s", op)
}
}
// notifyAll is a helper closure that constructs a spend detail and
// sends it to all the subscribers of that particular input.
//
// NOTE: must be used inside a goroutine.
notifyAll := func(detail *SpendDetail, op wire.OutPoint) {
defer m.wg.Done()
txid := detail.SpendingTx.TxHash()
Log.Debugf("Notifying all clients for the spend of %s in tx %s",
op, txid)
// Load the subscriber.
subs, loaded := m.subscribedInputs.Load(op)
if !loaded {
Log.Errorf("Sub not found for %s", op)
return
}
// Iterate all the subscribers for this input and notify them.
subs.ForEach(func(id uint64, sub *MempoolSpendEvent) error {
m.wg.Add(1)
go notifySingle(id, sub, op, detail)
return nil
})
}
// Iterate the spent inputs to notify the subscribers concurrently.
for op, tx := range spentInputs {
op, tx := op, tx
m.wg.Add(1)
go notifyAll(tx, op)
}
}
package chainntnfs
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/stretchr/testify/mock"
)
// MockMempoolWatcher is a mock implementation of the MempoolWatcher interface.
// This is used by other subsystems to mock the behavior of the mempool
// watcher.
type MockMempoolWatcher struct {
mock.Mock
}
// NewMockMempoolWatcher returns a new instance of a mock mempool watcher.
func NewMockMempoolWatcher() *MockMempoolWatcher {
return &MockMempoolWatcher{}
}
// Compile-time check to ensure MockMempoolWatcher implements MempoolWatcher.
var _ MempoolWatcher = (*MockMempoolWatcher)(nil)
// SubscribeMempoolSpent implements the MempoolWatcher interface.
func (m *MockMempoolWatcher) SubscribeMempoolSpent(
op wire.OutPoint) (*MempoolSpendEvent, error) {
args := m.Called(op)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*MempoolSpendEvent), args.Error(1)
}
// CancelMempoolSpendEvent implements the MempoolWatcher interface.
func (m *MockMempoolWatcher) CancelMempoolSpendEvent(
sub *MempoolSpendEvent) {
m.Called(sub)
}
// LookupInputMempoolSpend looks up the mempool to find a spending tx which
// spends the given outpoint.
func (m *MockMempoolWatcher) LookupInputMempoolSpend(
op wire.OutPoint) fn.Option[wire.MsgTx] {
args := m.Called(op)
return args.Get(0).(fn.Option[wire.MsgTx])
}
// MockNotifier is a mock implementation of the ChainNotifier interface.
type MockChainNotifier struct {
mock.Mock
}
// Compile-time check to ensure MockChainNotifier implements ChainNotifier.
var _ ChainNotifier = (*MockChainNotifier)(nil)
// RegisterConfirmationsNtfn registers an intent to be notified once txid
// reaches numConfs confirmations.
func (m *MockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts ...NotifierOption) (*ConfirmationEvent, error) {
args := m.Called(txid, pkScript, numConfs, heightHint)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*ConfirmationEvent), args.Error(1)
}
// RegisterSpendNtfn registers an intent to be notified once the target
// outpoint is successfully spent within a transaction.
func (m *MockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*SpendEvent, error) {
args := m.Called(outpoint, pkScript, heightHint)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*SpendEvent), args.Error(1)
}
// RegisterBlockEpochNtfn registers an intent to be notified of each new block
// connected to the tip of the main chain.
func (m *MockChainNotifier) RegisterBlockEpochNtfn(epoch *BlockEpoch) (
*BlockEpochEvent, error) {
args := m.Called(epoch)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*BlockEpochEvent), args.Error(1)
}
// Start the ChainNotifier. Once started, the implementation should be ready,
// and able to receive notification registrations from clients.
func (m *MockChainNotifier) Start() error {
args := m.Called()
return args.Error(0)
}
// Started returns true if this instance has been started, and false otherwise.
func (m *MockChainNotifier) Started() bool {
args := m.Called()
return args.Bool(0)
}
// Stops the concrete ChainNotifier.
func (m *MockChainNotifier) Stop() error {
args := m.Called()
return args.Error(0)
}
package chainntnfs
import (
"bytes"
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
)
const (
// ReorgSafetyLimit is the chain depth beyond which it is assumed a
// block will not be reorganized out of the chain. This is used to
// determine when to prune old confirmation requests so that reorgs are
// handled correctly. The average number of blocks in a day is a
// reasonable value to use.
ReorgSafetyLimit = 144
// MaxNumConfs is the maximum number of confirmations that can be
// requested on a transaction.
MaxNumConfs = ReorgSafetyLimit
)
var (
// ZeroHash is the value that should be used as the txid when
// registering for the confirmation of a script on-chain. This allows
// the notifier to match _and_ dispatch upon the inclusion of the script
// on-chain, rather than the txid.
ZeroHash chainhash.Hash
// ZeroOutPoint is the value that should be used as the outpoint when
// registering for the spend of a script on-chain. This allows the
// notifier to match _and_ dispatch upon detecting the spend of the
// script on-chain, rather than the outpoint.
ZeroOutPoint wire.OutPoint
// zeroV1KeyPush is a pkScript that pushes an all-zero 32-byte Taproot
// SegWit v1 key to the stack.
zeroV1KeyPush = [34]byte{
txscript.OP_1, txscript.OP_DATA_32, // 32 byte of zeroes here
}
// ZeroTaprootPkScript is the parsed txscript.PkScript of an empty
// Taproot SegWit v1 key being pushed to the stack. This allows the
// notifier to match _and_ dispatch upon detecting the spend of the
// outpoint on-chain, rather than the pkScript (which cannot be derived
// from the witness alone in the SegWit v1 case).
ZeroTaprootPkScript, _ = txscript.ParsePkScript(zeroV1KeyPush[:])
)
var (
// ErrTxNotifierExiting is an error returned when attempting to interact
// with the TxNotifier but it been shut down.
ErrTxNotifierExiting = errors.New("TxNotifier is exiting")
// ErrNoScript is an error returned when a confirmation/spend
// registration is attempted without providing an accompanying output
// script.
ErrNoScript = errors.New("an output script must be provided")
// ErrNoHeightHint is an error returned when a confirmation/spend
// registration is attempted without providing an accompanying height
// hint.
ErrNoHeightHint = errors.New("a height hint greater than 0 must be " +
"provided")
// ErrNumConfsOutOfRange is an error returned when a confirmation/spend
// registration is attempted and the number of confirmations provided is
// out of range.
ErrNumConfsOutOfRange = fmt.Errorf("number of confirmations must be "+
"between %d and %d", 1, MaxNumConfs)
// ErrEmptyWitnessStack is returned when a spending transaction has an
// empty witness stack. More details in,
// - https://github.com/bitcoin/bitcoin/issues/28730
ErrEmptyWitnessStack = errors.New("witness stack is empty")
)
// rescanState indicates the progression of a registration before the notifier
// can begin dispatching confirmations at tip.
type rescanState byte
const (
// rescanNotStarted is the initial state, denoting that a historical
// dispatch may be required.
rescanNotStarted rescanState = iota
// rescanPending indicates that a dispatch has already been made, and we
// are waiting for its completion. No other rescans should be dispatched
// while in this state.
rescanPending
// rescanComplete signals either that a rescan was dispatched and has
// completed, or that we began watching at tip immediately. In either
// case, the notifier can only dispatch notifications from tip when in
// this state.
rescanComplete
)
// confNtfnSet holds all known, registered confirmation notifications for a
// txid/output script. If duplicates notifications are requested, only one
// historical dispatch will be spawned to ensure redundant scans are not
// permitted. A single conf detail will be constructed and dispatched to all
// interested
// clients.
type confNtfnSet struct {
// ntfns keeps tracks of all the active client notification requests for
// a transaction/output script
ntfns map[uint64]*ConfNtfn
// rescanStatus represents the current rescan state for the
// transaction/output script.
rescanStatus rescanState
// details serves as a cache of the confirmation details of a
// transaction that we'll use to determine if a transaction/output
// script has already confirmed at the time of registration.
// details is also used to make sure that in case of an address reuse
// (funds sent to a previously confirmed script) no additional
// notification is registered which would lead to an inconsistent state.
details *TxConfirmation
}
// newConfNtfnSet constructs a fresh confNtfnSet for a group of clients
// interested in a notification for a particular txid.
func newConfNtfnSet() *confNtfnSet {
return &confNtfnSet{
ntfns: make(map[uint64]*ConfNtfn),
rescanStatus: rescanNotStarted,
}
}
// spendNtfnSet holds all known, registered spend notifications for a spend
// request (outpoint/output script). If duplicate notifications are requested,
// only one historical dispatch will be spawned to ensure redundant scans are
// not permitted.
type spendNtfnSet struct {
// ntfns keeps tracks of all the active client notification requests for
// an outpoint/output script.
ntfns map[uint64]*SpendNtfn
// rescanStatus represents the current rescan state for the spend
// request (outpoint/output script).
rescanStatus rescanState
// details serves as a cache of the spend details for an outpoint/output
// script that we'll use to determine if it has already been spent at
// the time of registration.
details *SpendDetail
}
// newSpendNtfnSet constructs a new spend notification set.
func newSpendNtfnSet() *spendNtfnSet {
return &spendNtfnSet{
ntfns: make(map[uint64]*SpendNtfn),
rescanStatus: rescanNotStarted,
}
}
// ConfRequest encapsulates a request for a confirmation notification of either
// a txid or output script.
type ConfRequest struct {
// TxID is the hash of the transaction for which confirmation
// notifications are requested. If set to a zero hash, then a
// confirmation notification will be dispatched upon inclusion of the
// _script_, rather than the txid.
TxID chainhash.Hash
// PkScript is the public key script of an outpoint created in this
// transaction.
PkScript txscript.PkScript
}
// NewConfRequest creates a request for a confirmation notification of either a
// txid or output script. A nil txid or an allocated ZeroHash can be used to
// dispatch the confirmation notification on the script.
func NewConfRequest(txid *chainhash.Hash, pkScript []byte) (ConfRequest, error) {
var r ConfRequest
outputScript, err := txscript.ParsePkScript(pkScript)
if err != nil {
return r, err
}
// We'll only set a txid for which we'll dispatch a confirmation
// notification on this request if one was provided. Otherwise, we'll
// default to dispatching on the confirmation of the script instead.
if txid != nil {
r.TxID = *txid
}
r.PkScript = outputScript
return r, nil
}
// String returns the string representation of the ConfRequest.
func (r ConfRequest) String() string {
if r.TxID != ZeroHash {
return fmt.Sprintf("txid=%v", r.TxID)
}
return fmt.Sprintf("script=%v", r.PkScript)
}
// MatchesTx determines whether the given transaction satisfies the confirmation
// request. If the confirmation request is for a script, then we'll check all of
// the outputs of the transaction to determine if it matches. Otherwise, we'll
// match on the txid.
func (r ConfRequest) MatchesTx(tx *wire.MsgTx) bool {
scriptMatches := func() bool {
pkScript := r.PkScript.Script()
for _, txOut := range tx.TxOut {
if bytes.Equal(txOut.PkScript, pkScript) {
return true
}
}
return false
}
if r.TxID != ZeroHash {
return r.TxID == tx.TxHash() && scriptMatches()
}
return scriptMatches()
}
// ConfNtfn represents a notifier client's request to receive a notification
// once the target transaction/output script gets sufficient confirmations. The
// client is asynchronously notified via the ConfirmationEvent channels.
type ConfNtfn struct {
// ConfID uniquely identifies the confirmation notification request for
// the specified transaction/output script.
ConfID uint64
// ConfRequest represents either the txid or script we should detect
// inclusion of within the chain.
ConfRequest
// NumConfirmations is the number of confirmations after which the
// notification is to be sent.
NumConfirmations uint32
// Event contains references to the channels that the notifications are
// to be sent over.
Event *ConfirmationEvent
// HeightHint is the minimum height in the chain that we expect to find
// this txid.
HeightHint uint32
// dispatched is false if the confirmed notification has not been sent
// yet.
dispatched bool
// includeBlock is true if the dispatched notification should also have
// the block included with it.
includeBlock bool
// numConfsLeft is the number of confirmations left to be sent to the
// subscriber.
numConfsLeft uint32
}
// HistoricalConfDispatch parametrizes a manual rescan for a particular
// transaction/output script. The parameters include the start and end block
// heights specifying the range of blocks to scan.
type HistoricalConfDispatch struct {
// ConfRequest represents either the txid or script we should detect
// inclusion of within the chain.
ConfRequest
// StartHeight specifies the block height at which to begin the
// historical rescan.
StartHeight uint32
// EndHeight specifies the last block height (inclusive) that the
// historical scan should consider.
EndHeight uint32
}
// ConfRegistration encompasses all of the information required for callers to
// retrieve details about a confirmation event.
type ConfRegistration struct {
// Event contains references to the channels that the notifications are
// to be sent over.
Event *ConfirmationEvent
// HistoricalDispatch, if non-nil, signals to the client who registered
// the notification that they are responsible for attempting to manually
// rescan blocks for the txid/output script between the start and end
// heights.
HistoricalDispatch *HistoricalConfDispatch
// Height is the height of the TxNotifier at the time the confirmation
// notification was registered. This can be used so that backends can
// request to be notified of confirmations from this point forwards.
Height uint32
}
// SpendRequest encapsulates a request for a spend notification of either an
// outpoint or output script.
type SpendRequest struct {
// OutPoint is the outpoint for which a client has requested a spend
// notification for. If set to a zero outpoint, then a spend
// notification will be dispatched upon detecting the spend of the
// _script_, rather than the outpoint.
OutPoint wire.OutPoint
// PkScript is the script of the outpoint. If a zero outpoint is set,
// then this can be an arbitrary script.
PkScript txscript.PkScript
}
// NewSpendRequest creates a request for a spend notification of either an
// outpoint or output script. A nil outpoint or an allocated ZeroOutPoint can be
// used to dispatch the confirmation notification on the script.
func NewSpendRequest(op *wire.OutPoint, pkScript []byte) (SpendRequest, error) {
var r SpendRequest
outputScript, err := txscript.ParsePkScript(pkScript)
if err != nil {
return r, err
}
// We'll only set an outpoint for which we'll dispatch a spend
// notification on this request if one was provided. Otherwise, we'll
// default to dispatching on the spend of the script instead.
if op != nil {
r.OutPoint = *op
}
r.PkScript = outputScript
// For Taproot spends we have the main problem that for the key spend
// path we cannot derive the pkScript from only looking at the input's
// witness. So we need to rely on the outpoint information alone.
//
// TODO(guggero): For script path spends we can derive the pkScript from
// the witness, since we have the full control block and the spent
// script available.
if outputScript.Class() == txscript.WitnessV1TaprootTy {
if op == nil {
return r, fmt.Errorf("cannot register witness v1 " +
"spend request without outpoint")
}
// We have an outpoint, so we can set the pkScript to an all
// zero Taproot key that we'll compare this spend request to.
r.PkScript = ZeroTaprootPkScript
}
return r, nil
}
// String returns the string representation of the SpendRequest.
func (r SpendRequest) String() string {
if r.OutPoint != ZeroOutPoint {
return fmt.Sprintf("outpoint=%v, script=%v", r.OutPoint,
r.PkScript)
}
return fmt.Sprintf("outpoint=<zero>, script=%v", r.PkScript)
}
// MatchesTx determines whether the given transaction satisfies the spend
// request. If the spend request is for an outpoint, then we'll check all of
// the outputs being spent by the inputs of the transaction to determine if it
// matches. Otherwise, we'll need to match on the output script being spent, so
// we'll recompute it for each input of the transaction to determine if it
// matches.
func (r SpendRequest) MatchesTx(tx *wire.MsgTx) (bool, uint32, error) {
if r.OutPoint != ZeroOutPoint {
for i, txIn := range tx.TxIn {
if txIn.PreviousOutPoint == r.OutPoint {
return true, uint32(i), nil
}
}
return false, 0, nil
}
for i, txIn := range tx.TxIn {
pkScript, err := txscript.ComputePkScript(
txIn.SignatureScript, txIn.Witness,
)
if err == txscript.ErrUnsupportedScriptType {
continue
}
if err != nil {
return false, 0, err
}
if bytes.Equal(pkScript.Script(), r.PkScript.Script()) {
return true, uint32(i), nil
}
}
return false, 0, nil
}
// SpendNtfn represents a client's request to receive a notification once an
// outpoint/output script has been spent on-chain. The client is asynchronously
// notified via the SpendEvent channels.
type SpendNtfn struct {
// SpendID uniquely identies the spend notification request for the
// specified outpoint/output script.
SpendID uint64
// SpendRequest represents either the outpoint or script we should
// detect the spend of.
SpendRequest
// Event contains references to the channels that the notifications are
// to be sent over.
Event *SpendEvent
// HeightHint is the earliest height in the chain that we expect to find
// the spending transaction of the specified outpoint/output script.
// This value will be overridden by the spend hint cache if it contains
// an entry for it.
HeightHint uint32
// dispatched signals whether a spend notification has been dispatched
// to the client.
dispatched bool
}
// HistoricalSpendDispatch parametrizes a manual rescan to determine the
// spending details (if any) of an outpoint/output script. The parameters
// include the start and end block heights specifying the range of blocks to
// scan.
type HistoricalSpendDispatch struct {
// SpendRequest represents either the outpoint or script we should
// detect the spend of.
SpendRequest
// StartHeight specified the block height at which to begin the
// historical rescan.
StartHeight uint32
// EndHeight specifies the last block height (inclusive) that the
// historical rescan should consider.
EndHeight uint32
}
// SpendRegistration encompasses all of the information required for callers to
// retrieve details about a spend event.
type SpendRegistration struct {
// Event contains references to the channels that the notifications are
// to be sent over.
Event *SpendEvent
// HistoricalDispatch, if non-nil, signals to the client who registered
// the notification that they are responsible for attempting to manually
// rescan blocks for the txid/output script between the start and end
// heights.
HistoricalDispatch *HistoricalSpendDispatch
// Height is the height of the TxNotifier at the time the spend
// notification was registered. This can be used so that backends can
// request to be notified of spends from this point forwards.
Height uint32
}
// TxNotifier is a struct responsible for delivering transaction notifications
// to subscribers. These notifications can be of two different types:
// transaction/output script confirmations and/or outpoint/output script spends.
// The TxNotifier will watch the blockchain as new blocks come in, in order to
// satisfy its client requests.
type TxNotifier struct {
confClientCounter uint64 // To be used atomically.
spendClientCounter uint64 // To be used atomically.
// currentHeight is the height of the tracked blockchain. It is used to
// determine the number of confirmations a tx has and ensure blocks are
// connected and disconnected in order.
currentHeight uint32
// reorgSafetyLimit is the chain depth beyond which it is assumed a
// block will not be reorganized out of the chain. This is used to
// determine when to prune old notification requests so that reorgs are
// handled correctly. The coinbase maturity period is a reasonable value
// to use.
reorgSafetyLimit uint32
// reorgDepth is the depth of a chain organization that this system is
// being informed of. This is incremented as long as a sequence of
// blocks are disconnected without being interrupted by a new block.
reorgDepth uint32
// confNotifications is an index of confirmation notification requests
// by transaction hash/output script.
confNotifications map[ConfRequest]*confNtfnSet
// confsByInitialHeight is an index of watched transactions/output
// scripts by the height that they are included at in the chain. This
// is tracked so that incorrect notifications are not sent if a
// transaction/output script is reorged out of the chain and so that
// negative confirmations can be recognized.
confsByInitialHeight map[uint32]map[ConfRequest]struct{}
// ntfnsByConfirmHeight is an index of notification requests by the
// height at which the transaction/output script will have sufficient
// confirmations.
ntfnsByConfirmHeight map[uint32]map[*ConfNtfn]struct{}
// spendNotifications is an index of all active notification requests
// per outpoint/output script.
spendNotifications map[SpendRequest]*spendNtfnSet
// spendsByHeight is an index that keeps tracks of the spending height
// of outpoints/output scripts we are currently tracking notifications
// for. This is used in order to recover from spending transactions
// being reorged out of the chain.
spendsByHeight map[uint32]map[SpendRequest]struct{}
// confirmHintCache is a cache used to maintain the latest height hints
// for transactions/output scripts. Each height hint represents the
// earliest height at which they scripts could have been confirmed
// within the chain.
confirmHintCache ConfirmHintCache
// spendHintCache is a cache used to maintain the latest height hints
// for outpoints/output scripts. Each height hint represents the
// earliest height at which they could have been spent within the chain.
spendHintCache SpendHintCache
// quit is closed in order to signal that the notifier is gracefully
// exiting.
quit chan struct{}
sync.Mutex
}
// NewTxNotifier creates a TxNotifier. The current height of the blockchain is
// accepted as a parameter. The different hint caches (confirm and spend) are
// used as an optimization in order to retrieve a better starting point when
// dispatching a rescan for a historical event in the chain.
func NewTxNotifier(startHeight uint32, reorgSafetyLimit uint32,
confirmHintCache ConfirmHintCache,
spendHintCache SpendHintCache) *TxNotifier {
return &TxNotifier{
currentHeight: startHeight,
reorgSafetyLimit: reorgSafetyLimit,
confNotifications: make(map[ConfRequest]*confNtfnSet),
confsByInitialHeight: make(map[uint32]map[ConfRequest]struct{}),
ntfnsByConfirmHeight: make(map[uint32]map[*ConfNtfn]struct{}),
spendNotifications: make(map[SpendRequest]*spendNtfnSet),
spendsByHeight: make(map[uint32]map[SpendRequest]struct{}),
confirmHintCache: confirmHintCache,
spendHintCache: spendHintCache,
quit: make(chan struct{}),
}
}
// newConfNtfn validates all of the parameters required to successfully create
// and register a confirmation notification.
func (n *TxNotifier) newConfNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts *notifierOptions) (*ConfNtfn, error) {
// An accompanying output script must always be provided.
if len(pkScript) == 0 {
return nil, ErrNoScript
}
// Enforce that we will not dispatch confirmations beyond the reorg
// safety limit.
if numConfs == 0 || numConfs > n.reorgSafetyLimit {
return nil, ErrNumConfsOutOfRange
}
// A height hint must be provided to prevent scanning from the genesis
// block.
if heightHint == 0 {
return nil, ErrNoHeightHint
}
// Ensure the output script is of a supported type.
confRequest, err := NewConfRequest(txid, pkScript)
if err != nil {
return nil, err
}
confID := atomic.AddUint64(&n.confClientCounter, 1)
return &ConfNtfn{
ConfID: confID,
ConfRequest: confRequest,
NumConfirmations: numConfs,
Event: NewConfirmationEvent(numConfs, func() {
n.CancelConf(confRequest, confID)
}),
HeightHint: heightHint,
includeBlock: opts.includeBlock,
numConfsLeft: numConfs,
}, nil
}
// RegisterConf handles a new confirmation notification request. The client will
// be notified when the transaction/output script gets a sufficient number of
// confirmations in the blockchain.
//
// NOTE: If the transaction/output script has already been included in a block
// on the chain, the confirmation details must be provided with the
// UpdateConfDetails method, otherwise we will wait for the transaction/output
// script to confirm even though it already has.
func (n *TxNotifier) RegisterConf(txid *chainhash.Hash, pkScript []byte,
numConfs, heightHint uint32,
optFuncs ...NotifierOption) (*ConfRegistration, error) {
select {
case <-n.quit:
return nil, ErrTxNotifierExiting
default:
}
opts := defaultNotifierOptions()
for _, optFunc := range optFuncs {
optFunc(opts)
}
// We'll start by performing a series of validation checks.
ntfn, err := n.newConfNtfn(txid, pkScript, numConfs, heightHint, opts)
if err != nil {
return nil, err
}
// Before proceeding to register the notification, we'll query our
// height hint cache to determine whether a better one exists.
//
// TODO(conner): verify that all submitted height hints are identical.
startHeight := ntfn.HeightHint
hint, err := n.confirmHintCache.QueryConfirmHint(ntfn.ConfRequest)
if err == nil {
if hint > startHeight {
Log.Debugf("Using height hint %d retrieved from cache "+
"for %v instead of %d for conf subscription",
hint, ntfn.ConfRequest, startHeight)
startHeight = hint
}
} else if err != ErrConfirmHintNotFound {
Log.Errorf("Unable to query confirm hint for %v: %v",
ntfn.ConfRequest, err)
}
Log.Infof("New confirmation subscription: conf_id=%d, %v, "+
"num_confs=%v height_hint=%d", ntfn.ConfID, ntfn.ConfRequest,
numConfs, startHeight)
n.Lock()
defer n.Unlock()
confSet, ok := n.confNotifications[ntfn.ConfRequest]
if !ok {
// If this is the first registration for this request, construct
// a confSet to coalesce all notifications for the same request.
confSet = newConfNtfnSet()
n.confNotifications[ntfn.ConfRequest] = confSet
}
confSet.ntfns[ntfn.ConfID] = ntfn
switch confSet.rescanStatus {
// A prior rescan has already completed and we are actively watching at
// tip for this request.
case rescanComplete:
// If the confirmation details for this set of notifications has
// already been found, we'll attempt to deliver them immediately
// to this client.
Log.Debugf("Attempting to dispatch confirmation for %v on "+
"registration since rescan has finished, conf_id=%v",
ntfn.ConfRequest, ntfn.ConfID)
// The default notification we assigned above includes the
// block along with the rest of the details. However not all
// clients want the block, so we make a copy here w/o the block
// if needed so we can give clients only what they ask for.
confDetails := confSet.details
if !ntfn.includeBlock && confDetails != nil {
confDetailsCopy := *confDetails
confDetailsCopy.Block = nil
confDetails = &confDetailsCopy
}
// Deliver the details to the whole conf set where this ntfn
// lives in.
for _, subscriber := range confSet.ntfns {
err := n.dispatchConfDetails(subscriber, confDetails)
if err != nil {
return nil, err
}
}
return &ConfRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
// A rescan is already in progress, return here to prevent dispatching
// another. When the rescan returns, this notification's details will be
// updated as well.
case rescanPending:
Log.Debugf("Waiting for pending rescan to finish before "+
"notifying %v at tip", ntfn.ConfRequest)
return &ConfRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
// If no rescan has been dispatched, attempt to do so now.
case rescanNotStarted:
}
// If the provided or cached height hint indicates that the
// transaction with the given txid/output script is to be confirmed at a
// height greater than the notifier's current height, we'll refrain from
// spawning a historical dispatch.
if startHeight > n.currentHeight {
Log.Debugf("Height hint is above current height, not "+
"dispatching historical confirmation rescan for %v",
ntfn.ConfRequest)
// Set the rescan status to complete, which will allow the
// notifier to start delivering messages for this set
// immediately.
confSet.rescanStatus = rescanComplete
return &ConfRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
}
Log.Debugf("Dispatching historical confirmation rescan for %v",
ntfn.ConfRequest)
// Construct the parameters for historical dispatch, scanning the range
// of blocks between our best known height hint and the notifier's
// current height. The notifier will begin also watching for
// confirmations at tip starting with the next block.
dispatch := &HistoricalConfDispatch{
ConfRequest: ntfn.ConfRequest,
StartHeight: startHeight,
EndHeight: n.currentHeight,
}
// Set this confSet's status to pending, ensuring subsequent
// registrations don't also attempt a dispatch.
confSet.rescanStatus = rescanPending
return &ConfRegistration{
Event: ntfn.Event,
HistoricalDispatch: dispatch,
Height: n.currentHeight,
}, nil
}
// CancelConf cancels an existing request for a spend notification of an
// outpoint/output script. The request is identified by its spend ID.
func (n *TxNotifier) CancelConf(confRequest ConfRequest, confID uint64) {
select {
case <-n.quit:
return
default:
}
n.Lock()
defer n.Unlock()
confSet, ok := n.confNotifications[confRequest]
if !ok {
return
}
ntfn, ok := confSet.ntfns[confID]
if !ok {
return
}
Log.Debugf("Canceling confirmation notification: conf_id=%d, %v",
confID, confRequest)
// We'll close all the notification channels to let the client know
// their cancel request has been fulfilled.
close(ntfn.Event.Confirmed)
close(ntfn.Event.Updates)
close(ntfn.Event.NegativeConf)
// Finally, we'll clean up any lingering references to this
// notification.
delete(confSet.ntfns, confID)
// Remove the queued confirmation notification if the transaction has
// already confirmed, but hasn't met its required number of
// confirmations.
if confSet.details != nil {
confHeight := confSet.details.BlockHeight +
ntfn.NumConfirmations - 1
delete(n.ntfnsByConfirmHeight[confHeight], ntfn)
}
}
// UpdateConfDetails attempts to update the confirmation details for an active
// notification within the notifier. This should only be used in the case of a
// transaction/output script that has confirmed before the notifier's current
// height.
//
// NOTE: The notification should be registered first to ensure notifications are
// dispatched correctly.
func (n *TxNotifier) UpdateConfDetails(confRequest ConfRequest,
details *TxConfirmation) error {
select {
case <-n.quit:
return ErrTxNotifierExiting
default:
}
// Ensure we hold the lock throughout handling the notification to
// prevent the notifier from advancing its height underneath us.
n.Lock()
defer n.Unlock()
// First, we'll determine whether we have an active confirmation
// notification for the given txid/script.
confSet, ok := n.confNotifications[confRequest]
if !ok {
return fmt.Errorf("confirmation notification for %v not found",
confRequest)
}
// If the confirmation details were already found at tip, all existing
// notifications will have been dispatched or queued for dispatch. We
// can exit early to avoid sending too many notifications on the
// buffered channels.
if confSet.details != nil {
return nil
}
// The historical dispatch has been completed for this confSet. We'll
// update the rescan status and cache any details that were found. If
// the details are nil, that implies we did not find them and will
// continue to watch for them at tip.
confSet.rescanStatus = rescanComplete
// The notifier has yet to reach the height at which the
// transaction/output script was included in a block, so we should defer
// until handling it then within ConnectTip.
if details == nil {
Log.Debugf("Confirmation details for %v not found during "+
"historical dispatch, waiting to dispatch at tip",
confRequest)
// We'll commit the current height as the confirm hint to
// prevent another potentially long rescan if we restart before
// a new block comes in.
err := n.confirmHintCache.CommitConfirmHint(
n.currentHeight, confRequest,
)
if err != nil {
// The error is not fatal as this is an optimistic
// optimization, so we'll avoid returning an error.
Log.Debugf("Unable to update confirm hint to %d for "+
"%v: %v", n.currentHeight, confRequest, err)
}
return nil
}
if details.BlockHeight > n.currentHeight {
Log.Debugf("Confirmation details for %v found above current "+
"height, waiting to dispatch at tip", confRequest)
return nil
}
Log.Debugf("Updating confirmation details for %v", confRequest)
err := n.confirmHintCache.CommitConfirmHint(
details.BlockHeight, confRequest,
)
if err != nil {
// The error is not fatal, so we should not return an error to
// the caller.
Log.Errorf("Unable to update confirm hint to %d for %v: %v",
details.BlockHeight, confRequest, err)
}
// Cache the details found in the rescan and attempt to dispatch any
// notifications that have not yet been delivered.
confSet.details = details
for _, ntfn := range confSet.ntfns {
// The default notification we assigned above includes the
// block along with the rest of the details. However not all
// clients want the block, so we make a copy here w/o the block
// if needed so we can give clients only what they ask for.
confDetails := *details
if !ntfn.includeBlock {
confDetails.Block = nil
}
err = n.dispatchConfDetails(ntfn, &confDetails)
if err != nil {
return err
}
}
return nil
}
// dispatchConfDetails attempts to cache and dispatch details to a particular
// client if the transaction/output script has sufficiently confirmed. If the
// provided details are nil, this method will be a no-op.
func (n *TxNotifier) dispatchConfDetails(
ntfn *ConfNtfn, details *TxConfirmation) error {
// If there are no conf details to dispatch or if the notification has
// already been dispatched, then we can skip dispatching to this
// client.
if details == nil {
Log.Debugf("Skipped dispatching nil conf details for request "+
"%v, conf_id=%v", ntfn.ConfRequest, ntfn.ConfID)
return nil
}
if ntfn.dispatched {
Log.Debugf("Skipped dispatched conf details for request %v "+
"conf_id=%v", ntfn.ConfRequest, ntfn.ConfID)
return nil
}
// Now, we'll examine whether the transaction/output script of this
// request has reached its required number of confirmations. If it has,
// we'll dispatch a confirmation notification to the caller.
confHeight := details.BlockHeight + ntfn.NumConfirmations - 1
if confHeight <= n.currentHeight {
Log.Debugf("Dispatching %v confirmation notification for "+
"conf_id=%v, %v", ntfn.NumConfirmations, ntfn.ConfID,
ntfn.ConfRequest)
// We'll send a 0 value to the Updates channel,
// indicating that the transaction/output script has already
// been confirmed.
err := n.notifyNumConfsLeft(ntfn, 0)
if err != nil {
return err
}
select {
case ntfn.Event.Confirmed <- details:
ntfn.dispatched = true
case <-n.quit:
return ErrTxNotifierExiting
}
} else {
Log.Debugf("Queueing %v confirmation notification for %v at "+
"tip", ntfn.NumConfirmations, ntfn.ConfRequest)
// Otherwise, we'll keep track of the notification
// request by the height at which we should dispatch the
// confirmation notification.
ntfnSet, exists := n.ntfnsByConfirmHeight[confHeight]
if !exists {
ntfnSet = make(map[*ConfNtfn]struct{})
n.ntfnsByConfirmHeight[confHeight] = ntfnSet
}
ntfnSet[ntfn] = struct{}{}
// We'll also send an update to the client of how many
// confirmations are left for the transaction/output script to
// be confirmed.
numConfsLeft := confHeight - n.currentHeight
err := n.notifyNumConfsLeft(ntfn, numConfsLeft)
if err != nil {
return err
}
}
// As a final check, we'll also watch the transaction/output script if
// it's still possible for it to get reorged out of the chain.
reorgSafeHeight := details.BlockHeight + n.reorgSafetyLimit
if reorgSafeHeight > n.currentHeight {
txSet, exists := n.confsByInitialHeight[details.BlockHeight]
if !exists {
txSet = make(map[ConfRequest]struct{})
n.confsByInitialHeight[details.BlockHeight] = txSet
}
txSet[ntfn.ConfRequest] = struct{}{}
}
return nil
}
// newSpendNtfn validates all of the parameters required to successfully create
// and register a spend notification.
func (n *TxNotifier) newSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*SpendNtfn, error) {
// An accompanying output script must always be provided.
if len(pkScript) == 0 {
return nil, ErrNoScript
}
// A height hint must be provided to prevent scanning from the genesis
// block.
if heightHint == 0 {
return nil, ErrNoHeightHint
}
// Ensure the output script is of a supported type.
spendRequest, err := NewSpendRequest(outpoint, pkScript)
if err != nil {
return nil, err
}
spendID := atomic.AddUint64(&n.spendClientCounter, 1)
return &SpendNtfn{
SpendID: spendID,
SpendRequest: spendRequest,
Event: NewSpendEvent(func() {
n.CancelSpend(spendRequest, spendID)
}),
HeightHint: heightHint,
}, nil
}
// RegisterSpend handles a new spend notification request. The client will be
// notified once the outpoint/output script is detected as spent within the
// chain.
//
// NOTE: If the outpoint/output script has already been spent within the chain
// before the notifier's current tip, the spend details must be provided with
// the UpdateSpendDetails method, otherwise we will wait for the outpoint/output
// script to be spent at tip, even though it already has.
func (n *TxNotifier) RegisterSpend(outpoint *wire.OutPoint, pkScript []byte,
heightHint uint32) (*SpendRegistration, error) {
select {
case <-n.quit:
return nil, ErrTxNotifierExiting
default:
}
// We'll start by performing a series of validation checks.
ntfn, err := n.newSpendNtfn(outpoint, pkScript, heightHint)
if err != nil {
return nil, err
}
// Before proceeding to register the notification, we'll query our spend
// hint cache to determine whether a better one exists.
startHeight := ntfn.HeightHint
hint, err := n.spendHintCache.QuerySpendHint(ntfn.SpendRequest)
if err == nil {
if hint > startHeight {
Log.Debugf("Using height hint %d retrieved from cache "+
"for %v instead of %d for spend subscription",
hint, ntfn.SpendRequest, startHeight)
startHeight = hint
}
} else if err != ErrSpendHintNotFound {
Log.Errorf("Unable to query spend hint for %v: %v",
ntfn.SpendRequest, err)
}
n.Lock()
defer n.Unlock()
Log.Debugf("New spend subscription: spend_id=%d, %v, height_hint=%d",
ntfn.SpendID, ntfn.SpendRequest, startHeight)
// Keep track of the notification request so that we can properly
// dispatch a spend notification later on.
spendSet, ok := n.spendNotifications[ntfn.SpendRequest]
if !ok {
// If this is the first registration for the request, we'll
// construct a spendNtfnSet to coalesce all notifications.
spendSet = newSpendNtfnSet()
n.spendNotifications[ntfn.SpendRequest] = spendSet
}
spendSet.ntfns[ntfn.SpendID] = ntfn
// We'll now let the caller know whether a historical rescan is needed
// depending on the current rescan status.
switch spendSet.rescanStatus {
// If the spending details for this request have already been determined
// and cached, then we can use them to immediately dispatch the spend
// notification to the client.
case rescanComplete:
Log.Debugf("Attempting to dispatch spend for %v on "+
"registration since rescan has finished",
ntfn.SpendRequest)
err := n.dispatchSpendDetails(ntfn, spendSet.details)
if err != nil {
return nil, err
}
return &SpendRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
// If there is an active rescan to determine whether the request has
// been spent, then we won't trigger another one.
case rescanPending:
Log.Debugf("Waiting for pending rescan to finish before "+
"notifying %v at tip", ntfn.SpendRequest)
return &SpendRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
// Otherwise, we'll fall through and let the caller know that a rescan
// should be dispatched to determine whether the request has already
// been spent.
case rescanNotStarted:
}
// However, if the spend hint, either provided by the caller or
// retrieved from the cache, is found to be at a later height than the
// TxNotifier is aware of, then we'll refrain from dispatching a
// historical rescan and wait for the spend to come in at tip.
if startHeight > n.currentHeight {
Log.Debugf("Spend hint of %d for %v is above current height %d",
startHeight, ntfn.SpendRequest, n.currentHeight)
// We'll also set the rescan status as complete to ensure that
// spend hints for this request get updated upon
// connected/disconnected blocks.
spendSet.rescanStatus = rescanComplete
return &SpendRegistration{
Event: ntfn.Event,
HistoricalDispatch: nil,
Height: n.currentHeight,
}, nil
}
// We'll set the rescan status to pending to ensure subsequent
// notifications don't also attempt a historical dispatch.
spendSet.rescanStatus = rescanPending
Log.Debugf("Dispatching historical spend rescan for %v, start=%d, "+
"end=%d", ntfn.SpendRequest, startHeight, n.currentHeight)
return &SpendRegistration{
Event: ntfn.Event,
HistoricalDispatch: &HistoricalSpendDispatch{
SpendRequest: ntfn.SpendRequest,
StartHeight: startHeight,
EndHeight: n.currentHeight,
},
Height: n.currentHeight,
}, nil
}
// CancelSpend cancels an existing request for a spend notification of an
// outpoint/output script. The request is identified by its spend ID.
func (n *TxNotifier) CancelSpend(spendRequest SpendRequest, spendID uint64) {
select {
case <-n.quit:
return
default:
}
n.Lock()
defer n.Unlock()
spendSet, ok := n.spendNotifications[spendRequest]
if !ok {
return
}
ntfn, ok := spendSet.ntfns[spendID]
if !ok {
return
}
Log.Debugf("Canceling spend notification: spend_id=%d, %v", spendID,
spendRequest)
// We'll close all the notification channels to let the client know
// their cancel request has been fulfilled.
close(ntfn.Event.Spend)
close(ntfn.Event.Reorg)
close(ntfn.Event.Done)
delete(spendSet.ntfns, spendID)
}
// ProcessRelevantSpendTx processes a transaction provided externally. This will
// check whether the transaction is relevant to the notifier if it spends any
// outpoints/output scripts for which we currently have registered notifications
// for. If it is relevant, spend notifications will be dispatched to the caller.
func (n *TxNotifier) ProcessRelevantSpendTx(tx *btcutil.Tx,
blockHeight uint32) error {
select {
case <-n.quit:
return ErrTxNotifierExiting
default:
}
// Ensure we hold the lock throughout handling the notification to
// prevent the notifier from advancing its height underneath us.
n.Lock()
defer n.Unlock()
// We'll use a channel to coalesce all the spend requests that this
// transaction fulfills.
type spend struct {
request *SpendRequest
details *SpendDetail
}
// We'll set up the onSpend filter callback to gather all the fulfilled
// spends requests within this transaction.
var spends []spend
onSpend := func(request SpendRequest, details *SpendDetail) {
spends = append(spends, spend{&request, details})
}
n.filterTx(nil, tx, blockHeight, nil, onSpend)
// After the transaction has been filtered, we can finally dispatch
// notifications for each request.
for _, spend := range spends {
err := n.updateSpendDetails(*spend.request, spend.details)
if err != nil {
return err
}
}
return nil
}
// UpdateSpendDetails attempts to update the spend details for all active spend
// notification requests for an outpoint/output script. This method should be
// used once a historical scan of the chain has finished. If the historical scan
// did not find a spending transaction for it, the spend details may be nil.
//
// NOTE: A notification request for the outpoint/output script must be
// registered first to ensure notifications are delivered.
func (n *TxNotifier) UpdateSpendDetails(spendRequest SpendRequest,
details *SpendDetail) error {
select {
case <-n.quit:
return ErrTxNotifierExiting
default:
}
// Ensure we hold the lock throughout handling the notification to
// prevent the notifier from advancing its height underneath us.
n.Lock()
defer n.Unlock()
return n.updateSpendDetails(spendRequest, details)
}
// updateSpendDetails attempts to update the spend details for all active spend
// notification requests for an outpoint/output script. This method should be
// used once a historical scan of the chain has finished. If the historical scan
// did not find a spending transaction for it, the spend details may be nil.
//
// NOTE: This method must be called with the TxNotifier's lock held.
func (n *TxNotifier) updateSpendDetails(spendRequest SpendRequest,
details *SpendDetail) error {
// Mark the ongoing historical rescan for this request as finished. This
// will allow us to update the spend hints for it at tip.
spendSet, ok := n.spendNotifications[spendRequest]
if !ok {
return fmt.Errorf("spend notification for %v not found",
spendRequest)
}
// If the spend details have already been found either at tip, then the
// notifications should have already been dispatched, so we can exit
// early to prevent sending duplicate notifications.
if spendSet.details != nil {
return nil
}
// Since the historical rescan has completed for this request, we'll
// mark its rescan status as complete in order to ensure that the
// TxNotifier can properly update its spend hints upon
// connected/disconnected blocks.
spendSet.rescanStatus = rescanComplete
// If the historical rescan was not able to find a spending transaction
// for this request, then we can track the spend at tip.
if details == nil {
// We'll commit the current height as the spend hint to prevent
// another potentially long rescan if we restart before a new
// block comes in.
err := n.spendHintCache.CommitSpendHint(
n.currentHeight, spendRequest,
)
if err != nil {
// The error is not fatal as this is an optimistic
// optimization, so we'll avoid returning an error.
Log.Debugf("Unable to update spend hint to %d for %v: %v",
n.currentHeight, spendRequest, err)
}
Log.Debugf("Updated spend hint to height=%v for unconfirmed "+
"spend request %v", n.currentHeight, spendRequest)
return nil
}
// Return an error if the witness data is not present in the spending
// transaction.
//
// NOTE: if the witness stack is empty, we will do a critical log which
// shuts down the node.
if !details.HasSpenderWitness() {
Log.Criticalf("Found spending tx for outpoint=%v, but the "+
"transaction %v does not have witness",
spendRequest.OutPoint, details.SpendingTx.TxHash())
return ErrEmptyWitnessStack
}
// If the historical rescan found the spending transaction for this
// request, but it's at a later height than the notifier (this can
// happen due to latency with the backend during a reorg), then we'll
// defer handling the notification until the notifier has caught up to
// such height.
if uint32(details.SpendingHeight) > n.currentHeight {
return nil
}
// Now that we've determined the request has been spent, we'll commit
// its spending height as its hint in the cache and dispatch
// notifications to all of its respective clients.
err := n.spendHintCache.CommitSpendHint(
uint32(details.SpendingHeight), spendRequest,
)
if err != nil {
// The error is not fatal as this is an optimistic optimization,
// so we'll avoid returning an error.
Log.Debugf("Unable to update spend hint to %d for %v: %v",
details.SpendingHeight, spendRequest, err)
}
Log.Debugf("Updated spend hint to height=%v for confirmed spend "+
"request %v", details.SpendingHeight, spendRequest)
spendSet.details = details
for _, ntfn := range spendSet.ntfns {
err := n.dispatchSpendDetails(ntfn, spendSet.details)
if err != nil {
return err
}
}
return nil
}
// dispatchSpendDetails dispatches a spend notification to the client.
//
// NOTE: This must be called with the TxNotifier's lock held.
func (n *TxNotifier) dispatchSpendDetails(ntfn *SpendNtfn, details *SpendDetail) error {
// If there are no spend details to dispatch or if the notification has
// already been dispatched, then we can skip dispatching to this client.
if details == nil || ntfn.dispatched {
Log.Debugf("Skipping dispatch of spend details(%v) for "+
"request %v, dispatched=%v", details, ntfn.SpendRequest,
ntfn.dispatched)
return nil
}
Log.Debugf("Dispatching confirmed spend notification for %v at "+
"current height=%d: %v", ntfn.SpendRequest, n.currentHeight,
details)
select {
case ntfn.Event.Spend <- details:
ntfn.dispatched = true
case <-n.quit:
return ErrTxNotifierExiting
}
spendHeight := uint32(details.SpendingHeight)
// We also add to spendsByHeight to notify on chain reorgs.
reorgSafeHeight := spendHeight + n.reorgSafetyLimit
if reorgSafeHeight > n.currentHeight {
txSet, exists := n.spendsByHeight[spendHeight]
if !exists {
txSet = make(map[SpendRequest]struct{})
n.spendsByHeight[spendHeight] = txSet
}
txSet[ntfn.SpendRequest] = struct{}{}
}
return nil
}
// ConnectTip handles a new block extending the current chain. It will go
// through every transaction and determine if it is relevant to any of its
// clients. A transaction can be relevant in either of the following two ways:
//
// 1. One of the inputs in the transaction spends an outpoint/output script
// for which we currently have an active spend registration for.
//
// 2. The transaction has a txid or output script for which we currently have
// an active confirmation registration for.
//
// In the event that the transaction is relevant, a confirmation/spend
// notification will be queued for dispatch to the relevant clients.
// Confirmation notifications will only be dispatched for transactions/output
// scripts that have met the required number of confirmations required by the
// client.
//
// NOTE: In order to actually dispatch the relevant transaction notifications to
// clients, NotifyHeight must be called with the same block height in order to
// maintain correctness.
func (n *TxNotifier) ConnectTip(block *btcutil.Block,
blockHeight uint32) error {
select {
case <-n.quit:
return ErrTxNotifierExiting
default:
}
n.Lock()
defer n.Unlock()
if blockHeight != n.currentHeight+1 {
return fmt.Errorf("received blocks out of order: "+
"current height=%d, new height=%d",
n.currentHeight, blockHeight)
}
n.currentHeight++
n.reorgDepth = 0
// First, we'll iterate over all the transactions found in this block to
// determine if it includes any relevant transactions to the TxNotifier.
if block != nil {
Log.Debugf("Filtering %d txns for %d spend requests at "+
"height %d", len(block.Transactions()),
len(n.spendNotifications), blockHeight)
for _, tx := range block.Transactions() {
n.filterTx(
block, tx, blockHeight,
n.handleConfDetailsAtTip,
n.handleSpendDetailsAtTip,
)
}
}
// Now that we've determined which requests were confirmed and spent
// within the new block, we can update their entries in their respective
// caches, along with all of our unconfirmed and unspent requests.
n.updateHints(blockHeight)
// Finally, we'll clear the entries from our set of notifications for
// requests that are no longer under the risk of being reorged out of
// the chain.
if blockHeight >= n.reorgSafetyLimit {
matureBlockHeight := blockHeight - n.reorgSafetyLimit
for confRequest := range n.confsByInitialHeight[matureBlockHeight] {
confSet := n.confNotifications[confRequest]
for _, ntfn := range confSet.ntfns {
select {
case ntfn.Event.Done <- struct{}{}:
case <-n.quit:
return ErrTxNotifierExiting
}
}
delete(n.confNotifications, confRequest)
}
delete(n.confsByInitialHeight, matureBlockHeight)
for spendRequest := range n.spendsByHeight[matureBlockHeight] {
spendSet := n.spendNotifications[spendRequest]
for _, ntfn := range spendSet.ntfns {
select {
case ntfn.Event.Done <- struct{}{}:
case <-n.quit:
return ErrTxNotifierExiting
}
}
Log.Debugf("Deleting mature spend request %v at "+
"height=%d", spendRequest, blockHeight)
delete(n.spendNotifications, spendRequest)
}
delete(n.spendsByHeight, matureBlockHeight)
}
return nil
}
// filterTx determines whether the transaction spends or confirms any
// outstanding pending requests. The onConf and onSpend callbacks can be used to
// retrieve all the requests fulfilled by this transaction as they occur.
func (n *TxNotifier) filterTx(block *btcutil.Block, tx *btcutil.Tx,
blockHeight uint32, onConf func(ConfRequest, *TxConfirmation),
onSpend func(SpendRequest, *SpendDetail)) {
// In order to determine if this transaction is relevant to the
// notifier, we'll check its inputs for any outstanding spend
// requests.
txHash := tx.Hash()
if onSpend != nil {
// notifyDetails is a helper closure that will construct the
// spend details of a request and hand them off to the onSpend
// callback.
notifyDetails := func(spendRequest SpendRequest,
prevOut wire.OutPoint, inputIdx uint32) {
Log.Debugf("Found spend of %v: spend_tx=%v, "+
"block_height=%d", spendRequest, txHash,
blockHeight)
onSpend(spendRequest, &SpendDetail{
SpentOutPoint: &prevOut,
SpenderTxHash: txHash,
SpendingTx: tx.MsgTx(),
SpenderInputIndex: inputIdx,
SpendingHeight: int32(blockHeight),
})
}
for i, txIn := range tx.MsgTx().TxIn {
// We'll re-derive the script of the output being spent
// to determine if the inputs spends any registered
// requests.
prevOut := txIn.PreviousOutPoint
pkScript, err := txscript.ComputePkScript(
txIn.SignatureScript, txIn.Witness,
)
if err != nil {
continue
}
spendRequest := SpendRequest{
OutPoint: prevOut,
PkScript: pkScript,
}
// If we have any, we'll record their spend height so
// that notifications get dispatched to the respective
// clients.
if _, ok := n.spendNotifications[spendRequest]; ok {
notifyDetails(spendRequest, prevOut, uint32(i))
}
// Now try with an empty taproot key pkScript, since we
// cannot derive the spent pkScript directly from the
// witness. But we have the outpoint, which should be
// enough.
spendRequest.PkScript = ZeroTaprootPkScript
if _, ok := n.spendNotifications[spendRequest]; ok {
notifyDetails(spendRequest, prevOut, uint32(i))
}
// Restore the pkScript but try with a zero outpoint
// instead (won't be possible for Taproot).
spendRequest.PkScript = pkScript
spendRequest.OutPoint = ZeroOutPoint
if _, ok := n.spendNotifications[spendRequest]; ok {
notifyDetails(spendRequest, prevOut, uint32(i))
}
}
}
// We'll also check its outputs to determine if there are any
// outstanding confirmation requests.
if onConf != nil {
// notifyDetails is a helper closure that will construct the
// confirmation details of a request and hand them off to the
// onConf callback.
notifyDetails := func(confRequest ConfRequest) {
Log.Debugf("Found initial confirmation of %v: "+
"height=%d, hash=%v", confRequest,
blockHeight, block.Hash())
details := &TxConfirmation{
Tx: tx.MsgTx(),
BlockHash: block.Hash(),
BlockHeight: blockHeight,
TxIndex: uint32(tx.Index()),
Block: block.MsgBlock(),
}
onConf(confRequest, details)
}
for _, txOut := range tx.MsgTx().TxOut {
// We'll parse the script of the output to determine if
// we have any registered requests for it or the
// transaction itself.
pkScript, err := txscript.ParsePkScript(txOut.PkScript)
if err != nil {
continue
}
confRequest := ConfRequest{
TxID: *txHash,
PkScript: pkScript,
}
// If we have any, we'll record their confirmed height
// so that notifications get dispatched when they
// reaches the clients' desired number of confirmations.
if _, ok := n.confNotifications[confRequest]; ok {
notifyDetails(confRequest)
}
confRequest.TxID = ZeroHash
if _, ok := n.confNotifications[confRequest]; ok {
notifyDetails(confRequest)
}
}
}
}
// handleConfDetailsAtTip tracks the confirmation height of the txid/output
// script in order to properly dispatch a confirmation notification after
// meeting each request's desired number of confirmations for all current and
// future registered clients.
func (n *TxNotifier) handleConfDetailsAtTip(confRequest ConfRequest,
details *TxConfirmation) {
// TODO(wilmer): cancel pending historical rescans if any?
confSet := n.confNotifications[confRequest]
// If we already have details for this request, we don't want to add it
// again since we have already dispatched notifications for it.
if confSet.details != nil {
Log.Warnf("Ignoring address reuse for %s at height %d.",
confRequest, details.BlockHeight)
return
}
confSet.rescanStatus = rescanComplete
confSet.details = details
for _, ntfn := range confSet.ntfns {
// In the event that this notification was aware that the
// transaction/output script was reorged out of the chain, we'll
// consume the reorg notification if it hasn't been done yet
// already.
select {
case <-ntfn.Event.NegativeConf:
default:
}
// We'll note this client's required number of confirmations so
// that we can notify them when expected.
confHeight := details.BlockHeight + ntfn.NumConfirmations - 1
ntfnSet, exists := n.ntfnsByConfirmHeight[confHeight]
if !exists {
ntfnSet = make(map[*ConfNtfn]struct{})
n.ntfnsByConfirmHeight[confHeight] = ntfnSet
}
ntfnSet[ntfn] = struct{}{}
}
// We'll also note the initial confirmation height in order to correctly
// handle dispatching notifications when the transaction/output script
// gets reorged out of the chain.
txSet, exists := n.confsByInitialHeight[details.BlockHeight]
if !exists {
txSet = make(map[ConfRequest]struct{})
n.confsByInitialHeight[details.BlockHeight] = txSet
}
txSet[confRequest] = struct{}{}
}
// handleSpendDetailsAtTip tracks the spend height of the outpoint/output script
// in order to properly dispatch a spend notification for all current and future
// registered clients.
func (n *TxNotifier) handleSpendDetailsAtTip(spendRequest SpendRequest,
details *SpendDetail) {
// TODO(wilmer): cancel pending historical rescans if any?
spendSet := n.spendNotifications[spendRequest]
spendSet.rescanStatus = rescanComplete
spendSet.details = details
for _, ntfn := range spendSet.ntfns {
// In the event that this notification was aware that the
// spending transaction of its outpoint/output script was
// reorged out of the chain, we'll consume the reorg
// notification if it hasn't been done yet already.
select {
case <-ntfn.Event.Reorg:
default:
}
}
// We'll note the spending height of the request in order to correctly
// handle dispatching notifications when the spending transactions gets
// reorged out of the chain.
spendHeight := uint32(details.SpendingHeight)
opSet, exists := n.spendsByHeight[spendHeight]
if !exists {
opSet = make(map[SpendRequest]struct{})
n.spendsByHeight[spendHeight] = opSet
}
opSet[spendRequest] = struct{}{}
Log.Debugf("Spend request %v spent at tip=%d", spendRequest,
spendHeight)
}
// NotifyHeight dispatches confirmation and spend notifications to the clients
// who registered for a notification which has been fulfilled at the passed
// height.
func (n *TxNotifier) NotifyHeight(height uint32) error {
n.Lock()
defer n.Unlock()
// First, we'll dispatch an update to all of the notification clients
// for our watched requests with the number of confirmations left at
// this new height.
for _, confRequests := range n.confsByInitialHeight {
for confRequest := range confRequests {
confSet := n.confNotifications[confRequest]
for _, ntfn := range confSet.ntfns {
txConfHeight := confSet.details.BlockHeight +
ntfn.NumConfirmations - 1
numConfsLeft := txConfHeight - height
// Since we don't clear notifications until
// transactions/output scripts are no longer
// under the risk of being reorganized out of
// the chain, we'll skip sending updates for
// those that have already been confirmed.
if int32(numConfsLeft) < 0 {
continue
}
err := n.notifyNumConfsLeft(ntfn, numConfsLeft)
if err != nil {
return err
}
}
}
}
// Then, we'll dispatch notifications for all the requests that have
// become confirmed at this new block height.
for ntfn := range n.ntfnsByConfirmHeight[height] {
confSet := n.confNotifications[ntfn.ConfRequest]
// The default notification we assigned above includes the
// block along with the rest of the details. However not all
// clients want the block, so we make a copy here w/o the block
// if needed so we can give clients only what they ask for.
confDetails := *confSet.details
if !ntfn.includeBlock {
confDetails.Block = nil
}
// If the `confDetails` has already been sent before, we'll
// skip it and continue processing the next one.
if ntfn.dispatched {
Log.Debugf("Skipped dispatched conf details for "+
"request %v conf_id=%v", ntfn.ConfRequest,
ntfn.ConfID)
continue
}
Log.Debugf("Dispatching %v confirmation notification for "+
"conf_id=%v, %v", ntfn.NumConfirmations, ntfn.ConfID,
ntfn.ConfRequest)
select {
case ntfn.Event.Confirmed <- &confDetails:
ntfn.dispatched = true
case <-n.quit:
return ErrTxNotifierExiting
}
}
delete(n.ntfnsByConfirmHeight, height)
// Finally, we'll dispatch spend notifications for all the requests that
// were spent at this new block height.
for spendRequest := range n.spendsByHeight[height] {
spendSet := n.spendNotifications[spendRequest]
for _, ntfn := range spendSet.ntfns {
err := n.dispatchSpendDetails(ntfn, spendSet.details)
if err != nil {
return err
}
}
}
return nil
}
// DisconnectTip handles the tip of the current chain being disconnected during
// a chain reorganization. If any watched requests were included in this block,
// internal structures are updated to ensure confirmation/spend notifications
// are consumed (if not already), and reorg notifications are dispatched
// instead. Confirmation/spend notifications will be dispatched again upon block
// inclusion.
func (n *TxNotifier) DisconnectTip(blockHeight uint32) error {
select {
case <-n.quit:
return ErrTxNotifierExiting
default:
}
n.Lock()
defer n.Unlock()
if blockHeight != n.currentHeight {
return fmt.Errorf("received blocks out of order: "+
"current height=%d, disconnected height=%d",
n.currentHeight, blockHeight)
}
n.currentHeight--
n.reorgDepth++
// With the block disconnected, we'll update the confirm and spend hints
// for our notification requests to reflect the new height, except for
// those that have confirmed/spent at previous heights.
n.updateHints(blockHeight)
// We'll go through all of our watched confirmation requests and attempt
// to drain their notification channels to ensure sending notifications
// to the clients is always non-blocking.
for initialHeight, txHashes := range n.confsByInitialHeight {
for txHash := range txHashes {
// If the transaction/output script has been reorged out
// of the chain, we'll make sure to remove the cached
// confirmation details to prevent notifying clients
// with old information.
confSet := n.confNotifications[txHash]
if initialHeight == blockHeight {
confSet.details = nil
}
for _, ntfn := range confSet.ntfns {
// First, we'll attempt to drain an update
// from each notification to ensure sends to the
// Updates channel are always non-blocking.
select {
case <-ntfn.Event.Updates:
case <-n.quit:
return ErrTxNotifierExiting
default:
}
// We also reset the num of confs update.
ntfn.numConfsLeft = ntfn.NumConfirmations
// Then, we'll check if the current
// transaction/output script was included in the
// block currently being disconnected. If it
// was, we'll need to dispatch a reorg
// notification to the client.
if initialHeight == blockHeight {
err := n.dispatchConfReorg(
ntfn, blockHeight,
)
if err != nil {
return err
}
}
}
}
}
// We'll also go through our watched spend requests and attempt to drain
// their dispatched notifications to ensure dispatching notifications to
// clients later on is always non-blocking. We're only interested in
// requests whose spending transaction was included at the height being
// disconnected.
for op := range n.spendsByHeight[blockHeight] {
// Since the spending transaction is being reorged out of the
// chain, we'll need to clear out the spending details of the
// request.
spendSet := n.spendNotifications[op]
spendSet.details = nil
// For all requests which have had a spend notification
// dispatched, we'll attempt to drain it and send a reorg
// notification instead.
for _, ntfn := range spendSet.ntfns {
if err := n.dispatchSpendReorg(ntfn); err != nil {
return err
}
}
}
// Finally, we can remove the requests that were confirmed and/or spent
// at the height being disconnected. We'll still continue to track them
// until they have been confirmed/spent and are no longer under the risk
// of being reorged out of the chain again.
delete(n.confsByInitialHeight, blockHeight)
delete(n.spendsByHeight, blockHeight)
return nil
}
// updateHints attempts to update the confirm and spend hints for all relevant
// requests respectively. The height parameter is used to determine which
// requests we should update based on whether a new block is being
// connected/disconnected.
//
// NOTE: This must be called with the TxNotifier's lock held and after its
// height has already been reflected by a block being connected/disconnected.
func (n *TxNotifier) updateHints(height uint32) {
// TODO(wilmer): update under one database transaction.
//
// To update the height hint for all the required confirmation requests
// under one database transaction, we'll gather the set of unconfirmed
// requests along with the ones that confirmed at the height being
// connected/disconnected.
confRequests := n.unconfirmedRequests()
for confRequest := range n.confsByInitialHeight[height] {
confRequests = append(confRequests, confRequest)
}
err := n.confirmHintCache.CommitConfirmHint(
n.currentHeight, confRequests...,
)
if err != nil {
// The error is not fatal as this is an optimistic optimization,
// so we'll avoid returning an error.
Log.Debugf("Unable to update confirm hints to %d for "+
"%v: %v", n.currentHeight, confRequests, err)
}
// Similarly, to update the height hint for all the required spend
// requests under one database transaction, we'll gather the set of
// unspent requests along with the ones that were spent at the height
// being connected/disconnected.
spendRequests := n.unspentRequests()
for spendRequest := range n.spendsByHeight[height] {
spendRequests = append(spendRequests, spendRequest)
}
err = n.spendHintCache.CommitSpendHint(n.currentHeight, spendRequests...)
if err != nil {
// The error is not fatal as this is an optimistic optimization,
// so we'll avoid returning an error.
Log.Debugf("Unable to update spend hints to %d for "+
"%v: %v", n.currentHeight, spendRequests, err)
}
}
// unconfirmedRequests returns the set of confirmation requests that are
// still seen as unconfirmed by the TxNotifier.
//
// NOTE: This method must be called with the TxNotifier's lock held.
func (n *TxNotifier) unconfirmedRequests() []ConfRequest {
var unconfirmed []ConfRequest
for confRequest, confNtfnSet := range n.confNotifications {
// If the notification is already aware of its confirmation
// details, or it's in the process of learning them, we'll skip
// it as we can't yet determine if it's confirmed or not.
if confNtfnSet.rescanStatus != rescanComplete ||
confNtfnSet.details != nil {
continue
}
unconfirmed = append(unconfirmed, confRequest)
}
return unconfirmed
}
// unspentRequests returns the set of spend requests that are still seen as
// unspent by the TxNotifier.
//
// NOTE: This method must be called with the TxNotifier's lock held.
func (n *TxNotifier) unspentRequests() []SpendRequest {
var unspent []SpendRequest
for spendRequest, spendNtfnSet := range n.spendNotifications {
// If the notification is already aware of its spend details, or
// it's in the process of learning them, we'll skip it as we
// can't yet determine if it's unspent or not.
if spendNtfnSet.rescanStatus != rescanComplete ||
spendNtfnSet.details != nil {
continue
}
unspent = append(unspent, spendRequest)
}
return unspent
}
// dispatchConfReorg dispatches a reorg notification to the client if the
// confirmation notification was already delivered.
//
// NOTE: This must be called with the TxNotifier's lock held.
func (n *TxNotifier) dispatchConfReorg(ntfn *ConfNtfn,
heightDisconnected uint32) error {
// If the request's confirmation notification has yet to be dispatched,
// we'll need to clear its entry within the ntfnsByConfirmHeight index
// to prevent from notifying the client once the notifier reaches the
// confirmation height.
if !ntfn.dispatched {
confHeight := heightDisconnected + ntfn.NumConfirmations - 1
ntfnSet, exists := n.ntfnsByConfirmHeight[confHeight]
if exists {
delete(ntfnSet, ntfn)
}
return nil
}
// Otherwise, the entry within the ntfnsByConfirmHeight has already been
// deleted, so we'll attempt to drain the confirmation notification to
// ensure sends to the Confirmed channel are always non-blocking.
select {
case <-ntfn.Event.Confirmed:
case <-n.quit:
return ErrTxNotifierExiting
default:
}
ntfn.dispatched = false
// Send a negative confirmation notification to the client indicating
// how many blocks have been disconnected successively.
select {
case ntfn.Event.NegativeConf <- int32(n.reorgDepth):
case <-n.quit:
return ErrTxNotifierExiting
}
return nil
}
// dispatchSpendReorg dispatches a reorg notification to the client if a spend
// notiification was already delivered.
//
// NOTE: This must be called with the TxNotifier's lock held.
func (n *TxNotifier) dispatchSpendReorg(ntfn *SpendNtfn) error {
if !ntfn.dispatched {
return nil
}
// Attempt to drain the spend notification to ensure sends to the Spend
// channel are always non-blocking.
select {
case <-ntfn.Event.Spend:
default:
}
// Send a reorg notification to the client in order for them to
// correctly handle reorgs.
select {
case ntfn.Event.Reorg <- struct{}{}:
case <-n.quit:
return ErrTxNotifierExiting
}
ntfn.dispatched = false
return nil
}
// TearDown is to be called when the owner of the TxNotifier is exiting. This
// closes the event channels of all registered notifications that have not been
// dispatched yet.
func (n *TxNotifier) TearDown() {
close(n.quit)
n.Lock()
defer n.Unlock()
for _, confSet := range n.confNotifications {
for confID, ntfn := range confSet.ntfns {
close(ntfn.Event.Confirmed)
close(ntfn.Event.Updates)
close(ntfn.Event.NegativeConf)
close(ntfn.Event.Done)
delete(confSet.ntfns, confID)
}
}
for _, spendSet := range n.spendNotifications {
for spendID, ntfn := range spendSet.ntfns {
close(ntfn.Event.Spend)
close(ntfn.Event.Reorg)
close(ntfn.Event.Done)
delete(spendSet.ntfns, spendID)
}
}
}
// notifyNumConfsLeft sends the number of confirmations left to the
// notification subscriber through the Event.Updates channel.
//
// NOTE: must be used with the TxNotifier's lock held.
func (n *TxNotifier) notifyNumConfsLeft(ntfn *ConfNtfn, num uint32) error {
// If the number left is no less than the recorded value, we can skip
// sending it as it means this same value has already been sent before.
if num >= ntfn.numConfsLeft {
Log.Debugf("Skipped dispatched update (numConfsLeft=%v) for "+
"request %v conf_id=%v", num, ntfn.ConfRequest,
ntfn.ConfID)
return nil
}
// Update the number of confirmations left to the notification.
ntfn.numConfsLeft = num
select {
case ntfn.Event.Updates <- num:
case <-n.quit:
return ErrTxNotifierExiting
}
return nil
}
package channeldb
import (
"errors"
"net"
"github.com/btcsuite/btcd/btcec/v2"
)
// AddrSource is an interface that allow us to get the addresses for a target
// node. It may combine the results of multiple address sources.
type AddrSource interface {
// AddrsForNode returns all known addresses for the target node public
// key. The returned boolean must indicate if the given node is unknown
// to the backing source.
AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error)
}
// multiAddrSource is an implementation of AddrSource which gathers all the
// known addresses for a given node from multiple backends and de-duplicates the
// results.
type multiAddrSource struct {
sources []AddrSource
}
// NewMultiAddrSource constructs a new AddrSource which will query all the
// provided sources for a node's addresses and will then de-duplicate the
// results.
func NewMultiAddrSource(sources ...AddrSource) AddrSource {
return &multiAddrSource{
sources: sources,
}
}
// AddrsForNode returns all known addresses for the target node public key. It
// queries all the address sources provided and de-duplicates the results. The
// returned boolean is false only if none of the backing sources know of the
// node.
//
// NOTE: this implements the AddrSource interface.
func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
[]net.Addr, error) {
if len(c.sources) == 0 {
return false, nil, errors.New("no address sources")
}
// The multiple address sources will likely contain duplicate addresses,
// so we use a map here to de-dup them.
dedupedAddrs := make(map[string]net.Addr)
// known will be set to true if any backing source is aware of the node.
var known bool
// Iterate over all the address sources and query each one for the
// addresses it has for the node in question.
for _, src := range c.sources {
isKnown, addrs, err := src.AddrsForNode(nodePub)
if err != nil {
return false, nil, err
}
if isKnown {
known = true
}
for _, addr := range addrs {
dedupedAddrs[addr.String()] = addr
}
}
// Convert the map into a list we can return.
addrs := make([]net.Addr, 0, len(dedupedAddrs))
for _, addr := range dedupedAddrs {
addrs = append(addrs, addr)
}
return known, addrs, nil
}
// A compile-time check to ensure that multiAddrSource implements the AddrSource
// interface.
var _ AddrSource = (*multiAddrSource)(nil)
package channeldb
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// AbsoluteThawHeightThreshold is the threshold at which a thaw height
// begins to be interpreted as an absolute block height, rather than a
// relative one.
AbsoluteThawHeightThreshold uint32 = 500000
// HTLCBlindingPointTLV is the tlv type used for storing blinding
// points with HTLCs.
HTLCBlindingPointTLV tlv.Type = 0
)
var (
// closedChannelBucket stores summarization information concerning
// previously open, but now closed channels.
closedChannelBucket = []byte("closed-chan-bucket")
// openChannelBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
//
// openChan -> nodeID -> chanPoint
//
// TODO(roasbeef): flesh out comment
openChannelBucket = []byte("open-chan-bucket")
// outpointBucket stores all of our channel outpoints and a tlv
// stream containing channel data.
//
// outpoint -> tlv stream.
//
outpointBucket = []byte("outpoint-bucket")
// chanIDBucket stores all of the 32-byte channel ID's we know about.
// These could be derived from outpointBucket, but it is more
// convenient to have these in their own bucket.
//
// chanID -> tlv stream.
//
chanIDBucket = []byte("chan-id-bucket")
// historicalChannelBucket stores all channels that have seen their
// commitment tx confirm. All information from their previous open state
// is retained.
historicalChannelBucket = []byte("historical-chan-bucket")
// chanInfoKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores all the static
// information for a channel which is decided at the end of the
// funding flow.
chanInfoKey = []byte("chan-info-key")
// localUpfrontShutdownKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores an optional upfront
// shutdown script for the local peer.
localUpfrontShutdownKey = []byte("local-upfront-shutdown-key")
// remoteUpfrontShutdownKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores an optional upfront
// shutdown script for the remote peer.
remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key")
// chanCommitmentKey can be accessed within the sub-bucket for a
// particular channel. This key stores the up to date commitment state
// for a particular channel party. Appending a 0 to the end of this key
// indicates it's the commitment for the local party, and appending a 1
// to the end of this key indicates it's the commitment for the remote
// party.
chanCommitmentKey = []byte("chan-commitment-key")
// unsignedAckedUpdatesKey is an entry in the channel bucket that
// contains the remote updates that we have acked, but not yet signed
// for in one of our remote commits.
unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key")
// remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that
// contains the local updates that the remote party has acked, but
// has not yet signed for in one of their local commits.
remoteUnsignedLocalUpdatesKey = []byte("remote-unsigned-local-updates-key")
// revocationStateKey stores their current revocation hash, our
// preimage producer and their preimage store.
revocationStateKey = []byte("revocation-state-key")
// dataLossCommitPointKey stores the commitment point received from the
// remote peer during a channel sync in case we have lost channel state.
dataLossCommitPointKey = []byte("data-loss-commit-point-key")
// forceCloseTxKey points to a the unilateral closing tx that we
// broadcasted when moving the channel to state CommitBroadcasted.
forceCloseTxKey = []byte("closing-tx-key")
// coopCloseTxKey points to a the cooperative closing tx that we
// broadcasted when moving the channel to state CoopBroadcasted.
coopCloseTxKey = []byte("coop-closing-tx-key")
// shutdownInfoKey points to the serialised shutdown info that has been
// persisted for a channel. The existence of this info means that we
// have sent the Shutdown message before and so should re-initiate the
// shutdown on re-establish.
shutdownInfoKey = []byte("shutdown-info-key")
// commitDiffKey stores the current pending commitment state we've
// extended to the remote party (if any). Each time we propose a new
// state, we store the information necessary to reconstruct this state
// from the prior commitment. This allows us to resync the remote party
// to their expected state in the case of message loss.
//
// TODO(roasbeef): rename to commit chain?
commitDiffKey = []byte("commit-diff-key")
// frozenChanKey is the key where we store the information for any
// active "frozen" channels. This key is present only in the leaf
// bucket for a given channel.
frozenChanKey = []byte("frozen-chans")
// lastWasRevokeKey is a key that stores true when the last update we
// sent was a revocation and false when it was a commitment signature.
// This is nil in the case of new channels with no updates exchanged.
lastWasRevokeKey = []byte("last-was-revoke")
// finalHtlcsBucket contains the htlcs that have been resolved
// definitively. Within this bucket, there is a sub-bucket for each
// channel. In each channel bucket, the htlc indices are stored along
// with final outcome.
//
// final-htlcs -> chanID -> htlcIndex -> outcome
//
// 'outcome' is a byte value that encodes:
//
// | true false
// ------+------------------
// bit 0 | settled failed
// bit 1 | offchain onchain
//
// This bucket is positioned at the root level, because its contents
// will be kept independent of the channel lifecycle. This is to avoid
// the situation where a channel force-closes autonomously and the user
// not being able to query for htlc outcomes anymore.
finalHtlcsBucket = []byte("final-htlcs")
)
var (
// ErrNoCommitmentsFound is returned when a channel has not set
// commitment states.
ErrNoCommitmentsFound = fmt.Errorf("no commitments found")
// ErrNoChanInfoFound is returned when a particular channel does not
// have any channels state.
ErrNoChanInfoFound = fmt.Errorf("no chan info found")
// ErrNoRevocationsFound is returned when revocation state for a
// particular channel cannot be found.
ErrNoRevocationsFound = fmt.Errorf("no revocations found")
// ErrNoPendingCommit is returned when there is not a pending
// commitment for a remote party. A new commitment is written to disk
// each time we write a new state in order to be properly fault
// tolerant.
ErrNoPendingCommit = fmt.Errorf("no pending commits found")
// ErrNoCommitPoint is returned when no data loss commit point is found
// in the database.
ErrNoCommitPoint = fmt.Errorf("no commit point found")
// ErrNoCloseTx is returned when no closing tx is found for a channel
// in the state CommitBroadcasted.
ErrNoCloseTx = fmt.Errorf("no closing tx found")
// ErrNoShutdownInfo is returned when no shutdown info has been
// persisted for a channel.
ErrNoShutdownInfo = errors.New("no shutdown info")
// ErrNoRestoredChannelMutation is returned when a caller attempts to
// mutate a channel that's been recovered.
ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " +
"channel state")
// ErrChanBorked is returned when a caller attempts to mutate a borked
// channel.
ErrChanBorked = fmt.Errorf("cannot mutate borked channel")
// ErrMissingIndexEntry is returned when a caller attempts to close a
// channel and the outpoint is missing from the index.
ErrMissingIndexEntry = fmt.Errorf("missing outpoint from index")
// ErrOnionBlobLength is returned is an onion blob with incorrect
// length is read from disk.
ErrOnionBlobLength = errors.New("onion blob < 1366 bytes")
)
const (
// A tlv type definition used to serialize an outpoint's indexStatus
// for use in the outpoint index.
indexStatusType tlv.Type = 0
)
// openChannelTlvData houses the new data fields that are stored for each
// channel in a TLV stream within the root bucket. This is stored as a TLV
// stream appended to the existing hard-coded fields in the channel's root
// bucket. New fields being added to the channel state should be added here.
//
// NOTE: This struct is used for serialization purposes only and its fields
// should be accessed via the OpenChannel struct while in memory.
type openChannelTlvData struct {
// revokeKeyLoc is the key locator for the revocation key.
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]
// initialLocalBalance is the initial local balance of the channel.
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]
// initialRemoteBalance is the initial remote balance of the channel.
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]
// realScid is the real short channel ID of the channel corresponding to
// the on-chain outpoint.
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]
// memo is an optional text field that gives context to the user about
// the channel.
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]
// tapscriptRoot is the optional Tapscript root the channel funding
// output commits to.
tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte]
// customBlob is an optional TLV encoded blob of data representing
// custom channel funding information.
customBlob tlv.OptionalRecordT[tlv.TlvType7, tlv.Blob]
}
// encode serializes the openChannelTlvData to the given io.Writer.
func (c *openChannelTlvData) encode(w io.Writer) error {
tlvRecords := []tlv.Record{
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
}
c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record())
})
c.tapscriptRoot.WhenSome(
func(root tlv.RecordT[tlv.TlvType6, [32]byte]) {
tlvRecords = append(tlvRecords, root.Record())
},
)
c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType7, tlv.Blob]) {
tlvRecords = append(tlvRecords, blob.Record())
})
// Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// decode deserializes the openChannelTlvData from the given io.Reader.
func (c *openChannelTlvData) decode(r io.Reader) error {
memo := c.memo.Zero()
tapscriptRoot := c.tapscriptRoot.Zero()
blob := c.customBlob.Zero()
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
memo.Record(),
tapscriptRoot.Record(),
blob.Record(),
)
if err != nil {
return err
}
tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if _, ok := tlvs[memo.TlvType()]; ok {
c.memo = tlv.SomeRecordT(memo)
}
if _, ok := tlvs[tapscriptRoot.TlvType()]; ok {
c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot)
}
if _, ok := tlvs[c.customBlob.TlvType()]; ok {
c.customBlob = tlv.SomeRecordT(blob)
}
return nil
}
// indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values.
type indexStatus uint8
const (
// outpointOpen represents an outpoint that is open in the outpoint index.
outpointOpen indexStatus = 0
// outpointClosed represents an outpoint that is closed in the outpoint
// index.
outpointClosed indexStatus = 1
)
// ChannelType is an enum-like type that describes one of several possible
// channel types. Each open channel is associated with a particular type as the
// channel type may determine how higher level operations are conducted such as
// fee negotiation, channel closing, the format of HTLCs, etc. Structure-wise,
// a ChannelType is a bit field, with each bit denoting a modification from the
// base channel type of single funder.
type ChannelType uint64
const (
// NOTE: iota isn't used here for this enum needs to be stable
// long-term as it will be persisted to the database.
// SingleFunderBit represents a channel wherein one party solely funds
// the entire capacity of the channel.
SingleFunderBit ChannelType = 0
// DualFunderBit represents a channel wherein both parties contribute
// funds towards the total capacity of the channel. The channel may be
// funded symmetrically or asymmetrically.
DualFunderBit ChannelType = 1 << 0
// SingleFunderTweaklessBit is similar to the basic SingleFunder channel
// type, but it omits the tweak for one's key in the commitment
// transaction of the remote party.
SingleFunderTweaklessBit ChannelType = 1 << 1
// NoFundingTxBit denotes if we have the funding transaction locally on
// disk. This bit may be on if the funding transaction was crafted by a
// wallet external to the primary daemon.
NoFundingTxBit ChannelType = 1 << 2
// AnchorOutputsBit indicates that the channel makes use of anchor
// outputs to bump the commitment transaction's effective feerate. This
// channel type also uses a delayed to_remote output script.
AnchorOutputsBit ChannelType = 1 << 3
// FrozenBit indicates that the channel is a frozen channel, meaning
// that only the responder can decide to cooperatively close the
// channel.
FrozenBit ChannelType = 1 << 4
// ZeroHtlcTxFeeBit indicates that the channel should use zero-fee
// second-level HTLC transactions.
ZeroHtlcTxFeeBit ChannelType = 1 << 5
// LeaseExpirationBit indicates that the channel has been leased for a
// period of time, constraining every output that pays to the channel
// initiator with an additional CLTV of the lease maturity.
LeaseExpirationBit ChannelType = 1 << 6
// ZeroConfBit indicates that the channel is a zero-conf channel.
ZeroConfBit ChannelType = 1 << 7
// ScidAliasChanBit indicates that the channel has negotiated the
// scid-alias channel type.
ScidAliasChanBit ChannelType = 1 << 8
// ScidAliasFeatureBit indicates that the scid-alias feature bit was
// negotiated during the lifetime of this channel.
ScidAliasFeatureBit ChannelType = 1 << 9
// SimpleTaprootFeatureBit indicates that the simple-taproot-chans
// feature bit was negotiated during the lifetime of the channel.
SimpleTaprootFeatureBit ChannelType = 1 << 10
// TapscriptRootBit indicates that this is a MuSig2 channel with a top
// level tapscript commitment. This MUST be set along with the
// SimpleTaprootFeatureBit.
TapscriptRootBit ChannelType = 1 << 11
)
// IsSingleFunder returns true if the channel type if one of the known single
// funder variants.
func (c ChannelType) IsSingleFunder() bool {
return c&DualFunderBit == 0
}
// IsDualFunder returns true if the ChannelType has the DualFunderBit set.
func (c ChannelType) IsDualFunder() bool {
return c&DualFunderBit == DualFunderBit
}
// IsTweakless returns true if the target channel uses a commitment that
// doesn't tweak the key for the remote party.
func (c ChannelType) IsTweakless() bool {
return c&SingleFunderTweaklessBit == SingleFunderTweaklessBit
}
// HasFundingTx returns true if this channel type is one that has a funding
// transaction stored locally.
func (c ChannelType) HasFundingTx() bool {
return c&NoFundingTxBit == 0
}
// HasAnchors returns true if this channel type has anchor outputs on its
// commitment.
func (c ChannelType) HasAnchors() bool {
return c&AnchorOutputsBit == AnchorOutputsBit
}
// ZeroHtlcTxFee returns true if this channel type uses second-level HTLC
// transactions signed with zero-fee.
func (c ChannelType) ZeroHtlcTxFee() bool {
return c&ZeroHtlcTxFeeBit == ZeroHtlcTxFeeBit
}
// IsFrozen returns true if the channel is considered to be "frozen". A frozen
// channel means that only the responder can initiate a cooperative channel
// closure.
func (c ChannelType) IsFrozen() bool {
return c&FrozenBit == FrozenBit
}
// HasLeaseExpiration returns true if the channel originated from a lease.
func (c ChannelType) HasLeaseExpiration() bool {
return c&LeaseExpirationBit == LeaseExpirationBit
}
// HasZeroConf returns true if the channel is a zero-conf channel.
func (c ChannelType) HasZeroConf() bool {
return c&ZeroConfBit == ZeroConfBit
}
// HasScidAliasChan returns true if the scid-alias channel type was negotiated.
func (c ChannelType) HasScidAliasChan() bool {
return c&ScidAliasChanBit == ScidAliasChanBit
}
// HasScidAliasFeature returns true if the scid-alias feature bit was
// negotiated during the lifetime of this channel.
func (c ChannelType) HasScidAliasFeature() bool {
return c&ScidAliasFeatureBit == ScidAliasFeatureBit
}
// IsTaproot returns true if the channel is using taproot features.
func (c ChannelType) IsTaproot() bool {
return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit
}
// HasTapscriptRoot returns true if the channel is using a top level tapscript
// root commitment.
func (c ChannelType) HasTapscriptRoot() bool {
return c&TapscriptRootBit == TapscriptRootBit
}
// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel
// that are responsible for providing bounds on the state space of the abstract
// channel state. These values must be remembered for normal channel operation
// but they do not impact how we compute the commitment transactions themselves.
type ChannelStateBounds struct {
// ChanReserve is an absolute reservation on the channel for the
// owner of this set of constraints. This means that the current
// settled balance for this node CANNOT dip below the reservation
// amount. This acts as a defense against costless attacks when
// either side no longer has any skin in the game.
ChanReserve btcutil.Amount
// MaxPendingAmount is the maximum pending HTLC value that the
// owner of these constraints can offer the remote node at a
// particular time.
MaxPendingAmount lnwire.MilliSatoshi
// MinHTLC is the minimum HTLC value that the owner of these
// constraints can offer the remote node. If any HTLCs below this
// amount are offered, then the HTLC will be rejected. This, in
// tandem with the dust limit allows a node to regulate the
// smallest HTLC that it deems economically relevant.
MinHTLC lnwire.MilliSatoshi
// MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of
// this set of constraints can offer the remote node. This allows each
// node to limit their over all exposure to HTLCs that may need to be
// acted upon in the case of a unilateral channel closure or a contract
// breach.
MaxAcceptedHtlcs uint16
}
// CommitmentParams are the parameters from OpenChannel and
// AcceptChannel that are required to render an abstract channel state to a
// concrete commitment transaction. These values are necessary to (re)compute
// the commitment transaction. We treat these differently than the state space
// bounds because their history needs to be stored in order to properly handle
// chain resolution.
type CommitmentParams struct {
// DustLimit is the threshold (in satoshis) below which any outputs
// should be trimmed. When an output is trimmed, it isn't materialized
// as an actual output, but is instead burned to miner's fees.
DustLimit btcutil.Amount
// CsvDelay is the relative time lock delay expressed in blocks. Any
// settled outputs that pay to the owner of this channel configuration
// MUST ensure that the delay branch uses this value as the relative
// time lock. Similarly, any HTLC's offered by this node should use
// this value as well.
CsvDelay uint16
}
// ChannelConfig is a struct that houses the various configuration opens for
// channels. Each side maintains an instance of this configuration file as it
// governs: how the funding and commitment transaction to be created, the
// nature of HTLC's allotted, the keys to be used for delivery, and relative
// time lock parameters.
type ChannelConfig struct {
// ChannelStateBounds is the set of constraints that must be
// upheld for the duration of the channel for the owner of this channel
// configuration. Constraints govern a number of flow control related
// parameters, also including the smallest HTLC that will be accepted
// by a participant.
ChannelStateBounds
// CommitmentParams is an embedding of the parameters
// required to render an abstract channel state into a concrete
// commitment transaction.
CommitmentParams
// MultiSigKey is the key to be used within the 2-of-2 output script
// for the owner of this channel config.
MultiSigKey keychain.KeyDescriptor
// RevocationBasePoint is the base public key to be used when deriving
// revocation keys for the remote node's commitment transaction. This
// will be combined along with a per commitment secret to derive a
// unique revocation key for each state.
RevocationBasePoint keychain.KeyDescriptor
// PaymentBasePoint is the base public key to be used when deriving
// the key used within the non-delayed pay-to-self output on the
// commitment transaction for a node. This will be combined with a
// tweak derived from the per-commitment point to ensure unique keys
// for each commitment transaction.
PaymentBasePoint keychain.KeyDescriptor
// DelayBasePoint is the base public key to be used when deriving the
// key used within the delayed pay-to-self output on the commitment
// transaction for a node. This will be combined with a tweak derived
// from the per-commitment point to ensure unique keys for each
// commitment transaction.
DelayBasePoint keychain.KeyDescriptor
// HtlcBasePoint is the base public key to be used when deriving the
// local HTLC key. The derived key (combined with the tweak derived
// from the per-commitment point) is used within the "to self" clause
// within any HTLC output scripts.
HtlcBasePoint keychain.KeyDescriptor
}
// commitTlvData stores all the optional data that may be stored as a TLV stream
// at the _end_ of the normal serialized commit on disk.
type commitTlvData struct {
// customBlob is a custom blob that may store extra data for custom
// channels.
customBlob tlv.OptionalRecordT[tlv.TlvType1, tlv.Blob]
}
// encode encodes the aux data into the passed io.Writer.
func (c *commitTlvData) encode(w io.Writer) error {
var tlvRecords []tlv.Record
c.customBlob.WhenSome(func(blob tlv.RecordT[tlv.TlvType1, tlv.Blob]) {
tlvRecords = append(tlvRecords, blob.Record())
})
// Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// decode attempts to decode the aux data from the passed io.Reader.
func (c *commitTlvData) decode(r io.Reader) error {
blob := c.customBlob.Zero()
tlvStream, err := tlv.NewStream(
blob.Record(),
)
if err != nil {
return err
}
tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if _, ok := tlvs[c.customBlob.TlvType()]; ok {
c.customBlob = tlv.SomeRecordT(blob)
}
return nil
}
// ChannelCommitment is a snapshot of the commitment state at a particular
// point in the commitment chain. With each state transition, a snapshot of the
// current state along with all non-settled HTLCs are recorded. These snapshots
// detail the state of the _remote_ party's commitment at a particular state
// number. For ourselves (the local node) we ONLY store our most recent
// (unrevoked) state for safety purposes.
type ChannelCommitment struct {
// CommitHeight is the update number that this ChannelDelta represents
// the total number of commitment updates to this point. This can be
// viewed as sort of a "commitment height" as this number is
// monotonically increasing.
CommitHeight uint64
// LocalLogIndex is the cumulative log index index of the local node at
// this point in the commitment chain. This value will be incremented
// for each _update_ added to the local update log.
LocalLogIndex uint64
// LocalHtlcIndex is the current local running HTLC index. This value
// will be incremented for each outgoing HTLC the local node offers.
LocalHtlcIndex uint64
// RemoteLogIndex is the cumulative log index index of the remote node
// at this point in the commitment chain. This value will be
// incremented for each _update_ added to the remote update log.
RemoteLogIndex uint64
// RemoteHtlcIndex is the current remote running HTLC index. This value
// will be incremented for each outgoing HTLC the remote node offers.
RemoteHtlcIndex uint64
// LocalBalance is the current available settled balance within the
// channel directly spendable by us.
//
// NOTE: This is the balance *after* subtracting any commitment fee,
// AND anchor output values.
LocalBalance lnwire.MilliSatoshi
// RemoteBalance is the current available settled balance within the
// channel directly spendable by the remote node.
//
// NOTE: This is the balance *after* subtracting any commitment fee,
// AND anchor output values.
RemoteBalance lnwire.MilliSatoshi
// CommitFee is the amount calculated to be paid in fees for the
// current set of commitment transactions. The fee amount is persisted
// with the channel in order to allow the fee amount to be removed and
// recalculated with each channel state update, including updates that
// happen after a system restart.
CommitFee btcutil.Amount
// FeePerKw is the min satoshis/kilo-weight that should be paid within
// the commitment transaction for the entire duration of the channel's
// lifetime. This field may be updated during normal operation of the
// channel as on-chain conditions change.
//
// TODO(halseth): make this SatPerKWeight. Cannot be done atm because
// this will cause the import cycle lnwallet<->channeldb. Fee
// estimation stuff should be in its own package.
FeePerKw btcutil.Amount
// CommitTx is the latest version of the commitment state, broadcast
// able by us.
CommitTx *wire.MsgTx
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This may track some custom
// specific state for this given commitment.
CustomBlob fn.Option[tlv.Blob]
// CommitSig is one half of the signature required to fully complete
// the script for the commitment transaction above. This is the
// signature signed by the remote party for our version of the
// commitment transactions.
CommitSig []byte
// Htlcs is the set of HTLC's that are pending at this particular
// commitment height.
Htlcs []HTLC
}
// amendTlvData updates the channel with the given auxiliary TLV data.
func (c *ChannelCommitment) amendTlvData(auxData commitTlvData) {
auxData.customBlob.WhenSomeV(func(blob tlv.Blob) {
c.CustomBlob = fn.Some(blob)
})
}
// extractTlvData creates a new commitTlvData from the given commitment.
func (c *ChannelCommitment) extractTlvData() commitTlvData {
var auxData commitTlvData
c.CustomBlob.WhenSome(func(blob tlv.Blob) {
auxData.customBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType1](blob),
)
})
return auxData
}
// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in
// the default usable state, or a state where it shouldn't be used.
type ChannelStatus uint64
var (
// ChanStatusDefault is the normal state of an open channel.
ChanStatusDefault ChannelStatus
// ChanStatusBorked indicates that the channel has entered an
// irreconcilable state, triggered by a state desynchronization or
// channel breach. Channels in this state should never be added to the
// htlc switch.
ChanStatusBorked ChannelStatus = 1
// ChanStatusCommitBroadcasted indicates that a commitment for this
// channel has been broadcasted.
ChanStatusCommitBroadcasted ChannelStatus = 1 << 1
// ChanStatusLocalDataLoss indicates that we have lost channel state
// for this channel, and broadcasting our latest commitment might be
// considered a breach.
//
// TODO(halseh): actually enforce that we are not force closing such a
// channel.
ChanStatusLocalDataLoss ChannelStatus = 1 << 2
// ChanStatusRestored is a status flag that signals that the channel
// has been restored, and doesn't have all the fields a typical channel
// will have.
ChanStatusRestored ChannelStatus = 1 << 3
// ChanStatusCoopBroadcasted indicates that a cooperative close for
// this channel has been broadcasted. Older cooperatively closed
// channels will only have this status set. Newer ones will also have
// close initiator information stored using the local/remote initiator
// status. This status is set in conjunction with the initiator status
// so that we do not need to check multiple channel statues for
// cooperative closes.
ChanStatusCoopBroadcasted ChannelStatus = 1 << 4
// ChanStatusLocalCloseInitiator indicates that we initiated closing
// the channel.
ChanStatusLocalCloseInitiator ChannelStatus = 1 << 5
// ChanStatusRemoteCloseInitiator indicates that the remote node
// initiated closing the channel.
ChanStatusRemoteCloseInitiator ChannelStatus = 1 << 6
)
// chanStatusStrings maps a ChannelStatus to a human friendly string that
// describes that status.
var chanStatusStrings = map[ChannelStatus]string{
ChanStatusDefault: "ChanStatusDefault",
ChanStatusBorked: "ChanStatusBorked",
ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted",
ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss",
ChanStatusRestored: "ChanStatusRestored",
ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted",
ChanStatusLocalCloseInitiator: "ChanStatusLocalCloseInitiator",
ChanStatusRemoteCloseInitiator: "ChanStatusRemoteCloseInitiator",
}
// orderedChanStatusFlags is an in-order list of all that channel status flags.
var orderedChanStatusFlags = []ChannelStatus{
ChanStatusBorked,
ChanStatusCommitBroadcasted,
ChanStatusLocalDataLoss,
ChanStatusRestored,
ChanStatusCoopBroadcasted,
ChanStatusLocalCloseInitiator,
ChanStatusRemoteCloseInitiator,
}
// String returns a human-readable representation of the ChannelStatus.
func (c ChannelStatus) String() string {
// If no flags are set, then this is the default case.
if c == ChanStatusDefault {
return chanStatusStrings[ChanStatusDefault]
}
// Add individual bit flags.
statusStr := ""
for _, flag := range orderedChanStatusFlags {
if c&flag == flag {
statusStr += chanStatusStrings[flag] + "|"
c -= flag
}
}
// Remove anything to the right of the final bar, including it as well.
statusStr = strings.TrimRight(statusStr, "|")
// Add any remaining flags which aren't accounted for as hex.
if c != 0 {
statusStr += "|0x" + strconv.FormatUint(uint64(c), 16)
}
// If this was purely an unknown flag, then remove the extra bar at the
// start of the string.
statusStr = strings.TrimLeft(statusStr, "|")
return statusStr
}
// FinalHtlcByte defines a byte type that encodes information about the final
// htlc resolution.
type FinalHtlcByte byte
const (
// FinalHtlcSettledBit is the bit that encodes whether the htlc was
// settled or failed.
FinalHtlcSettledBit FinalHtlcByte = 1 << 0
// FinalHtlcOffchainBit is the bit that encodes whether the htlc was
// resolved offchain or onchain.
FinalHtlcOffchainBit FinalHtlcByte = 1 << 1
)
// OpenChannel encapsulates the persistent and dynamic state of an open channel
// with a remote node. An open channel supports several options for on-disk
// serialization depending on the exact context. Full (upon channel creation)
// state commitments, and partial (due to a commitment update) writes are
// supported. Each partial write due to a state update appends the new update
// to an on-disk log, which can then subsequently be queried in order to
// "time-travel" to a prior state.
type OpenChannel struct {
// ChanType denotes which type of channel this is.
ChanType ChannelType
// ChainHash is a hash which represents the blockchain that this
// channel will be opened within. This value is typically the genesis
// hash. In the case that the original chain went through a contentious
// hard-fork, then this value will be tweaked using the unique fork
// point on each branch.
ChainHash chainhash.Hash
// FundingOutpoint is the outpoint of the final funding transaction.
// This value uniquely and globally identifies the channel within the
// target blockchain as specified by the chain hash parameter.
FundingOutpoint wire.OutPoint
// ShortChannelID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
//
// If IsZeroConf(), then this will the "base" (very first) ALIAS scid
// and the confirmed SCID will be stored in ConfirmedScid.
ShortChannelID lnwire.ShortChannelID
// IsPending indicates whether a channel's funding transaction has been
// confirmed.
IsPending bool
// IsInitiator is a bool which indicates if we were the original
// initiator for the channel. This value may affect how higher levels
// negotiate fees, or close the channel.
IsInitiator bool
// chanStatus is the current status of this channel. If it is not in
// the state Default, it should not be used for forwarding payments.
chanStatus ChannelStatus
// FundingBroadcastHeight is the height in which the funding
// transaction was broadcast. This value can be used by higher level
// sub-systems to determine if a channel is stale and/or should have
// been confirmed before a certain height.
FundingBroadcastHeight uint32
// NumConfsRequired is the number of confirmations a channel's funding
// transaction must have received in order to be considered available
// for normal transactional use.
NumConfsRequired uint16
// ChannelFlags holds the flags that were sent as part of the
// open_channel message.
ChannelFlags lnwire.FundingFlag
// IdentityPub is the identity public key of the remote node this
// channel has been established with.
IdentityPub *btcec.PublicKey
// Capacity is the total capacity of this channel.
Capacity btcutil.Amount
// TotalMSatSent is the total number of milli-satoshis we've sent
// within this channel.
TotalMSatSent lnwire.MilliSatoshi
// TotalMSatReceived is the total number of milli-satoshis we've
// received within this channel.
TotalMSatReceived lnwire.MilliSatoshi
// InitialLocalBalance is the balance we have during the channel
// opening. When we are not the initiator, this value represents the
// push amount.
InitialLocalBalance lnwire.MilliSatoshi
// InitialRemoteBalance is the balance they have during the channel
// opening.
InitialRemoteBalance lnwire.MilliSatoshi
// LocalChanCfg is the channel configuration for the local node.
LocalChanCfg ChannelConfig
// RemoteChanCfg is the channel configuration for the remote node.
RemoteChanCfg ChannelConfig
// LocalCommitment is the current local commitment state for the local
// party. This is stored distinct from the state of the remote party
// as there are certain asymmetric parameters which affect the
// structure of each commitment.
LocalCommitment ChannelCommitment
// RemoteCommitment is the current remote commitment state for the
// remote party. This is stored distinct from the state of the local
// party as there are certain asymmetric parameters which affect the
// structure of each commitment.
RemoteCommitment ChannelCommitment
// RemoteCurrentRevocation is the current revocation for their
// commitment transaction. However, since this the derived public key,
// we don't yet have the private key so we aren't yet able to verify
// that it's actually in the hash chain.
RemoteCurrentRevocation *btcec.PublicKey
// RemoteNextRevocation is the revocation key to be used for the *next*
// commitment transaction we create for the local node. Within the
// specification, this value is referred to as the
// per-commitment-point.
RemoteNextRevocation *btcec.PublicKey
// RevocationProducer is used to generate the revocation in such a way
// that remote side might store it efficiently and have the ability to
// restore the revocation by index if needed. Current implementation of
// secret producer is shachain producer.
RevocationProducer shachain.Producer
// RevocationStore is used to efficiently store the revocations for
// previous channels states sent to us by remote side. Current
// implementation of secret store is shachain store.
RevocationStore shachain.Store
// Packager is used to create and update forwarding packages for this
// channel, which encodes all necessary information to recover from
// failures and reforward HTLCs that were not fully processed.
Packager FwdPackager
// FundingTxn is the transaction containing this channel's funding
// outpoint. Upon restarts, this txn will be rebroadcast if the channel
// is found to be pending.
//
// NOTE: This value will only be populated for single-funder channels
// for which we are the initiator, and that we also have the funding
// transaction for. One can check this by using the HasFundingTx()
// method on the ChanType field.
FundingTxn *wire.MsgTx
// LocalShutdownScript is set to a pre-set script if the channel was opened
// by the local node with option_upfront_shutdown_script set. If the option
// was not set, the field is empty.
LocalShutdownScript lnwire.DeliveryAddress
// RemoteShutdownScript is set to a pre-set script if the channel was opened
// by the remote node with option_upfront_shutdown_script set. If the option
// was not set, the field is empty.
RemoteShutdownScript lnwire.DeliveryAddress
// ThawHeight is the height when a frozen channel once again becomes a
// normal channel. If this is zero, then there're no restrictions on
// this channel. If the value is lower than 500,000, then it's
// interpreted as a relative height, or an absolute height otherwise.
ThawHeight uint32
// LastWasRevoke is a boolean that determines if the last update we sent
// was a revocation (true) or a commitment signature (false).
LastWasRevoke bool
// RevocationKeyLocator stores the KeyLocator information that we will
// need to derive the shachain root for this channel. This allows us to
// have private key isolation from lnd.
RevocationKeyLocator keychain.KeyLocator
// confirmedScid is the confirmed ShortChannelID for a zero-conf
// channel. If the channel is unconfirmed, then this will be the
// default ShortChannelID. This is only set for zero-conf channels.
confirmedScid lnwire.ShortChannelID
// Memo is any arbitrary information we wish to store locally about the
// channel that will be useful to our future selves.
Memo []byte
// TapscriptRoot is an optional tapscript root used to derive the MuSig2
// funding output.
TapscriptRoot fn.Option[chainhash.Hash]
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This information is only created
// at channel funding time, and after wards is to be considered
// immutable.
CustomBlob fn.Option[tlv.Blob]
// TODO(roasbeef): eww
Db *ChannelStateDB
// TODO(roasbeef): just need to store local and remote HTLC's?
sync.RWMutex
}
// String returns a string representation of the channel.
func (c *OpenChannel) String() string {
indexStr := "height=%v, local_htlc_index=%v, local_log_index=%v, " +
"remote_htlc_index=%v, remote_log_index=%v"
commit := c.LocalCommitment
local := fmt.Sprintf(indexStr, commit.CommitHeight,
commit.LocalHtlcIndex, commit.LocalLogIndex,
commit.RemoteHtlcIndex, commit.RemoteLogIndex,
)
commit = c.RemoteCommitment
remote := fmt.Sprintf(indexStr, commit.CommitHeight,
commit.LocalHtlcIndex, commit.LocalLogIndex,
commit.RemoteHtlcIndex, commit.RemoteLogIndex,
)
return fmt.Sprintf("SCID=%v, status=%v, initiator=%v, pending=%v, "+
"local commitment has %s, remote commitment has %s",
c.ShortChannelID, c.chanStatus, c.IsInitiator, c.IsPending,
local, remote,
)
}
// Initiator returns the ChannelParty that originally opened this channel.
func (c *OpenChannel) Initiator() lntypes.ChannelParty {
c.RLock()
defer c.RUnlock()
if c.IsInitiator {
return lntypes.Local
}
return lntypes.Remote
}
// ShortChanID returns the current ShortChannelID of this channel.
func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID {
c.RLock()
defer c.RUnlock()
return c.ShortChannelID
}
// ZeroConfRealScid returns the zero-conf channel's confirmed scid. This should
// only be called if IsZeroConf returns true.
func (c *OpenChannel) ZeroConfRealScid() lnwire.ShortChannelID {
c.RLock()
defer c.RUnlock()
return c.confirmedScid
}
// ZeroConfConfirmed returns whether the zero-conf channel has confirmed. This
// should only be called if IsZeroConf returns true.
func (c *OpenChannel) ZeroConfConfirmed() bool {
c.RLock()
defer c.RUnlock()
return c.confirmedScid != hop.Source
}
// IsZeroConf returns whether the option_zeroconf channel type was negotiated.
func (c *OpenChannel) IsZeroConf() bool {
c.RLock()
defer c.RUnlock()
return c.ChanType.HasZeroConf()
}
// IsOptionScidAlias returns whether the option_scid_alias channel type was
// negotiated.
func (c *OpenChannel) IsOptionScidAlias() bool {
c.RLock()
defer c.RUnlock()
return c.ChanType.HasScidAliasChan()
}
// NegotiatedAliasFeature returns whether the option-scid-alias feature bit was
// negotiated.
func (c *OpenChannel) NegotiatedAliasFeature() bool {
c.RLock()
defer c.RUnlock()
return c.ChanType.HasScidAliasFeature()
}
// ChanStatus returns the current ChannelStatus of this channel.
func (c *OpenChannel) ChanStatus() ChannelStatus {
c.RLock()
defer c.RUnlock()
return c.chanStatus
}
// ApplyChanStatus allows the caller to modify the internal channel state in a
// thead-safe manner.
func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error {
c.Lock()
defer c.Unlock()
return c.putChanStatus(status)
}
// ClearChanStatus allows the caller to clear a particular channel status from
// the primary channel status bit field. After this method returns, a call to
// HasChanStatus(status) should return false.
func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error {
c.Lock()
defer c.Unlock()
return c.clearChanStatus(status)
}
// HasChanStatus returns true if the internal bitfield channel status of the
// target channel has the specified status bit set.
func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool {
c.RLock()
defer c.RUnlock()
return c.hasChanStatus(status)
}
func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool {
// Special case ChanStatusDefualt since it isn't actually flag, but a
// particular combination (or lack-there-of) of flags.
if status == ChanStatusDefault {
return c.chanStatus == ChanStatusDefault
}
return c.chanStatus&status == status
}
// BroadcastHeight returns the height at which the funding tx was broadcast.
func (c *OpenChannel) BroadcastHeight() uint32 {
c.RLock()
defer c.RUnlock()
return c.FundingBroadcastHeight
}
// SetBroadcastHeight sets the FundingBroadcastHeight.
func (c *OpenChannel) SetBroadcastHeight(height uint32) {
c.Lock()
defer c.Unlock()
c.FundingBroadcastHeight = height
}
// amendTlvData updates the channel with the given auxiliary TLV data.
func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) {
c.RevocationKeyLocator = auxData.revokeKeyLoc.Val.KeyLocator
c.InitialLocalBalance = lnwire.MilliSatoshi(
auxData.initialLocalBalance.Val,
)
c.InitialRemoteBalance = lnwire.MilliSatoshi(
auxData.initialRemoteBalance.Val,
)
c.confirmedScid = auxData.realScid.Val
auxData.memo.WhenSomeV(func(memo []byte) {
c.Memo = memo
})
auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) {
c.TapscriptRoot = fn.Some[chainhash.Hash](h)
})
auxData.customBlob.WhenSomeV(func(blob tlv.Blob) {
c.CustomBlob = fn.Some(blob)
})
}
// extractTlvData creates a new openChannelTlvData from the given channel.
func (c *OpenChannel) extractTlvData() openChannelTlvData {
auxData := openChannelTlvData{
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1](
keyLocRecord{c.RevocationKeyLocator},
),
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint64(c.InitialLocalBalance),
),
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3](
uint64(c.InitialRemoteBalance),
),
realScid: tlv.NewRecordT[tlv.TlvType4](
c.confirmedScid,
),
}
if len(c.Memo) != 0 {
auxData.memo = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo),
)
}
c.TapscriptRoot.WhenSome(func(h chainhash.Hash) {
auxData.tapscriptRoot = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h),
)
})
c.CustomBlob.WhenSome(func(blob tlv.Blob) {
auxData.customBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType7](blob),
)
})
return auxData
}
// Refresh updates the in-memory channel state using the latest state observed
// on disk.
func (c *OpenChannel) Refresh() error {
c.Lock()
defer c.Unlock()
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
// We'll re-populating the in-memory channel with the info
// fetched from disk.
if err := fetchChanInfo(chanBucket, c); err != nil {
return fmt.Errorf("unable to fetch chan info: %w", err)
}
// Also populate the channel's commitment states for both sides
// of the channel.
if err := fetchChanCommitments(chanBucket, c); err != nil {
return fmt.Errorf("unable to fetch chan commitments: "+
"%v", err)
}
// Also retrieve the current revocation state.
if err := fetchChanRevocationState(chanBucket, c); err != nil {
return fmt.Errorf("unable to fetch chan revocations: "+
"%v", err)
}
return nil
}, func() {})
if err != nil {
return err
}
return nil
}
// fetchChanBucket is a helper function that returns the bucket where a
// channel's data resides in given: the public key for the node, the outpoint,
// and the chainhash that the channel resides on.
func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey,
outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RBucket, error) {
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket := tx.ReadBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
}
// TODO(roasbeef): CreateTopLevelBucket on the interface isn't like
// CreateIfNotExists, will return error
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := nodeKey.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadBucket(chainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
}
// fetchChanBucketRw is a helper function that returns the bucket where a
// channel's data resides in given: the public key for the node, the outpoint,
// and the chainhash that the channel resides on. This differs from
// fetchChanBucket in that it returns a writeable bucket.
func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey,
outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket,
error) {
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
}
// TODO(roasbeef): CreateTopLevelBucket on the interface isn't like
// CreateIfNotExists, will return error
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := nodeKey.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
}
func fetchFinalHtlcsBucketRw(tx kvdb.RwTx,
chanID lnwire.ShortChannelID) (kvdb.RwBucket, error) {
finalHtlcsBucket, err := tx.CreateTopLevelBucket(finalHtlcsBucket)
if err != nil {
return nil, err
}
var chanIDBytes [8]byte
byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64())
chanBucket, err := finalHtlcsBucket.CreateBucketIfNotExists(
chanIDBytes[:],
)
if err != nil {
return nil, err
}
return chanBucket, nil
}
// fullSync syncs the contents of an OpenChannel while re-using an existing
// database transaction.
func (c *OpenChannel) fullSync(tx kvdb.RwTx) error {
// Fetch the outpoint bucket and check if the outpoint already exists.
opBucket := tx.ReadWriteBucket(outpointBucket)
if opBucket == nil {
return ErrNoChanDBExists
}
cidBucket := tx.ReadWriteBucket(chanIDBucket)
if cidBucket == nil {
return ErrNoChanDBExists
}
var chanPointBuf bytes.Buffer
err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint)
if err != nil {
return err
}
// Now, check if the outpoint exists in our index.
if opBucket.Get(chanPointBuf.Bytes()) != nil {
return ErrChanAlreadyExists
}
cid := lnwire.NewChanIDFromOutPoint(c.FundingOutpoint)
if cidBucket.Get(cid[:]) != nil {
return ErrChanAlreadyExists
}
status := uint8(outpointOpen)
// Write the status of this outpoint as the first entry in a tlv
// stream.
statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status)
opStream, err := tlv.NewStream(statusRecord)
if err != nil {
return err
}
var b bytes.Buffer
if err := opStream.Encode(&b); err != nil {
return err
}
// Add the outpoint to our outpoint index with the tlv stream.
if err := opBucket.Put(chanPointBuf.Bytes(), b.Bytes()); err != nil {
return err
}
if err := cidBucket.Put(cid[:], []byte{}); err != nil {
return err
}
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket, err := tx.CreateTopLevelBucket(openChannelBucket)
if err != nil {
return err
}
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := c.IdentityPub.SerializeCompressed()
nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub)
if err != nil {
return err
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:])
if err != nil {
return err
}
// With the bucket for the node fetched, we can now go down another
// level, creating the bucket for this channel itself.
chanBucket, err := chainBucket.CreateBucket(
chanPointBuf.Bytes(),
)
switch {
case err == kvdb.ErrBucketExists:
// If this channel already exists, then in order to avoid
// overriding it, we'll return an error back up to the caller.
return ErrChanAlreadyExists
case err != nil:
return err
}
return putOpenChannel(chanBucket, c)
}
// MarkAsOpen marks a channel as fully open given a locator that uniquely
// describes its location within the chain.
func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
c.Lock()
defer c.Unlock()
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
if err != nil {
return err
}
channel.IsPending = false
channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil {
return err
}
c.IsPending = false
c.ShortChannelID = openLoc
c.Packager = NewChannelPackager(openLoc)
return nil
}
// MarkRealScid marks the zero-conf channel's confirmed ShortChannelID. This
// should only be done if IsZeroConf returns true.
func (c *OpenChannel) MarkRealScid(realScid lnwire.ShortChannelID) error {
c.Lock()
defer c.Unlock()
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
channel, err := fetchOpenChannel(
chanBucket, &c.FundingOutpoint,
)
if err != nil {
return err
}
channel.confirmedScid = realScid
return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil {
return err
}
c.confirmedScid = realScid
return nil
}
// MarkScidAliasNegotiated adds ScidAliasFeatureBit to ChanType in-memory and
// in the database.
func (c *OpenChannel) MarkScidAliasNegotiated() error {
c.Lock()
defer c.Unlock()
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
channel, err := fetchOpenChannel(
chanBucket, &c.FundingOutpoint,
)
if err != nil {
return err
}
channel.ChanType |= ScidAliasFeatureBit
return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil {
return err
}
c.ChanType |= ScidAliasFeatureBit
return nil
}
// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the
// passed commitPoint for use to retrieve funds in case the remote force closes
// the channel.
func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error {
c.Lock()
defer c.Unlock()
var b bytes.Buffer
if err := WriteElement(&b, commitPoint); err != nil {
return err
}
putCommitPoint := func(chanBucket kvdb.RwBucket) error {
return chanBucket.Put(dataLossCommitPointKey, b.Bytes())
}
return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint)
}
// DataLossCommitPoint retrieves the stored commit point set during
// MarkDataLoss. If not found ErrNoCommitPoint is returned.
func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
var commitPoint *btcec.PublicKey
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch err {
case nil:
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
return ErrNoCommitPoint
default:
return err
}
bs := chanBucket.Get(dataLossCommitPointKey)
if bs == nil {
return ErrNoCommitPoint
}
r := bytes.NewReader(bs)
if err := ReadElements(r, &commitPoint); err != nil {
return err
}
return nil
}, func() {
commitPoint = nil
})
if err != nil {
return nil, err
}
return commitPoint, nil
}
// MarkBorked marks the event when the channel as reached an irreconcilable
// state, such as a channel breach or state desynchronization. Borked channels
// should never be added to the switch.
func (c *OpenChannel) MarkBorked() error {
c.Lock()
defer c.Unlock()
return c.putChanStatus(ChanStatusBorked)
}
// SecondCommitmentPoint returns the second per-commitment-point for use in the
// channel_ready message.
func (c *OpenChannel) SecondCommitmentPoint() (*btcec.PublicKey, error) {
c.RLock()
defer c.RUnlock()
// Since we start at commitment height = 0, the second per commitment
// point is actually at the 1st index.
revocation, err := c.RevocationProducer.AtIndex(1)
if err != nil {
return nil, err
}
return input.ComputeCommitmentPoint(revocation[:]), nil
}
var (
// taprootRevRootKey is the key used to derive the revocation root for
// the taproot nonces. This is done via HMAC of the existing revocation
// root.
taprootRevRootKey = []byte("taproot-rev-root")
)
// DeriveMusig2Shachain derives a shachain producer for the taproot channel
// from normal shachain revocation root.
func DeriveMusig2Shachain(revRoot shachain.Producer) (shachain.Producer, error) { //nolint:ll
// In order to obtain the revocation root hash to create the taproot
// revocation, we'll encode the producer into a buffer, then use that
// to derive the shachain root needed.
var rootHashBuf bytes.Buffer
if err := revRoot.Encode(&rootHashBuf); err != nil {
return nil, fmt.Errorf("unable to encode producer: %w", err)
}
revRootHash := chainhash.HashH(rootHashBuf.Bytes())
// For taproot channel types, we'll also generate a distinct shachain
// root using the same seed information. We'll use this to generate
// verification nonces for the channel. We'll bind with this a simple
// hmac.
taprootRevHmac := hmac.New(sha256.New, taprootRevRootKey)
if _, err := taprootRevHmac.Write(revRootHash[:]); err != nil {
return nil, err
}
taprootRevRoot := taprootRevHmac.Sum(nil)
// Once we have the root, we can then generate our shachain producer
// and from that generate the per-commitment point.
return shachain.NewRevocationProducerFromBytes(
taprootRevRoot,
)
}
// NewMusigVerificationNonce generates the local or verification nonce for
// another musig2 session. In order to permit our implementation to not have to
// write any secret nonce state to disk, we'll use the _next_ shachain
// pre-image as our primary randomness source. When used to generate the nonce
// again to broadcast our commitment hte current height will be used.
func NewMusigVerificationNonce(pubKey *btcec.PublicKey, targetHeight uint64,
shaGen shachain.Producer) (*musig2.Nonces, error) {
// Now that we know what height we need, we'll grab the shachain
// pre-image at the target destination.
nextPreimage, err := shaGen.AtIndex(targetHeight)
if err != nil {
return nil, err
}
shaChainRand := musig2.WithCustomRand(bytes.NewBuffer(nextPreimage[:]))
pubKeyOpt := musig2.WithPublicKey(pubKey)
return musig2.GenNonces(pubKeyOpt, shaChainRand)
}
// ChanSyncMsg returns the ChannelReestablish message that should be sent upon
// reconnection with the remote peer that we're maintaining this channel with.
// The information contained within this message is necessary to re-sync our
// commitment chains in the case of a last or only partially processed message.
// When the remote party receives this message one of three things may happen:
//
// 1. We're fully synced and no messages need to be sent.
// 2. We didn't get the last CommitSig message they sent, so they'll re-send
// it.
// 3. We didn't get the last RevokeAndAck message they sent, so they'll
// re-send it.
//
// If this is a restored channel, having status ChanStatusRestored, then we'll
// modify our typical chan sync message to ensure they force close even if
// we're on the very first state.
func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) {
c.Lock()
defer c.Unlock()
// The remote commitment height that we'll send in the
// ChannelReestablish message is our current commitment height plus
// one. If the receiver thinks that our commitment height is actually
// *equal* to this value, then they'll re-send the last commitment that
// they sent but we never fully processed.
localHeight := c.LocalCommitment.CommitHeight
nextLocalCommitHeight := localHeight + 1
// The second value we'll send is the height of the remote commitment
// from our PoV. If the receiver thinks that their height is actually
// *one plus* this value, then they'll re-send their last revocation.
remoteChainTipHeight := c.RemoteCommitment.CommitHeight
// If this channel has undergone a commitment update, then in order to
// prove to the remote party our knowledge of their prior commitment
// state, we'll also send over the last commitment secret that the
// remote party sent.
var lastCommitSecret [32]byte
if remoteChainTipHeight != 0 {
remoteSecret, err := c.RevocationStore.LookUp(
remoteChainTipHeight - 1,
)
if err != nil {
return nil, err
}
lastCommitSecret = [32]byte(*remoteSecret)
}
// Additionally, we'll send over the current unrevoked commitment on
// our local commitment transaction.
currentCommitSecret, err := c.RevocationProducer.AtIndex(
localHeight,
)
if err != nil {
return nil, err
}
// If we've restored this channel, then we'll purposefully give them an
// invalid LocalUnrevokedCommitPoint so they'll force close the channel
// allowing us to sweep our funds.
if c.hasChanStatus(ChanStatusRestored) {
currentCommitSecret[0] ^= 1
// If this is a tweakless channel, then we'll purposefully send
// a next local height taht's invalid to trigger a force close
// on their end. We do this as tweakless channels don't require
// that the commitment point is valid, only that it's present.
if c.ChanType.IsTweakless() {
nextLocalCommitHeight = 0
}
}
// If this is a taproot channel, then we'll need to generate our next
// verification nonce to send to the remote party. They'll use this to
// sign the next update to our commitment transaction.
var nextTaprootNonce lnwire.OptMusig2NonceTLV
if c.ChanType.IsTaproot() {
taprootRevProducer, err := DeriveMusig2Shachain(
c.RevocationProducer,
)
if err != nil {
return nil, err
}
nextNonce, err := NewMusigVerificationNonce(
c.LocalChanCfg.MultiSigKey.PubKey,
nextLocalCommitHeight, taprootRevProducer,
)
if err != nil {
return nil, fmt.Errorf("unable to gen next "+
"nonce: %w", err)
}
nextTaprootNonce = lnwire.SomeMusig2Nonce(nextNonce.PubNonce)
}
return &lnwire.ChannelReestablish{
ChanID: lnwire.NewChanIDFromOutPoint(
c.FundingOutpoint,
),
NextLocalCommitHeight: nextLocalCommitHeight,
RemoteCommitTailHeight: remoteChainTipHeight,
LastRemoteCommitSecret: lastCommitSecret,
LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint(
currentCommitSecret[:],
),
LocalNonce: nextTaprootNonce,
}, nil
}
// MarkShutdownSent serialises and persist the given ShutdownInfo for this
// channel. Persisting this info represents the fact that we have sent the
// Shutdown message to the remote side and hence that we should re-transmit the
// same Shutdown message on re-establish.
func (c *OpenChannel) MarkShutdownSent(info *ShutdownInfo) error {
c.Lock()
defer c.Unlock()
return c.storeShutdownInfo(info)
}
// storeShutdownInfo serialises the ShutdownInfo and persists it under the
// shutdownInfoKey.
func (c *OpenChannel) storeShutdownInfo(info *ShutdownInfo) error {
var b bytes.Buffer
err := info.encode(&b)
if err != nil {
return err
}
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
return chanBucket.Put(shutdownInfoKey, b.Bytes())
}, func() {})
}
// ShutdownInfo decodes the shutdown info stored for this channel and returns
// the result. If no shutdown info has been persisted for this channel then the
// ErrNoShutdownInfo error is returned.
func (c *OpenChannel) ShutdownInfo() (fn.Option[ShutdownInfo], error) {
c.RLock()
defer c.RUnlock()
var shutdownInfo *ShutdownInfo
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch {
case err == nil:
case errors.Is(err, ErrNoChanDBExists),
errors.Is(err, ErrNoActiveChannels),
errors.Is(err, ErrChannelNotFound):
return ErrNoShutdownInfo
default:
return err
}
shutdownInfoBytes := chanBucket.Get(shutdownInfoKey)
if shutdownInfoBytes == nil {
return ErrNoShutdownInfo
}
shutdownInfo, err = decodeShutdownInfo(shutdownInfoBytes)
return err
}, func() {
shutdownInfo = nil
})
if err != nil {
return fn.None[ShutdownInfo](), err
}
return fn.Some[ShutdownInfo](*shutdownInfo), nil
}
// isBorked returns true if the channel has been marked as borked in the
// database. This requires an existing database transaction to already be
// active.
//
// NOTE: The primary mutex should already be held before this method is called.
func (c *OpenChannel) isBorked(chanBucket kvdb.RBucket) (bool, error) {
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
if err != nil {
return false, err
}
return channel.chanStatus != ChanStatusDefault, nil
}
// MarkCommitmentBroadcasted marks the channel as a commitment transaction has
// been broadcast, either our own or the remote, and we should watch the chain
// for it to confirm before taking any further action. It takes as argument the
// closing tx _we believe_ will appear in the chain. This is only used to
// republish this tx at startup to ensure propagation, and we should still
// handle the case where a different tx actually hits the chain.
func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx,
closer lntypes.ChannelParty) error {
return c.markBroadcasted(
ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx,
closer,
)
}
// MarkCoopBroadcasted marks the channel to indicate that a cooperative close
// transaction has been broadcast, either our own or the remote, and that we
// should watch the chain for it to confirm before taking further action. It
// takes as argument a cooperative close tx that could appear on chain, and
// should be rebroadcast upon startup. This is only used to republish and
// ensure propagation, and we should still handle the case where a different tx
// actually hits the chain.
func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx,
closer lntypes.ChannelParty) error {
return c.markBroadcasted(
ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx,
closer,
)
}
// markBroadcasted is a helper function which modifies the channel status of the
// receiving channel and inserts a close transaction under the requested key,
// which should specify either a coop or force close. It adds a status which
// indicates the party that initiated the channel close.
func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte,
closeTx *wire.MsgTx, closer lntypes.ChannelParty) error {
c.Lock()
defer c.Unlock()
// If a closing tx is provided, we'll generate a closure to write the
// transaction in the appropriate bucket under the given key.
var putClosingTx func(kvdb.RwBucket) error
if closeTx != nil {
var b bytes.Buffer
if err := WriteElement(&b, closeTx); err != nil {
return err
}
putClosingTx = func(chanBucket kvdb.RwBucket) error {
return chanBucket.Put(key, b.Bytes())
}
}
// Add the initiator status to the status provided. These statuses are
// set in addition to the broadcast status so that we do not need to
// migrate the original logic which does not store initiator.
if closer.IsLocal() {
status |= ChanStatusLocalCloseInitiator
} else {
status |= ChanStatusRemoteCloseInitiator
}
return c.putChanStatus(status, putClosingTx)
}
// BroadcastedCommitment retrieves the stored unilateral closing tx set during
// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned.
func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) {
return c.getClosingTx(forceCloseTxKey)
}
// BroadcastedCooperative retrieves the stored cooperative closing tx set during
// MarkCoopBroadcasted. If not found ErrNoCloseTx is returned.
func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) {
return c.getClosingTx(coopCloseTxKey)
}
// getClosingTx is a helper method which returns the stored closing transaction
// for key. The caller should use either the force or coop closing keys.
func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
var closeTx *wire.MsgTx
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch err {
case nil:
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
return ErrNoCloseTx
default:
return err
}
bs := chanBucket.Get(key)
if bs == nil {
return ErrNoCloseTx
}
r := bytes.NewReader(bs)
return ReadElement(r, &closeTx)
}, func() {
closeTx = nil
})
if err != nil {
return nil, err
}
return closeTx, nil
}
// putChanStatus appends the given status to the channel. fs is an optional
// list of closures that are given the chanBucket in order to atomically add
// extra information together with the new status.
func (c *OpenChannel) putChanStatus(status ChannelStatus,
fs ...func(kvdb.RwBucket) error) error {
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
if err != nil {
return err
}
// Add this status to the existing bitvector found in the DB.
status = channel.chanStatus | status
channel.chanStatus = status
if err := putOpenChannel(chanBucket, channel); err != nil {
return err
}
for _, f := range fs {
// Skip execution of nil closures.
if f == nil {
continue
}
if err := f(chanBucket); err != nil {
return err
}
}
return nil
}, func() {}); err != nil {
return err
}
// Update the in-memory representation to keep it in sync with the DB.
c.chanStatus = status
return nil
}
func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
if err != nil {
return err
}
// Unset this bit in the bitvector on disk.
status = channel.chanStatus & ^status
channel.chanStatus = status
return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil {
return err
}
// Update the in-memory representation to keep it in sync with the DB.
c.chanStatus = status
return nil
}
// putOpenChannel serializes, and stores the current state of the channel in its
// entirety.
func putOpenChannel(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
// First, we'll write out all the relatively static fields, that are
// decided upon initial channel creation.
if err := putChanInfo(chanBucket, channel); err != nil {
return fmt.Errorf("unable to store chan info: %w", err)
}
// With the static channel info written out, we'll now write out the
// current commitment state for both parties.
if err := putChanCommitments(chanBucket, channel); err != nil {
return fmt.Errorf("unable to store chan commitments: %w", err)
}
// Next, if this is a frozen channel, we'll add in the axillary
// information we need to store.
if channel.ChanType.IsFrozen() || channel.ChanType.HasLeaseExpiration() {
err := storeThawHeight(
chanBucket, channel.ThawHeight,
)
if err != nil {
return fmt.Errorf("unable to store thaw height: %w",
err)
}
}
// Finally, we'll write out the revocation state for both parties
// within a distinct key space.
if err := putChanRevocationState(chanBucket, channel); err != nil {
return fmt.Errorf("unable to store chan revocations: %w", err)
}
return nil
}
// fetchOpenChannel retrieves, and deserializes (including decrypting
// sensitive) the complete channel currently active with the passed nodeID.
func fetchOpenChannel(chanBucket kvdb.RBucket,
chanPoint *wire.OutPoint) (*OpenChannel, error) {
channel := &OpenChannel{
FundingOutpoint: *chanPoint,
}
// First, we'll read all the static information that changes less
// frequently from disk.
if err := fetchChanInfo(chanBucket, channel); err != nil {
return nil, fmt.Errorf("unable to fetch chan info: %w", err)
}
// With the static information read, we'll now read the current
// commitment state for both sides of the channel.
if err := fetchChanCommitments(chanBucket, channel); err != nil {
return nil, fmt.Errorf("unable to fetch chan commitments: %w",
err)
}
// Next, if this is a frozen channel, we'll add in the axillary
// information we need to store.
if channel.ChanType.IsFrozen() || channel.ChanType.HasLeaseExpiration() {
thawHeight, err := fetchThawHeight(chanBucket)
if err != nil {
return nil, fmt.Errorf("unable to store thaw "+
"height: %v", err)
}
channel.ThawHeight = thawHeight
}
// Finally, we'll retrieve the current revocation state so we can
// properly
if err := fetchChanRevocationState(chanBucket, channel); err != nil {
return nil, fmt.Errorf("unable to fetch chan revocations: %w",
err)
}
channel.Packager = NewChannelPackager(channel.ShortChannelID)
return channel, nil
}
// SyncPending writes the contents of the channel to the database while it's in
// the pending (waiting for funding confirmation) state. The IsPending flag
// will be set to true. When the channel's funding transaction is confirmed,
// the channel should be marked as "open" and the IsPending flag set to false.
// Note that this function also creates a LinkNode relationship between this
// newly created channel and a new LinkNode instance. This allows listing all
// channels in the database globally, or according to the LinkNode they were
// created with.
//
// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type
// that includes service bits.
func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
c.Lock()
defer c.Unlock()
c.FundingBroadcastHeight = pendingHeight
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return syncNewChannel(tx, c, []net.Addr{addr})
}, func() {})
}
// syncNewChannel will write the passed channel to disk, and also create a
// LinkNode (if needed) for the channel peer.
func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error {
// First, sync all the persistent channel state to disk.
if err := c.fullSync(tx); err != nil {
return err
}
nodeInfoBucket, err := tx.CreateTopLevelBucket(nodeInfoBucket)
if err != nil {
return err
}
// If a LinkNode for this identity public key already exists,
// then we can exit early.
nodePub := c.IdentityPub.SerializeCompressed()
if nodeInfoBucket.Get(nodePub) != nil {
return nil
}
// Next, we need to establish a (possibly) new LinkNode relationship
// for this channel. The LinkNode metadata contains reachability,
// up-time, and service bits related information.
linkNode := NewLinkNode(
&LinkNodeDB{backend: c.Db.backend},
wire.MainNet, c.IdentityPub, addrs...,
)
// TODO(roasbeef): do away with link node all together?
return putLinkNode(nodeInfoBucket, linkNode)
}
// UpdateCommitment updates the local commitment state. It locks in the pending
// local updates that were received by us from the remote party. The commitment
// state completely describes the balance state at this point in the commitment
// chain. In addition to that, it persists all the remote log updates that we
// have acked, but not signed a remote commitment for yet. These need to be
// persisted to be able to produce a valid commit signature if a restart would
// occur. This method its to be called when we revoke our prior commitment
// state.
//
// A map is returned of all the htlc resolutions that were locked in this
// commitment. Keys correspond to htlc indices and values indicate whether the
// htlc was settled or failed.
func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment,
unsignedAckedUpdates []LogUpdate) (map[uint64]bool, error) {
c.Lock()
defer c.Unlock()
// If this is a restored channel, then we want to avoid mutating the
// state as all, as it's impossible to do so in a protocol compliant
// manner.
if c.hasChanStatus(ChanStatusRestored) {
return nil, ErrNoRestoredChannelMutation
}
var finalHtlcs = make(map[uint64]bool)
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
// If the channel is marked as borked, then for safety reasons,
// we shouldn't attempt any further updates.
isBorked, err := c.isBorked(chanBucket)
if err != nil {
return err
}
if isBorked {
return ErrChanBorked
}
if err = putChanInfo(chanBucket, c); err != nil {
return fmt.Errorf("unable to store chan info: %w", err)
}
// With the proper bucket fetched, we'll now write the latest
// commitment state to disk for the target party.
err = putChanCommitment(
chanBucket, newCommitment, true,
)
if err != nil {
return fmt.Errorf("unable to store chan "+
"revocations: %v", err)
}
// Persist unsigned but acked remote updates that need to be
// restored after a restart.
var b bytes.Buffer
err = serializeLogUpdates(&b, unsignedAckedUpdates)
if err != nil {
return err
}
err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes())
if err != nil {
return fmt.Errorf("unable to store dangline remote "+
"updates: %v", err)
}
// Since we have just sent the counterparty a revocation, store true
// under lastWasRevokeKey.
var b2 bytes.Buffer
if err := WriteElements(&b2, true); err != nil {
return err
}
if err := chanBucket.Put(lastWasRevokeKey, b2.Bytes()); err != nil {
return err
}
// Persist the remote unsigned local updates that are not included
// in our new commitment.
updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey)
if updateBytes == nil {
return nil
}
r := bytes.NewReader(updateBytes)
updates, err := deserializeLogUpdates(r)
if err != nil {
return err
}
// Get the bucket where settled htlcs are recorded if the user
// opted in to storing this information.
var finalHtlcsBucket kvdb.RwBucket
if c.Db.parent.storeFinalHtlcResolutions {
bucket, err := fetchFinalHtlcsBucketRw(
tx, c.ShortChannelID,
)
if err != nil {
return err
}
finalHtlcsBucket = bucket
}
var unsignedUpdates []LogUpdate
for _, upd := range updates {
// Gather updates that are not on our local commitment.
if upd.LogIndex >= newCommitment.LocalLogIndex {
unsignedUpdates = append(unsignedUpdates, upd)
continue
}
// The update was locked in. If the update was a
// resolution, then store it in the database.
err := processFinalHtlc(
finalHtlcsBucket, upd, finalHtlcs,
)
if err != nil {
return err
}
}
var b3 bytes.Buffer
err = serializeLogUpdates(&b3, unsignedUpdates)
if err != nil {
return fmt.Errorf("unable to serialize log updates: %w",
err)
}
err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b3.Bytes())
if err != nil {
return fmt.Errorf("unable to restore chanbucket: %w",
err)
}
return nil
}, func() {
finalHtlcs = make(map[uint64]bool)
})
if err != nil {
return nil, err
}
c.LocalCommitment = *newCommitment
return finalHtlcs, nil
}
// processFinalHtlc stores a final htlc outcome in the database if signaled via
// the supplied log update. An in-memory htlcs map is updated too.
func processFinalHtlc(finalHtlcsBucket walletdb.ReadWriteBucket, upd LogUpdate,
finalHtlcs map[uint64]bool) error {
var (
settled bool
id uint64
)
switch msg := upd.UpdateMsg.(type) {
case *lnwire.UpdateFulfillHTLC:
settled = true
id = msg.ID
case *lnwire.UpdateFailHTLC:
settled = false
id = msg.ID
case *lnwire.UpdateFailMalformedHTLC:
settled = false
id = msg.ID
default:
return nil
}
// Store the final resolution in the database if a bucket is provided.
if finalHtlcsBucket != nil {
err := putFinalHtlc(
finalHtlcsBucket, id,
FinalHtlcInfo{
Settled: settled,
Offchain: true,
},
)
if err != nil {
return err
}
}
finalHtlcs[id] = settled
return nil
}
// ActiveHtlcs returns a slice of HTLC's which are currently active on *both*
// commitment transactions.
func (c *OpenChannel) ActiveHtlcs() []HTLC {
c.RLock()
defer c.RUnlock()
// We'll only return HTLC's that are locked into *both* commitment
// transactions. So we'll iterate through their set of HTLC's to note
// which ones are present on their commitment.
remoteHtlcs := make(map[[32]byte]struct{})
for _, htlc := range c.RemoteCommitment.Htlcs {
log.Tracef("RemoteCommitment has htlc: id=%v, update=%v "+
"incoming=%v", htlc.HtlcIndex, htlc.LogIndex,
htlc.Incoming)
onionHash := sha256.Sum256(htlc.OnionBlob[:])
remoteHtlcs[onionHash] = struct{}{}
}
// Now that we know which HTLC's they have, we'll only mark the HTLC's
// as active if *we* know them as well.
activeHtlcs := make([]HTLC, 0, len(remoteHtlcs))
for _, htlc := range c.LocalCommitment.Htlcs {
log.Tracef("LocalCommitment has htlc: id=%v, update=%v "+
"incoming=%v", htlc.HtlcIndex, htlc.LogIndex,
htlc.Incoming)
onionHash := sha256.Sum256(htlc.OnionBlob[:])
if _, ok := remoteHtlcs[onionHash]; !ok {
log.Tracef("Skipped htlc due to onion mismatched: "+
"id=%v, update=%v incoming=%v",
htlc.HtlcIndex, htlc.LogIndex, htlc.Incoming)
continue
}
activeHtlcs = append(activeHtlcs, htlc)
}
return activeHtlcs
}
// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are
// contained within ChannelDeltas which encode the current state of the
// commitment between state updates.
//
// TODO(roasbeef): save space by using smaller ints at tail end?
type HTLC struct {
// TODO(yy): can embed an HTLCEntry here.
// Signature is the signature for the second level covenant transaction
// for this HTLC. The second level transaction is a timeout tx in the
// case that this is an outgoing HTLC, and a success tx in the case
// that this is an incoming HTLC.
//
// TODO(roasbeef): make [64]byte instead?
Signature []byte
// RHash is the payment hash of the HTLC.
RHash [32]byte
// Amt is the amount of milli-satoshis this HTLC escrows.
Amt lnwire.MilliSatoshi
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout uint32
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
OutputIndex int32
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
Incoming bool
// OnionBlob is an opaque blob which is used to complete multi-hop
// routing.
OnionBlob [lnwire.OnionPacketSize]byte
// HtlcIndex is the HTLC counter index of this active, outstanding
// HTLC. This differs from the LogIndex, as the HtlcIndex is only
// incremented for each offered HTLC, while they LogIndex is
// incremented for each update (includes settle+fail).
HtlcIndex uint64
// LogIndex is the cumulative log index of this HTLC. This differs
// from the HtlcIndex as this will be incremented for each new log
// update added.
LogIndex uint64
// ExtraData contains any additional information that was transmitted
// with the HTLC via TLVs. This data *must* already be encoded as a
// TLV stream, and may be empty. The length of this data is naturally
// limited by the space available to TLVs in update_add_htlc:
// = 65535 bytes (bolt 8 maximum message size):
// - 2 bytes (bolt 1 message_type)
// - 32 bytes (channel_id)
// - 8 bytes (id)
// - 8 bytes (amount_msat)
// - 32 bytes (payment_hash)
// - 4 bytes (cltv_expiry)
// - 1366 bytes (onion_routing_packet)
// = 64083 bytes maximum possible TLV stream
//
// Note that this extra data is stored inline with the OnionBlob for
// legacy reasons, see serialization/deserialization functions for
// detail.
ExtraData lnwire.ExtraOpaqueData
// BlindingPoint is an optional blinding point included with the HTLC.
//
// Note: this field is not a part of on-disk representation of the
// HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC.
BlindingPoint lnwire.BlindingPointRecord
// CustomRecords is a set of custom TLV records that are associated with
// this HTLC. These records are used to store additional information
// about the HTLC that is not part of the standard HTLC fields. This
// field is encoded within the ExtraData field.
CustomRecords lnwire.CustomRecords
}
// serializeExtraData encodes a TLV stream of extra data to be stored with a
// HTLC. It uses the update_add_htlc TLV types, because this is where extra
// data is passed with a HTLC. At present blinding points are the only extra
// data that we will store, and the function is a no-op if a nil blinding
// point is provided.
//
// This function MUST be called to persist all HTLC values when they are
// serialized.
func (h *HTLC) serializeExtraData() error {
var records []tlv.RecordProducer
h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {
records = append(records, &b)
})
records, err := h.CustomRecords.ExtendRecordProducers(records)
if err != nil {
return err
}
return h.ExtraData.PackRecords(records...)
}
// deserializeExtraData extracts TLVs from the extra data persisted for the
// htlc and populates values in the struct accordingly.
//
// This function MUST be called to populate the struct properly when HTLCs
// are deserialized.
func (h *HTLC) deserializeExtraData() error {
if len(h.ExtraData) == 0 {
return nil
}
blindingPoint := h.BlindingPoint.Zero()
tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint)
if err != nil {
return err
}
if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint)
// Remove the entry from the TLV map. Anything left in the map
// will be included in the custom records field.
delete(tlvMap, h.BlindingPoint.TlvType())
}
// Set the custom records field to the remaining TLV records.
customRecords, err := lnwire.NewCustomRecords(tlvMap)
if err != nil {
return err
}
h.CustomRecords = customRecords
return nil
}
// SerializeHtlcs writes out the passed set of HTLC's into the passed writer
// using the current default on-disk serialization format.
//
// This inline serialization has been extended to allow storage of extra data
// associated with a HTLC in the following way:
// - The known-length onion blob (1366 bytes) is serialized as var bytes in
// WriteElements (ie, the length 1366 was written, followed by the 1366
// onion bytes).
// - To include extra data, we append any extra data present to this one
// variable length of data. Since we know that the onion is strictly 1366
// bytes, any length after that should be considered to be extra data.
//
// NOTE: This API is NOT stable, the on-disk format will likely change in the
// future.
func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
numHtlcs := uint16(len(htlcs))
if err := WriteElement(b, numHtlcs); err != nil {
return err
}
for _, htlc := range htlcs {
// Populate TLV stream for any additional fields contained
// in the TLV.
if err := htlc.serializeExtraData(); err != nil {
return err
}
// The onion blob and hltc data are stored as a single var
// bytes blob.
onionAndExtraData := make(
[]byte, lnwire.OnionPacketSize+len(htlc.ExtraData),
)
copy(onionAndExtraData, htlc.OnionBlob[:])
copy(onionAndExtraData[lnwire.OnionPacketSize:], htlc.ExtraData)
if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, onionAndExtraData,
htlc.HtlcIndex, htlc.LogIndex,
); err != nil {
return err
}
}
return nil
}
// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed
// io.Reader. The bytes within the passed reader MUST have been previously
// written to using the SerializeHtlcs function.
//
// This inline deserialization has been extended to allow storage of extra data
// associated with a HTLC in the following way:
// - The known-length onion blob (1366 bytes) and any additional data present
// are read out as a single blob of variable byte data.
// - They are stored like this to take advantage of the variable space
// available for extension without migration (see SerializeHtlcs).
// - The first 1366 bytes are interpreted as the onion blob, and any remaining
// bytes as extra HTLC data.
// - This extra HTLC data is expected to be serialized as a TLV stream, and
// its parsing is left to higher layers.
//
// NOTE: This API is NOT stable, the on-disk format will likely change in the
// future.
func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
var numHtlcs uint16
if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err
}
var htlcs []HTLC
if numHtlcs == 0 {
return htlcs, nil
}
htlcs = make([]HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ {
var onionAndExtraData []byte
if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &onionAndExtraData,
&htlcs[i].HtlcIndex, &htlcs[i].LogIndex,
); err != nil {
return htlcs, err
}
// Sanity check that we have at least the onion blob size we
// expect.
if len(onionAndExtraData) < lnwire.OnionPacketSize {
return nil, ErrOnionBlobLength
}
// First OnionPacketSize bytes are our fixed length onion
// packet.
copy(
htlcs[i].OnionBlob[:],
onionAndExtraData[0:lnwire.OnionPacketSize],
)
// Any additional bytes belong to extra data. ExtraDataLen
// will be >= 0, because we know that we always have a fixed
// length onion packet.
extraDataLen := len(onionAndExtraData) - lnwire.OnionPacketSize
if extraDataLen > 0 {
htlcs[i].ExtraData = make([]byte, extraDataLen)
copy(
htlcs[i].ExtraData,
onionAndExtraData[lnwire.OnionPacketSize:],
)
}
// Finally, deserialize any TLVs contained in that extra data
// if they are present.
if err := htlcs[i].deserializeExtraData(); err != nil {
return nil, err
}
}
return htlcs, nil
}
// Copy returns a full copy of the target HTLC.
func (h *HTLC) Copy() HTLC {
clone := HTLC{
Incoming: h.Incoming,
Amt: h.Amt,
RefundTimeout: h.RefundTimeout,
OutputIndex: h.OutputIndex,
}
copy(clone.Signature[:], h.Signature)
copy(clone.RHash[:], h.RHash[:])
copy(clone.ExtraData, h.ExtraData)
clone.BlindingPoint = h.BlindingPoint
clone.CustomRecords = h.CustomRecords.Copy()
return clone
}
// LogUpdate represents a pending update to the remote commitment chain. The
// log update may be an add, fail, or settle entry. We maintain this data in
// order to be able to properly retransmit our proposed state if necessary.
type LogUpdate struct {
// LogIndex is the log index of this proposed commitment update entry.
LogIndex uint64
// UpdateMsg is the update message that was included within our
// local update log. The LogIndex value denotes the log index of this
// update which will be used when restoring our local update log if
// we're left with a dangling update on restart.
UpdateMsg lnwire.Message
}
// serializeLogUpdate writes a log update to the provided io.Writer.
func serializeLogUpdate(w io.Writer, l *LogUpdate) error {
return WriteElements(w, l.LogIndex, l.UpdateMsg)
}
// deserializeLogUpdate reads a log update from the provided io.Reader.
func deserializeLogUpdate(r io.Reader) (*LogUpdate, error) {
l := &LogUpdate{}
if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil {
return nil, err
}
return l, nil
}
// CommitDiff represents the delta needed to apply the state transition between
// two subsequent commitment states. Given state N and state N+1, one is able
// to apply the set of messages contained within the CommitDiff to N to arrive
// at state N+1. Each time a new commitment is extended, we'll write a new
// commitment (along with the full commitment state) to disk so we can
// re-transmit the state in the case of a connection loss or message drop.
type CommitDiff struct {
// ChannelCommitment is the full commitment state that one would arrive
// at by applying the set of messages contained in the UpdateDiff to
// the prior accepted commitment.
Commitment ChannelCommitment
// LogUpdates is the set of messages sent prior to the commitment state
// transition in question. Upon reconnection, if we detect that they
// don't have the commitment, then we re-send this along with the
// proper signature.
LogUpdates []LogUpdate
// CommitSig is the exact CommitSig message that should be sent after
// the set of LogUpdates above has been retransmitted. The signatures
// within this message should properly cover the new commitment state
// and also the HTLC's within the new commitment state.
CommitSig *lnwire.CommitSig
// OpenedCircuitKeys is a set of unique identifiers for any downstream
// Add packets included in this commitment txn. After a restart, this
// set of htlcs is acked from the link's incoming mailbox to ensure
// there isn't an attempt to re-add them to this commitment txn.
OpenedCircuitKeys []models.CircuitKey
// ClosedCircuitKeys records the unique identifiers for any settle/fail
// packets that were resolved by this commitment txn. After a restart,
// this is used to ensure those circuits are removed from the circuit
// map, and the downstream packets in the link's mailbox are removed.
ClosedCircuitKeys []models.CircuitKey
// AddAcks specifies the locations (commit height, pkg index) of any
// Adds that were failed/settled in this commit diff. This will ack
// entries in *this* channel's forwarding packages.
//
// NOTE: This value is not serialized, it is used to atomically mark the
// resolution of adds, such that they will not be reprocessed after a
// restart.
AddAcks []AddRef
// SettleFailAcks specifies the locations (chan id, commit height, pkg
// index) of any Settles or Fails that were locked into this commit
// diff, and originate from *another* channel, i.e. the outgoing link.
//
// NOTE: This value is not serialized, it is used to atomically acks
// settles and fails from the forwarding packages of other channels,
// such that they will not be reforwarded internally after a restart.
SettleFailAcks []SettleFailRef
}
// serializeLogUpdates serializes provided list of updates to a stream.
func serializeLogUpdates(w io.Writer, logUpdates []LogUpdate) error {
numUpdates := uint16(len(logUpdates))
if err := binary.Write(w, byteOrder, numUpdates); err != nil {
return err
}
for _, diff := range logUpdates {
err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
if err != nil {
return err
}
}
return nil
}
// deserializeLogUpdates deserializes a list of updates from a stream.
func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) {
var numUpdates uint16
if err := binary.Read(r, byteOrder, &numUpdates); err != nil {
return nil, err
}
logUpdates := make([]LogUpdate, numUpdates)
for i := 0; i < int(numUpdates); i++ {
err := ReadElements(r,
&logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg,
)
if err != nil {
return nil, err
}
}
return logUpdates, nil
}
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
return err
}
if err := WriteElements(w, diff.CommitSig); err != nil {
return err
}
if err := serializeLogUpdates(w, diff.LogUpdates); err != nil {
return err
}
numOpenRefs := uint16(len(diff.OpenedCircuitKeys))
if err := binary.Write(w, byteOrder, numOpenRefs); err != nil {
return err
}
for _, openRef := range diff.OpenedCircuitKeys {
err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
if err != nil {
return err
}
}
numClosedRefs := uint16(len(diff.ClosedCircuitKeys))
if err := binary.Write(w, byteOrder, numClosedRefs); err != nil {
return err
}
for _, closedRef := range diff.ClosedCircuitKeys {
err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
if err != nil {
return err
}
}
// We'll also encode the commit aux data stream here. We do this here
// rather than above (at the call to serializeChanCommit), to ensure
// backwards compat for reads to existing non-custom channels.
auxData := diff.Commitment.extractTlvData()
if err := auxData.encode(w); err != nil {
return fmt.Errorf("unable to write aux data: %w", err)
}
return nil
}
func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
var (
d CommitDiff
err error
)
d.Commitment, err = deserializeChanCommit(r)
if err != nil {
return nil, err
}
var msg lnwire.Message
if err := ReadElements(r, &msg); err != nil {
return nil, err
}
commitSig, ok := msg.(*lnwire.CommitSig)
if !ok {
return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+
"read: %T", msg)
}
d.CommitSig = commitSig
d.LogUpdates, err = deserializeLogUpdates(r)
if err != nil {
return nil, err
}
var numOpenRefs uint16
if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil {
return nil, err
}
d.OpenedCircuitKeys = make([]models.CircuitKey, numOpenRefs)
for i := 0; i < int(numOpenRefs); i++ {
err := ReadElements(r,
&d.OpenedCircuitKeys[i].ChanID,
&d.OpenedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
var numClosedRefs uint16
if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil {
return nil, err
}
d.ClosedCircuitKeys = make([]models.CircuitKey, numClosedRefs)
for i := 0; i < int(numClosedRefs); i++ {
err := ReadElements(r,
&d.ClosedCircuitKeys[i].ChanID,
&d.ClosedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
// As a final step, we'll read out any aux commit data that we have at
// the end of this byte stream. We do this here to ensure backward
// compatibility, as otherwise we risk erroneously reading into the
// wrong field.
var auxData commitTlvData
if err := auxData.decode(r); err != nil {
return nil, fmt.Errorf("unable to decode aux data: %w", err)
}
d.Commitment.amendTlvData(auxData)
return &d, nil
}
// AppendRemoteCommitChain appends a new CommitDiff to the end of the
// commitment chain for the remote party. This method is to be used once we
// have prepared a new commitment state for the remote party, but before we
// transmit it to the remote party. The contents of the argument should be
// sufficient to retransmit the updates and signature needed to reconstruct the
// state in full, in the case that we need to retransmit.
func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
c.Lock()
defer c.Unlock()
// If this is a restored channel, then we want to avoid mutating the
// state at all, as it's impossible to do so in a protocol compliant
// manner.
if c.hasChanStatus(ChanStatusRestored) {
return ErrNoRestoredChannelMutation
}
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
// First, we'll grab the writable bucket where this channel's
// data resides.
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
// If the channel is marked as borked, then for safety reasons,
// we shouldn't attempt any further updates.
isBorked, err := c.isBorked(chanBucket)
if err != nil {
return err
}
if isBorked {
return ErrChanBorked
}
// Any outgoing settles and fails necessarily have a
// corresponding adds in this channel's forwarding packages.
// Mark all of these as being fully processed in our forwarding
// package, which prevents us from reprocessing them after
// startup.
err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...)
if err != nil {
return err
}
// Additionally, we ack from any fails or settles that are
// persisted in another channel's forwarding package. This
// prevents the same fails and settles from being retransmitted
// after restarts. The actual fail or settle we need to
// propagate to the remote party is now in the commit diff.
err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...)
if err != nil {
return err
}
// We are sending a commitment signature so lastWasRevokeKey should
// store false.
var b bytes.Buffer
if err := WriteElements(&b, false); err != nil {
return err
}
if err := chanBucket.Put(lastWasRevokeKey, b.Bytes()); err != nil {
return err
}
// TODO(roasbeef): use seqno to derive key for later LCP
// With the bucket retrieved, we'll now serialize the commit
// diff itself, and write it to disk.
var b2 bytes.Buffer
if err := serializeCommitDiff(&b2, diff); err != nil {
return err
}
return chanBucket.Put(commitDiffKey, b2.Bytes())
}, func() {})
}
// RemoteCommitChainTip returns the "tip" of the current remote commitment
// chain. This value will be non-nil iff, we've created a new commitment for
// the remote party that they haven't yet ACK'd. In this case, their commitment
// chain will have a length of two: their current unrevoked commitment, and
// this new pending commitment. Once they revoked their prior state, we'll swap
// these pointers, causing the tip and the tail to point to the same entry.
func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
var cd *CommitDiff
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch err {
case nil:
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
return ErrNoPendingCommit
default:
return err
}
tipBytes := chanBucket.Get(commitDiffKey)
if tipBytes == nil {
return ErrNoPendingCommit
}
tipReader := bytes.NewReader(tipBytes)
dcd, err := deserializeCommitDiff(tipReader)
if err != nil {
return err
}
cd = dcd
return nil
}, func() {
cd = nil
})
if err != nil {
return nil, err
}
return cd, nil
}
// UnsignedAckedUpdates retrieves the persisted unsigned acked remote log
// updates that still need to be signed for.
func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
var updates []LogUpdate
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch err {
case nil:
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
return nil
default:
return err
}
updateBytes := chanBucket.Get(unsignedAckedUpdatesKey)
if updateBytes == nil {
return nil
}
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
}
return updates, nil
}
// RemoteUnsignedLocalUpdates retrieves the persisted, unsigned local log
// updates that the remote still needs to sign for.
func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
var updates []LogUpdate
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch err {
case nil:
break
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
return nil
default:
return err
}
updateBytes := chanBucket.Get(remoteUnsignedLocalUpdatesKey)
if updateBytes == nil {
return nil
}
r := bytes.NewReader(updateBytes)
updates, err = deserializeLogUpdates(r)
return err
}, func() {
updates = nil
})
if err != nil {
return nil, err
}
return updates, nil
}
// InsertNextRevocation inserts the _next_ commitment point (revocation) into
// the database, and also modifies the internal RemoteNextRevocation attribute
// to point to the passed key. This method is to be using during final channel
// set up, _after_ the channel has been fully confirmed.
//
// NOTE: If this method isn't called, then the target channel won't be able to
// propose new states for the commitment state of the remote party.
func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
c.Lock()
defer c.Unlock()
c.RemoteNextRevocation = revKey
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
return putChanRevocationState(chanBucket, c)
}, func() {})
if err != nil {
return err
}
return nil
}
// AdvanceCommitChainTail records the new state transition within an on-disk
// append-only log which records all state transitions by the remote peer. In
// the case of an uncooperative broadcast of a prior state by the remote peer,
// this log can be consulted in order to reconstruct the state needed to
// rectify the situation. This method will add the current commitment for the
// remote party to the revocation log, and promote the current pending
// commitment to the current remote commitment. The updates parameter is the
// set of local updates that the peer still needs to send us a signature for.
// We store this set of updates in case we go down.
func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
updates []LogUpdate, ourOutputIndex, theirOutputIndex uint32) error {
c.Lock()
defer c.Unlock()
// If this is a restored channel, then we want to avoid mutating the
// state at all, as it's impossible to do so in a protocol compliant
// manner.
if c.hasChanStatus(ChanStatusRestored) {
return ErrNoRestoredChannelMutation
}
var newRemoteCommit *ChannelCommitment
err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
// If the channel is marked as borked, then for safety reasons,
// we shouldn't attempt any further updates.
isBorked, err := c.isBorked(chanBucket)
if err != nil {
return err
}
if isBorked {
return ErrChanBorked
}
// Persist the latest preimage state to disk as the remote peer
// has just added to our local preimage store, and given us a
// new pending revocation key.
if err := putChanRevocationState(chanBucket, c); err != nil {
return err
}
// With the current preimage producer/store state updated,
// append a new log entry recording this the delta of this
// state transition.
//
// TODO(roasbeef): could make the deltas relative, would save
// space, but then tradeoff for more disk-seeks to recover the
// full state.
logKey := revocationLogBucket
logBucket, err := chanBucket.CreateBucketIfNotExists(logKey)
if err != nil {
return err
}
// Before we append this revoked state to the revocation log,
// we'll swap out what's currently the tail of the commit tip,
// with the current locked-in commitment for the remote party.
tipBytes := chanBucket.Get(commitDiffKey)
tipReader := bytes.NewReader(tipBytes)
newCommit, err := deserializeCommitDiff(tipReader)
if err != nil {
return err
}
err = putChanCommitment(
chanBucket, &newCommit.Commitment, false,
)
if err != nil {
return err
}
if err := chanBucket.Delete(commitDiffKey); err != nil {
return err
}
// With the commitment pointer swapped, we can now add the
// revoked (prior) state to the revocation log.
err = putRevocationLog(
logBucket, &c.RemoteCommitment, ourOutputIndex,
theirOutputIndex, c.Db.parent.noRevLogAmtData,
)
if err != nil {
return err
}
// Lastly, we write the forwarding package to disk so that we
// can properly recover from failures and reforward HTLCs that
// have not received a corresponding settle/fail.
if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil {
return err
}
// Persist the unsigned acked updates that are not included
// in their new commitment.
updateBytes := chanBucket.Get(unsignedAckedUpdatesKey)
if updateBytes == nil {
// This shouldn't normally happen as we always store
// the number of updates, but could still be
// encountered by nodes that are upgrading.
newRemoteCommit = &newCommit.Commitment
return nil
}
r := bytes.NewReader(updateBytes)
unsignedUpdates, err := deserializeLogUpdates(r)
if err != nil {
return err
}
var validUpdates []LogUpdate
for _, upd := range unsignedUpdates {
lIdx := upd.LogIndex
// Filter for updates that are not on the remote
// commitment.
if lIdx >= newCommit.Commitment.RemoteLogIndex {
validUpdates = append(validUpdates, upd)
}
}
var b bytes.Buffer
err = serializeLogUpdates(&b, validUpdates)
if err != nil {
return fmt.Errorf("unable to serialize log updates: %w",
err)
}
err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes())
if err != nil {
return fmt.Errorf("unable to store under "+
"unsignedAckedUpdatesKey: %w", err)
}
// Persist the local updates the peer hasn't yet signed so they
// can be restored after restart.
var b2 bytes.Buffer
err = serializeLogUpdates(&b2, updates)
if err != nil {
return err
}
err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b2.Bytes())
if err != nil {
return fmt.Errorf("unable to restore remote unsigned "+
"local updates: %v", err)
}
newRemoteCommit = &newCommit.Commitment
return nil
}, func() {
newRemoteCommit = nil
})
if err != nil {
return err
}
// With the db transaction complete, we'll swap over the in-memory
// pointer of the new remote commitment, which was previously the tip
// of the commit chain.
c.RemoteCommitment = *newRemoteCommit
return nil
}
// FinalHtlcInfo contains information about the final outcome of an htlc.
type FinalHtlcInfo struct {
// Settled is true is the htlc was settled. If false, the htlc was
// failed.
Settled bool
// Offchain indicates whether the htlc was resolved off-chain or
// on-chain.
Offchain bool
}
// putFinalHtlc writes the final htlc outcome to the database. Additionally it
// records whether the htlc was resolved off-chain or on-chain.
func putFinalHtlc(finalHtlcsBucket kvdb.RwBucket, id uint64,
info FinalHtlcInfo) error {
var key [8]byte
byteOrder.PutUint64(key[:], id)
var finalHtlcByte FinalHtlcByte
if info.Settled {
finalHtlcByte |= FinalHtlcSettledBit
}
if info.Offchain {
finalHtlcByte |= FinalHtlcOffchainBit
}
return finalHtlcsBucket.Put(key[:], []byte{byte(finalHtlcByte)})
}
// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure
// this always returns the next index that has been not been allocated, this
// will first try to examine any pending commitments, before falling back to the
// last locked-in remote commitment.
func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) {
// First, load the most recent commit diff that we initiated for the
// remote party. If no pending commit is found, this is not treated as
// a critical error, since we can always fall back.
pendingRemoteCommit, err := c.RemoteCommitChainTip()
if err != nil && err != ErrNoPendingCommit {
return 0, err
}
// If a pending commit was found, its local htlc index will be at least
// as large as the one on our local commitment.
if pendingRemoteCommit != nil {
return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil
}
// Otherwise, fallback to using the local htlc index of their commitment.
return c.RemoteCommitment.LocalHtlcIndex, nil
}
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
// processed, and returns their deserialized log updates in map indexed by the
// remote commitment height at which the updates were locked in.
func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
c.RLock()
defer c.RUnlock()
var fwdPkgs []*FwdPkg
if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}
return fwdPkgs, nil
}
// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs
// indicating that a response to this Add has been committed to the remote party.
// Doing so will prevent these Add HTLCs from being reforwarded internally.
func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckAddHtlcs(tx, addRefs...)
}, func() {})
}
// AckSettleFails updates the SettleFailFilter containing any of the provided
// SettleFailRefs, indicating that the response has been delivered to the
// incoming link, corresponding to a particular AddRef. Doing so will prevent
// the responses from being retransmitted internally.
func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckSettleFails(tx, settleFailRefs...)
}, func() {})
}
// SetFwdFilter atomically sets the forwarding filter for the forwarding package
// identified by `height`.
func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter)
}, func() {})
}
// RemoveFwdPkgs atomically removes forwarding packages specified by the remote
// commitment heights. If one of the intermediate RemovePkg calls fails, then the
// later packages won't be removed.
//
// NOTE: This method should only be called on packages marked FwdStateCompleted.
func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
for _, height := range heights {
err := c.Packager.RemovePkg(tx, height)
if err != nil {
return err
}
}
return nil
}, func() {})
}
// revocationLogTailCommitHeight returns the commit height at the end of the
// revocation log. This entry represents the last previous state for the remote
// node's commitment chain. The ChannelDelta returned by this method will
// always lag one state behind the most current (unrevoked) state of the remote
// node's commitment chain.
// NOTE: used in unit test only.
func (c *OpenChannel) revocationLogTailCommitHeight() (uint64, error) {
c.RLock()
defer c.RUnlock()
var height uint64
// If we haven't created any state updates yet, then we'll exit early as
// there's nothing to be found on disk in the revocation bucket.
if c.RemoteCommitment.CommitHeight == 0 {
return height, nil
}
if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
logBucket, err := fetchLogBucket(chanBucket)
if err != nil {
return err
}
// Once we have the bucket that stores the revocation log from
// this channel, we'll jump to the _last_ key in bucket. Since
// the key is the commit height, we'll decode the bytes and
// return it.
cursor := logBucket.ReadCursor()
rawHeight, _ := cursor.Last()
height = byteOrder.Uint64(rawHeight)
return nil
}, func() {}); err != nil {
return height, err
}
return height, nil
}
// CommitmentHeight returns the current commitment height. The commitment
// height represents the number of updates to the commitment state to date.
// This value is always monotonically increasing. This method is provided in
// order to allow multiple instances of a particular open channel to obtain a
// consistent view of the number of channel updates to date.
func (c *OpenChannel) CommitmentHeight() (uint64, error) {
c.RLock()
defer c.RUnlock()
var height uint64
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
commit, err := fetchChanCommitment(chanBucket, true)
if err != nil {
return err
}
height = commit.CommitHeight
return nil
}, func() {
height = 0
})
if err != nil {
return 0, err
}
return height, nil
}
// FindPreviousState scans through the append-only log in an attempt to recover
// the previous channel state indicated by the update number. This method is
// intended to be used for obtaining the relevant data needed to claim all
// funds rightfully spendable in the case of an on-chain broadcast of the
// commitment transaction.
func (c *OpenChannel) FindPreviousState(
updateNum uint64) (*RevocationLog, *ChannelCommitment, error) {
c.RLock()
defer c.RUnlock()
commit := &ChannelCommitment{}
rl := &RevocationLog{}
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
// Find the revocation log from both the new and the old
// bucket.
r, c, err := fetchRevocationLogCompatible(chanBucket, updateNum)
if err != nil {
return err
}
rl = r
commit = c
return nil
}, func() {})
if err != nil {
return nil, nil, err
}
// Either the `rl` or the `commit` is nil here. We return them as-is
// and leave it to the caller to decide its following action.
return rl, commit, nil
}
// ClosureType is an enum like structure that details exactly _how_ a channel
// was closed. Three closure types are currently possible: none, cooperative,
// local force close, remote force close, and (remote) breach.
type ClosureType uint8
const (
// CooperativeClose indicates that a channel has been closed
// cooperatively. This means that both channel peers were online and
// signed a new transaction paying out the settled balance of the
// contract.
CooperativeClose ClosureType = 0
// LocalForceClose indicates that we have unilaterally broadcast our
// current commitment state on-chain.
LocalForceClose ClosureType = 1
// RemoteForceClose indicates that the remote peer has unilaterally
// broadcast their current commitment state on-chain.
RemoteForceClose ClosureType = 4
// BreachClose indicates that the remote peer attempted to broadcast a
// prior _revoked_ channel state.
BreachClose ClosureType = 2
// FundingCanceled indicates that the channel never was fully opened
// before it was marked as closed in the database. This can happen if
// we or the remote fail at some point during the opening workflow, or
// we timeout waiting for the funding transaction to be confirmed.
FundingCanceled ClosureType = 3
// Abandoned indicates that the channel state was removed without
// any further actions. This is intended to clean up unusable
// channels during development.
Abandoned ClosureType = 5
)
// ChannelCloseSummary contains the final state of a channel at the point it
// was closed. Once a channel is closed, all the information pertaining to that
// channel within the openChannelBucket is deleted, and a compact summary is
// put in place instead.
type ChannelCloseSummary struct {
// ChanPoint is the outpoint for this channel's funding transaction,
// and is used as a unique identifier for the channel.
ChanPoint wire.OutPoint
// ShortChanID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
ShortChanID lnwire.ShortChannelID
// ChainHash is the hash of the genesis block that this channel resides
// within.
ChainHash chainhash.Hash
// ClosingTXID is the txid of the transaction which ultimately closed
// this channel.
ClosingTXID chainhash.Hash
// RemotePub is the public key of the remote peer that we formerly had
// a channel with.
RemotePub *btcec.PublicKey
// Capacity was the total capacity of the channel.
Capacity btcutil.Amount
// CloseHeight is the height at which the funding transaction was
// spent.
CloseHeight uint32
// SettledBalance is our total balance settled balance at the time of
// channel closure. This _does not_ include the sum of any outputs that
// have been time-locked as a result of the unilateral channel closure.
SettledBalance btcutil.Amount
// TimeLockedBalance is the sum of all the time-locked outputs at the
// time of channel closure. If we triggered the force closure of this
// channel, then this value will be non-zero if our settled output is
// above the dust limit. If we were on the receiving side of a channel
// force closure, then this value will be non-zero if we had any
// outstanding outgoing HTLC's at the time of channel closure.
TimeLockedBalance btcutil.Amount
// CloseType details exactly _how_ the channel was closed. Five closure
// types are possible: cooperative, local force, remote force, breach
// and funding canceled.
CloseType ClosureType
// IsPending indicates whether this channel is in the 'pending close'
// state, which means the channel closing transaction has been
// confirmed, but not yet been fully resolved. In the case of a channel
// that has been cooperatively closed, it will go straight into the
// fully resolved state as soon as the closing transaction has been
// confirmed. However, for channels that have been force closed, they'll
// stay marked as "pending" until _all_ the pending funds have been
// swept.
IsPending bool
// RemoteCurrentRevocation is the current revocation for their
// commitment transaction. However, since this is the derived public key,
// we don't yet have the private key so we aren't yet able to verify
// that it's actually in the hash chain.
RemoteCurrentRevocation *btcec.PublicKey
// RemoteNextRevocation is the revocation key to be used for the *next*
// commitment transaction we create for the local node. Within the
// specification, this value is referred to as the
// per-commitment-point.
RemoteNextRevocation *btcec.PublicKey
// LocalChanConfig is the channel configuration for the local node.
LocalChanConfig ChannelConfig
// LastChanSyncMsg is the ChannelReestablish message for this channel
// for the state at the point where it was closed.
LastChanSyncMsg *lnwire.ChannelReestablish
}
// CloseChannel closes a previously active Lightning channel. Closing a channel
// entails deleting all saved state within the database concerning this
// channel. This method also takes a struct that summarizes the state of the
// channel at closing, this compact representation will be the only component
// of a channel left over after a full closing. It takes an optional set of
// channel statuses which will be written to the historical channel bucket.
// These statuses are used to record close initiators.
func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
statuses ...ChannelStatus) error {
c.Lock()
defer c.Unlock()
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoChanDBExists
}
nodePub := c.IdentityPub.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return ErrNoActiveChannels
}
chainBucket := nodeChanBucket.NestedReadWriteBucket(c.ChainHash[:])
if chainBucket == nil {
return ErrNoActiveChannels
}
var chanPointBuf bytes.Buffer
err := graphdb.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint)
if err != nil {
return err
}
chanKey := chanPointBuf.Bytes()
chanBucket := chainBucket.NestedReadWriteBucket(
chanKey,
)
if chanBucket == nil {
return ErrNoActiveChannels
}
// Before we delete the channel state, we'll read out the full
// details, as we'll also store portions of this information
// for record keeping.
chanState, err := fetchOpenChannel(
chanBucket, &c.FundingOutpoint,
)
if err != nil {
return err
}
// Delete all the forwarding packages stored for this particular
// channel.
if err = chanState.Packager.Wipe(tx); err != nil {
return err
}
// Now that the index to this channel has been deleted, purge
// the remaining channel metadata from the database.
err = deleteOpenChannel(chanBucket)
if err != nil {
return err
}
// We'll also remove the channel from the frozen channel bucket
// if we need to.
if c.ChanType.IsFrozen() || c.ChanType.HasLeaseExpiration() {
err := deleteThawHeight(chanBucket)
if err != nil {
return err
}
}
// With the base channel data deleted, attempt to delete the
// information stored within the revocation log.
if err := deleteLogBucket(chanBucket); err != nil {
return err
}
err = chainBucket.DeleteNestedBucket(chanPointBuf.Bytes())
if err != nil {
return err
}
// Fetch the outpoint bucket to see if the outpoint exists or
// not.
opBucket := tx.ReadWriteBucket(outpointBucket)
if opBucket == nil {
return ErrNoChanDBExists
}
// Add the closed outpoint to our outpoint index. This should
// replace an open outpoint in the index.
if opBucket.Get(chanPointBuf.Bytes()) == nil {
return ErrMissingIndexEntry
}
status := uint8(outpointClosed)
// Write the IndexStatus of this outpoint as the first entry in a tlv
// stream.
statusRecord := tlv.MakePrimitiveRecord(indexStatusType, &status)
opStream, err := tlv.NewStream(statusRecord)
if err != nil {
return err
}
var b bytes.Buffer
if err := opStream.Encode(&b); err != nil {
return err
}
// Finally add the closed outpoint and tlv stream to the index.
if err := opBucket.Put(chanPointBuf.Bytes(), b.Bytes()); err != nil {
return err
}
// Add channel state to the historical channel bucket.
historicalBucket, err := tx.CreateTopLevelBucket(
historicalChannelBucket,
)
if err != nil {
return err
}
historicalChanBucket, err :=
historicalBucket.CreateBucketIfNotExists(chanKey)
if err != nil {
return err
}
// Apply any additional statuses to the channel state.
for _, status := range statuses {
chanState.chanStatus |= status
}
err = putOpenChannel(historicalChanBucket, chanState)
if err != nil {
return err
}
// Finally, create a summary of this channel in the closed
// channel bucket for this node.
return putChannelCloseSummary(
tx, chanPointBuf.Bytes(), summary, chanState,
)
}, func() {})
}
// ChannelSnapshot is a frozen snapshot of the current channel state. A
// snapshot is detached from the original channel that generated it, providing
// read-only access to the current or prior state of an active channel.
//
// TODO(roasbeef): remove all together? pretty much just commitment
type ChannelSnapshot struct {
// RemoteIdentity is the identity public key of the remote node that we
// are maintaining the open channel with.
RemoteIdentity btcec.PublicKey
// ChanPoint is the outpoint that created the channel. This output is
// found within the funding transaction and uniquely identified the
// channel on the resident chain.
ChannelPoint wire.OutPoint
// ChainHash is the genesis hash of the chain that the channel resides
// within.
ChainHash chainhash.Hash
// Capacity is the total capacity of the channel.
Capacity btcutil.Amount
// TotalMSatSent is the total number of milli-satoshis we've sent
// within this channel.
TotalMSatSent lnwire.MilliSatoshi
// TotalMSatReceived is the total number of milli-satoshis we've
// received within this channel.
TotalMSatReceived lnwire.MilliSatoshi
// ChannelCommitment is the current up-to-date commitment for the
// target channel.
ChannelCommitment
}
// Snapshot returns a read-only snapshot of the current channel state. This
// snapshot includes information concerning the current settled balance within
// the channel, metadata detailing total flows, and any outstanding HTLCs.
func (c *OpenChannel) Snapshot() *ChannelSnapshot {
c.RLock()
defer c.RUnlock()
localCommit := c.LocalCommitment
snapshot := &ChannelSnapshot{
RemoteIdentity: *c.IdentityPub,
ChannelPoint: c.FundingOutpoint,
Capacity: c.Capacity,
TotalMSatSent: c.TotalMSatSent,
TotalMSatReceived: c.TotalMSatReceived,
ChainHash: c.ChainHash,
ChannelCommitment: ChannelCommitment{
LocalBalance: localCommit.LocalBalance,
RemoteBalance: localCommit.RemoteBalance,
CommitHeight: localCommit.CommitHeight,
CommitFee: localCommit.CommitFee,
},
}
localCommit.CustomBlob.WhenSome(func(blob tlv.Blob) {
blobCopy := make([]byte, len(blob))
copy(blobCopy, blob)
snapshot.ChannelCommitment.CustomBlob = fn.Some(blobCopy)
})
// Copy over the current set of HTLCs to ensure the caller can't mutate
// our internal state.
snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs))
for i, h := range localCommit.Htlcs {
snapshot.Htlcs[i] = h.Copy()
}
return snapshot
}
// LatestCommitments returns the two latest commitments for both the local and
// remote party. These commitments are read from disk to ensure that only the
// latest fully committed state is returned. The first commitment returned is
// the local commitment, and the second returned is the remote commitment.
func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
return fetchChanCommitments(chanBucket, c)
}, func() {})
if err != nil {
return nil, nil, err
}
return &c.LocalCommitment, &c.RemoteCommitment, nil
}
// RemoteRevocationStore returns the most up to date commitment version of the
// revocation storage tree for the remote party. This method can be used when
// acting on a possible contract breach to ensure, that the caller has the most
// up to date information required to deliver justice.
func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
return fetchChanRevocationState(chanBucket, c)
}, func() {})
if err != nil {
return nil, err
}
return c.RevocationStore, nil
}
// AbsoluteThawHeight determines a frozen channel's absolute thaw height. If the
// channel is not frozen, then 0 is returned.
func (c *OpenChannel) AbsoluteThawHeight() (uint32, error) {
// Only frozen channels have a thaw height.
if !c.ChanType.IsFrozen() && !c.ChanType.HasLeaseExpiration() {
return 0, nil
}
// If the channel has the frozen bit set and it's thaw height is below
// the absolute threshold, then it's interpreted as a relative height to
// the chain's current height.
if c.ChanType.IsFrozen() && c.ThawHeight < AbsoluteThawHeightThreshold {
// We'll only known of the channel's short ID once it's
// confirmed.
if c.IsPending {
return 0, errors.New("cannot use relative thaw " +
"height for unconfirmed channel")
}
// For non-zero-conf channels, this is the base height to use.
blockHeightBase := c.ShortChannelID.BlockHeight
// If this is a zero-conf channel, the ShortChannelID will be
// an alias.
if c.IsZeroConf() {
if !c.ZeroConfConfirmed() {
return 0, errors.New("cannot use relative " +
"height for unconfirmed zero-conf " +
"channel")
}
// Use the confirmed SCID's BlockHeight.
blockHeightBase = c.confirmedScid.BlockHeight
}
return blockHeightBase + c.ThawHeight, nil
}
return c.ThawHeight, nil
}
// DeriveHeightHint derives the block height for the channel opening.
func (c *OpenChannel) DeriveHeightHint() uint32 {
// As a height hint, we'll try to use the opening height, but if the
// channel isn't yet open, then we'll use the height it was broadcast
// at. This may be an unconfirmed zero-conf channel.
heightHint := c.ShortChanID().BlockHeight
if heightHint == 0 {
heightHint = c.BroadcastHeight()
}
// Since no zero-conf state is stored in a channel backup, the below
// logic will not be triggered for restored, zero-conf channels. Set
// the height hint for zero-conf channels.
if c.IsZeroConf() {
if c.ZeroConfConfirmed() {
// If the zero-conf channel is confirmed, we'll use the
// confirmed SCID's block height.
heightHint = c.ZeroConfRealScid().BlockHeight
} else {
// The zero-conf channel is unconfirmed. We'll need to
// use the FundingBroadcastHeight.
heightHint = c.BroadcastHeight()
}
}
return heightHint
}
func putChannelCloseSummary(tx kvdb.RwTx, chanID []byte,
summary *ChannelCloseSummary, lastChanState *OpenChannel) error {
closedChanBucket, err := tx.CreateTopLevelBucket(closedChannelBucket)
if err != nil {
return err
}
summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation
summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation
summary.LocalChanConfig = lastChanState.LocalChanCfg
var b bytes.Buffer
if err := serializeChannelCloseSummary(&b, summary); err != nil {
return err
}
return closedChanBucket.Put(chanID, b.Bytes())
}
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
)
if err != nil {
return err
}
// If this is a close channel summary created before the addition of
// the new fields, then we can exit here.
if cs.RemoteCurrentRevocation == nil {
return WriteElements(w, false)
}
// If fields are present, write boolean to indicate this, and continue.
if err := WriteElements(w, true); err != nil {
return err
}
if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err
}
if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil {
return err
}
// The RemoteNextRevocation field is optional, as it's possible for a
// channel to be closed before we learn of the next unrevoked
// revocation point for the remote party. Write a boolean indicating
// whether this field is present or not.
if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil {
return err
}
// Write the field, if present.
if cs.RemoteNextRevocation != nil {
if err = WriteElements(w, cs.RemoteNextRevocation); err != nil {
return err
}
}
// Write whether the channel sync message is present.
if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil {
return err
}
// Write the channel sync message, if present.
if cs.LastChanSyncMsg != nil {
if err := WriteElements(w, cs.LastChanSyncMsg); err != nil {
return err
}
}
return nil
}
func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
var hasNewFields bool
err = ReadElements(r, &hasNewFields)
if err != nil {
return nil, err
}
// If fields are not present, we can return.
if !hasNewFields {
return c, nil
}
// Otherwise read the new fields.
if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil {
return nil, err
}
if err := readChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// channel_ready message then this might not be present. A boolean
// indicating whether the field is present will come first.
var hasRemoteNextRevocation bool
err = ReadElements(r, &hasRemoteNextRevocation)
if err != nil {
return nil, err
}
// If this field was written, read it.
if hasRemoteNextRevocation {
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil {
return nil, err
}
}
// Check if we have a channel sync message to read.
var hasChanSyncMsg bool
err = ReadElements(r, &hasChanSyncMsg)
if err == io.EOF {
return c, nil
} else if err != nil {
return nil, err
}
// If a chan sync message is present, read it.
if hasChanSyncMsg {
// We must pass in reference to a lnwire.Message for the codec
// to support it.
var msg lnwire.Message
if err := ReadElements(r, &msg); err != nil {
return nil, err
}
chanSync, ok := msg.(*lnwire.ChannelReestablish)
if !ok {
return nil, errors.New("unable cast db Message to " +
"ChannelReestablish")
}
c.LastChanSyncMsg = chanSync
}
return c, nil
}
func writeChanConfig(b io.Writer, c *ChannelConfig) error {
return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
c.HtlcBasePoint,
)
}
// fundingTxPresent returns true if expect the funding transcation to be found
// on disk or already populated within the passed open channel struct.
func fundingTxPresent(channel *OpenChannel) bool {
chanType := channel.ChanType
return chanType.IsSingleFunder() && chanType.HasFundingTx() &&
channel.IsInitiator &&
!channel.hasChanStatus(ChanStatusRestored)
}
func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
var w bytes.Buffer
if err := WriteElements(&w,
channel.ChanType, channel.ChainHash, channel.FundingOutpoint,
channel.ShortChannelID, channel.IsPending, channel.IsInitiator,
channel.chanStatus, channel.FundingBroadcastHeight,
channel.NumConfsRequired, channel.ChannelFlags,
channel.IdentityPub, channel.Capacity, channel.TotalMSatSent,
channel.TotalMSatReceived,
); err != nil {
return err
}
// For single funder channels that we initiated, and we have the
// funding transaction, then write the funding txn.
if fundingTxPresent(channel) {
if err := WriteElement(&w, channel.FundingTxn); err != nil {
return err
}
}
if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil {
return err
}
if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil {
return err
}
auxData := channel.extractTlvData()
if err := auxData.encode(&w); err != nil {
return fmt.Errorf("unable to encode aux data: %w", err)
}
if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
return err
}
// Finally, add optional shutdown scripts for the local and remote peer if
// they are present.
if err := putOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, channel.LocalShutdownScript,
); err != nil {
return err
}
return putOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, channel.RemoteShutdownScript,
)
}
// putOptionalUpfrontShutdownScript adds a shutdown script under the key
// provided if it has a non-zero length.
func putOptionalUpfrontShutdownScript(chanBucket kvdb.RwBucket, key []byte,
script []byte) error {
// If the script is empty, we do not need to add anything.
if len(script) == 0 {
return nil
}
var w bytes.Buffer
if err := WriteElement(&w, script); err != nil {
return err
}
return chanBucket.Put(key, w.Bytes())
}
// getOptionalUpfrontShutdownScript reads the shutdown script stored under the
// key provided if it is present. Upfront shutdown scripts are optional, so the
// function returns with no error if the key is not present.
func getOptionalUpfrontShutdownScript(chanBucket kvdb.RBucket, key []byte,
script *lnwire.DeliveryAddress) error {
// Return early if the bucket does not exit, a shutdown script was not set.
bs := chanBucket.Get(key)
if bs == nil {
return nil
}
var tempScript []byte
r := bytes.NewReader(bs)
if err := ReadElement(r, &tempScript); err != nil {
return err
}
*script = tempScript
return nil
}
func serializeChanCommit(w io.Writer, c *ChannelCommitment) error {
if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
c.CommitSig,
); err != nil {
return err
}
return SerializeHtlcs(w, c.Htlcs...)
}
func putChanCommitment(chanBucket kvdb.RwBucket, c *ChannelCommitment,
local bool) error {
var commitKey []byte
if local {
commitKey = append(chanCommitmentKey, byte(0x00))
} else {
commitKey = append(chanCommitmentKey, byte(0x01))
}
var b bytes.Buffer
if err := serializeChanCommit(&b, c); err != nil {
return err
}
// Before we write to disk, we'll also write our aux data as well.
auxData := c.extractTlvData()
if err := auxData.encode(&b); err != nil {
return fmt.Errorf("unable to write aux data: %w", err)
}
return chanBucket.Put(commitKey, b.Bytes())
}
func putChanCommitments(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
// If this is a restored channel, then we don't have any commitments to
// write.
if channel.hasChanStatus(ChanStatusRestored) {
return nil
}
err := putChanCommitment(
chanBucket, &channel.LocalCommitment, true,
)
if err != nil {
return err
}
return putChanCommitment(
chanBucket, &channel.RemoteCommitment, false,
)
}
func putChanRevocationState(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
var b bytes.Buffer
err := WriteElements(
&b, channel.RemoteCurrentRevocation, channel.RevocationProducer,
channel.RevocationStore,
)
if err != nil {
return err
}
// If the next revocation is present, which is only the case after the
// ChannelReady message has been sent, then we'll write it to disk.
if channel.RemoteNextRevocation != nil {
err = WriteElements(&b, channel.RemoteNextRevocation)
if err != nil {
return err
}
}
return chanBucket.Put(revocationStateKey, b.Bytes())
}
func readChanConfig(b io.Reader, c *ChannelConfig) error {
return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint,
&c.PaymentBasePoint, &c.DelayBasePoint,
&c.HtlcBasePoint,
)
}
func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
infoBytes := chanBucket.Get(chanInfoKey)
if infoBytes == nil {
return ErrNoChanInfoFound
}
r := bytes.NewReader(infoBytes)
if err := ReadElements(r,
&channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint,
&channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator,
&channel.chanStatus, &channel.FundingBroadcastHeight,
&channel.NumConfsRequired, &channel.ChannelFlags,
&channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent,
&channel.TotalMSatReceived,
); err != nil {
return err
}
// For single funder channels that we initiated and have the funding
// transaction to, read the funding txn.
if fundingTxPresent(channel) {
if err := ReadElement(r, &channel.FundingTxn); err != nil {
return err
}
}
if err := readChanConfig(r, &channel.LocalChanCfg); err != nil {
return err
}
if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil {
return err
}
// Retrieve the boolean stored under lastWasRevokeKey.
lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey)
if lastWasRevokeBytes == nil {
// If nothing has been stored under this key, we store false in the
// OpenChannel struct.
channel.LastWasRevoke = false
} else {
// Otherwise, read the value into the LastWasRevoke field.
revokeReader := bytes.NewReader(lastWasRevokeBytes)
err := ReadElements(revokeReader, &channel.LastWasRevoke)
if err != nil {
return err
}
}
var auxData openChannelTlvData
if err := auxData.decode(r); err != nil {
return fmt.Errorf("unable to decode aux data: %w", err)
}
// Assign all the relevant fields from the aux data into the actual
// open channel.
channel.amendTlvData(auxData)
channel.Packager = NewChannelPackager(channel.ShortChannelID)
// Finally, read the optional shutdown scripts.
if err := getOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, &channel.LocalShutdownScript,
); err != nil {
return err
}
return getOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, &channel.RemoteShutdownScript,
)
}
func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
var c ChannelCommitment
err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
)
if err != nil {
return c, err
}
c.Htlcs, err = DeserializeHtlcs(r)
if err != nil {
return c, err
}
return c, nil
}
func fetchChanCommitment(chanBucket kvdb.RBucket,
local bool) (ChannelCommitment, error) {
var commitKey []byte
if local {
commitKey = append(chanCommitmentKey, byte(0x00))
} else {
commitKey = append(chanCommitmentKey, byte(0x01))
}
commitBytes := chanBucket.Get(commitKey)
if commitBytes == nil {
return ChannelCommitment{}, ErrNoCommitmentsFound
}
r := bytes.NewReader(commitBytes)
chanCommit, err := deserializeChanCommit(r)
if err != nil {
return ChannelCommitment{}, fmt.Errorf("unable to decode "+
"chan commit: %w", err)
}
// We'll also check to see if we have any aux data stored as the end of
// the stream.
var auxData commitTlvData
if err := auxData.decode(r); err != nil {
return ChannelCommitment{}, fmt.Errorf("unable to decode "+
"chan aux data: %w", err)
}
chanCommit.amendTlvData(auxData)
return chanCommit, nil
}
func fetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error {
var err error
// If this is a restored channel, then we don't have any commitments to
// read.
if channel.hasChanStatus(ChanStatusRestored) {
return nil
}
channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true)
if err != nil {
return err
}
channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false)
if err != nil {
return err
}
return nil
}
func fetchChanRevocationState(chanBucket kvdb.RBucket, channel *OpenChannel) error {
revBytes := chanBucket.Get(revocationStateKey)
if revBytes == nil {
return ErrNoRevocationsFound
}
r := bytes.NewReader(revBytes)
err := ReadElements(
r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer,
&channel.RevocationStore,
)
if err != nil {
return err
}
// If there aren't any bytes left in the buffer, then we don't yet have
// the next remote revocation, so we can exit early here.
if r.Len() == 0 {
return nil
}
// Otherwise we'll read the next revocation for the remote party which
// is always the last item within the buffer.
return ReadElements(r, &channel.RemoteNextRevocation)
}
func deleteOpenChannel(chanBucket kvdb.RwBucket) error {
if err := chanBucket.Delete(chanInfoKey); err != nil {
return err
}
err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00)))
if err != nil {
return err
}
err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01)))
if err != nil {
return err
}
if err := chanBucket.Delete(revocationStateKey); err != nil {
return err
}
if diff := chanBucket.Get(commitDiffKey); diff != nil {
return chanBucket.Delete(commitDiffKey)
}
return nil
}
// makeLogKey converts a uint64 into an 8 byte array.
func makeLogKey(updateNum uint64) [8]byte {
var key [8]byte
byteOrder.PutUint64(key[:], updateNum)
return key
}
func fetchThawHeight(chanBucket kvdb.RBucket) (uint32, error) {
var height uint32
heightBytes := chanBucket.Get(frozenChanKey)
heightReader := bytes.NewReader(heightBytes)
if err := ReadElements(heightReader, &height); err != nil {
return 0, err
}
return height, nil
}
func storeThawHeight(chanBucket kvdb.RwBucket, height uint32) error {
var heightBuf bytes.Buffer
if err := WriteElements(&heightBuf, height); err != nil {
return err
}
return chanBucket.Put(frozenChanKey, heightBuf.Bytes())
}
func deleteThawHeight(chanBucket kvdb.RwBucket) error {
return chanBucket.Delete(frozenChanKey)
}
// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
// tlv.RecordProducer interface.
type keyLocRecord struct {
keychain.KeyLocator
}
// Record creates a Record out of a KeyLocator using the passed Type and the
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
// KeyFamily is uint32 and the Index is uint32.
//
// NOTE: This is part of the tlv.RecordProducer interface.
func (k *keyLocRecord) Record() tlv.Record {
// Note that we set the type here as zero, as when used with a
// tlv.RecordT, the type param will be used as the type.
return tlv.MakeStaticRecord(
0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator,
)
}
// EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok {
err := tlv.EUint32T(w, uint32(v.Family), buf)
if err != nil {
return err
}
return tlv.EUint32T(w, v.Index, buf)
}
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
}
// DKeyLocator is a decoder for keychain.KeyLocator.
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*keychain.KeyLocator); ok {
var family uint32
err := tlv.DUint32(r, &family, buf, 4)
if err != nil {
return err
}
v.Family = keychain.KeyFamily(family)
return tlv.DUint32(r, &v.Index, buf, 4)
}
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
}
// ShutdownInfo contains various info about the shutdown initiation of a
// channel.
type ShutdownInfo struct {
// DeliveryScript is the address that we have included in any previous
// Shutdown message for a particular channel and so should include in
// any future re-sends of the Shutdown message.
DeliveryScript tlv.RecordT[tlv.TlvType0, lnwire.DeliveryAddress]
// LocalInitiator is true if we sent a Shutdown message before ever
// receiving a Shutdown message from the remote peer.
LocalInitiator tlv.RecordT[tlv.TlvType1, bool]
}
// NewShutdownInfo constructs a new ShutdownInfo object.
func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress,
locallyInitiated bool) *ShutdownInfo {
return &ShutdownInfo{
DeliveryScript: tlv.NewRecordT[tlv.TlvType0](deliveryScript),
LocalInitiator: tlv.NewPrimitiveRecord[tlv.TlvType1](
locallyInitiated,
),
}
}
// Closer identifies the ChannelParty that initiated the coop-closure process.
func (s ShutdownInfo) Closer() lntypes.ChannelParty {
if s.LocalInitiator.Val {
return lntypes.Local
}
return lntypes.Remote
}
// encode serialises the ShutdownInfo to the given io.Writer.
func (s *ShutdownInfo) encode(w io.Writer) error {
records := []tlv.Record{
s.DeliveryScript.Record(),
s.LocalInitiator.Record(),
}
stream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return stream.Encode(w)
}
// decodeShutdownInfo constructs a ShutdownInfo struct by decoding the given
// byte slice.
func decodeShutdownInfo(b []byte) (*ShutdownInfo, error) {
tlvStream := lnwire.ExtraOpaqueData(b)
var info ShutdownInfo
records := []tlv.RecordProducer{
&info.DeliveryScript,
&info.LocalInitiator,
}
_, err := tlvStream.ExtractRecords(records...)
return &info, err
}
package channeldb
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// NewUnknownElementType creates a new UnknownElementType error from the passed
// method name and element.
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
return UnknownElementType{method: method, element: el}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil {
return err
}
if err := binary.Write(w, byteOrder, e.Index); err != nil {
return err
}
if e.PubKey != nil {
if err := binary.Write(w, byteOrder, true); err != nil {
return fmt.Errorf("error writing serialized "+
"element: %w", err)
}
return WriteElement(w, e.PubKey)
}
return binary.Write(w, byteOrder, false)
case ChannelType:
var buf [8]byte
if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil {
return err
}
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case wire.OutPoint:
return graphdb.WriteOutpoint(w, &e)
case lnwire.ShortChannelID:
if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil {
return err
}
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case int64, uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case int32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint16:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint8:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case btcutil.Amount:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case lnwire.MilliSatoshi:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case *btcec.PrivateKey:
b := e.Serialize()
if _, err := w.Write(b); err != nil {
return err
}
case *btcec.PublicKey:
b := e.SerializeCompressed()
if _, err := w.Write(b); err != nil {
return err
}
case shachain.Producer:
return e.Encode(w)
case shachain.Store:
return e.Encode(w)
case *wire.MsgTx:
return e.Serialize(w)
case [32]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case lnwire.Message:
var msgBuf bytes.Buffer
if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil {
return err
}
msgLen := uint16(len(msgBuf.Bytes()))
if err := WriteElements(w, msgLen); err != nil {
return err
}
if _, err := w.Write(msgBuf.Bytes()); err != nil {
return err
}
case ChannelStatus:
var buf [8]byte
if err := tlv.WriteVarInt(w, uint64(e), &buf); err != nil {
return err
}
case ClosureType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case paymentIndexType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case lnwire.FundingFlag:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case net.Addr:
if err := graphdb.SerializeAddr(w, e); err != nil {
return err
}
case []net.Addr:
if err := WriteElement(w, uint32(len(e))); err != nil {
return err
}
for _, addr := range e {
if err := graphdb.SerializeAddr(w, addr); err != nil {
return err
}
}
default:
return UnknownElementType{"WriteElement", e}
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &e.Index); err != nil {
return err
}
var hasPubKey bool
if err := binary.Read(r, byteOrder, &hasPubKey); err != nil {
return err
}
if hasPubKey {
return ReadElement(r, &e.PubKey)
}
case *ChannelType:
var buf [8]byte
ctype, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
*e = ChannelType(ctype)
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wire.OutPoint:
return graphdb.ReadOutpoint(r, e)
case *lnwire.ShortChannelID:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *int64, *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *int32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint16:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint8:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *btcutil.Amount:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = btcutil.Amount(a)
case *lnwire.MilliSatoshi:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.MilliSatoshi(a)
case **btcec.PrivateKey:
var b [btcec.PrivKeyBytesLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
priv, _ := btcec.PrivKeyFromBytes(b[:])
*e = priv
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case *shachain.Producer:
var root [32]byte
if _, err := io.ReadFull(r, root[:]); err != nil {
return err
}
// TODO(roasbeef): remove
producer, err := shachain.NewRevocationProducerFromBytes(root[:])
if err != nil {
return err
}
*e = producer
case *shachain.Store:
store, err := shachain.NewRevocationStoreFromBytes(r)
if err != nil {
return err
}
*e = store
case **wire.MsgTx:
tx := wire.NewMsgTx(2)
if err := tx.Deserialize(r); err != nil {
return err
}
*e = tx
case *[32]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *lnwire.Message:
var msgLen uint16
if err := ReadElement(r, &msgLen); err != nil {
return err
}
msgReader := io.LimitReader(r, int64(msgLen))
msg, err := lnwire.ReadMessage(msgReader, 0)
if err != nil {
return err
}
*e = msg
case *ChannelStatus:
var buf [8]byte
status, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
*e = ChannelStatus(status)
case *ClosureType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *paymentIndexType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *lnwire.FundingFlag:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *net.Addr:
addr, err := graphdb.DeserializeAddr(r)
if err != nil {
return err
}
*e = addr
case *[]net.Addr:
var numAddrs uint32
if err := ReadElement(r, &numAddrs); err != nil {
return err
}
*e = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := graphdb.DeserializeAddr(r)
if err != nil {
return err
}
(*e)[i] = addr
}
default:
return UnknownElementType{"ReadElement", e}
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
package channeldb
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"os"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/go-errors/errors"
mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration13"
"github.com/lightningnetwork/lnd/channeldb/migration16"
"github.com/lightningnetwork/lnd/channeldb/migration20"
"github.com/lightningnetwork/lnd/channeldb/migration21"
"github.com/lightningnetwork/lnd/channeldb/migration23"
"github.com/lightningnetwork/lnd/channeldb/migration24"
"github.com/lightningnetwork/lnd/channeldb/migration25"
"github.com/lightningnetwork/lnd/channeldb/migration26"
"github.com/lightningnetwork/lnd/channeldb/migration27"
"github.com/lightningnetwork/lnd/channeldb/migration29"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration33"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/clock"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
const (
dbName = "channel.db"
)
var (
// ErrDryRunMigrationOK signals that a migration executed successful,
// but we intentionally did not commit the result.
ErrDryRunMigrationOK = errors.New("dry run migration successful")
// ErrFinalHtlcsBucketNotFound signals that the top-level final htlcs
// bucket does not exist.
ErrFinalHtlcsBucketNotFound = errors.New("final htlcs bucket not " +
"found")
// ErrFinalChannelBucketNotFound signals that the channel bucket for
// final htlc outcomes does not exist.
ErrFinalChannelBucketNotFound = errors.New("final htlcs channel " +
"bucket not found")
)
// migration is a function which takes a prior outdated version of the database
// instances and mutates the key/bucket structure to arrive at a more
// up-to-date version of the database.
type migration func(tx kvdb.RwTx) error
// mandatoryVersion defines a db version that must be applied before the lnd
// starts.
type mandatoryVersion struct {
number uint32
migration migration
}
// MigrationConfig is an interface combines the config interfaces of all
// optional migrations.
type MigrationConfig interface {
migration30.MigrateRevLogConfig
}
// MigrationConfigImpl is a super set of all the various migration configs and
// an implementation of MigrationConfig.
type MigrationConfigImpl struct {
migration30.MigrateRevLogConfigImpl
}
// optionalMigration defines an optional migration function. When a migration
// is optional, it usually involves a large scale of changes that might touch
// millions of keys. Due to OOM concern, the update cannot be safely done
// within one db transaction. Thus, for optional migrations, they must take the
// db backend and construct transactions as needed.
type optionalMigration func(db kvdb.Backend, cfg MigrationConfig) error
// optionalVersion defines a db version that can be optionally applied. When
// applying migrations, we must apply all the mandatory migrations first before
// attempting optional ones.
type optionalVersion struct {
name string
migration optionalMigration
}
var (
// dbVersions is storing all mandatory versions of database. If current
// version of database don't match with latest version this list will
// be used for retrieving all migration function that are need to apply
// to the current db.
dbVersions = []mandatoryVersion{
{
// The base DB version requires no migration.
number: 0,
migration: nil,
},
{
// The version of the database where two new indexes
// for the update time of node and channel updates were
// added.
number: 1,
migration: migration_01_to_11.MigrateNodeAndEdgeUpdateIndex,
},
{
// The DB version that added the invoice event time
// series.
number: 2,
migration: migration_01_to_11.MigrateInvoiceTimeSeries,
},
{
// The DB version that updated the embedded invoice in
// outgoing payments to match the new format.
number: 3,
migration: migration_01_to_11.MigrateInvoiceTimeSeriesOutgoingPayments,
},
{
// The version of the database where every channel
// always has two entries in the edges bucket. If
// a policy is unknown, this will be represented
// by a special byte sequence.
number: 4,
migration: migration_01_to_11.MigrateEdgePolicies,
},
{
// The DB version where we persist each attempt to send
// an HTLC to a payment hash, and track whether the
// payment is in-flight, succeeded, or failed.
number: 5,
migration: migration_01_to_11.PaymentStatusesMigration,
},
{
// The DB version that properly prunes stale entries
// from the edge update index.
number: 6,
migration: migration_01_to_11.MigratePruneEdgeUpdateIndex,
},
{
// The DB version that migrates the ChannelCloseSummary
// to a format where optional fields are indicated with
// boolean flags.
number: 7,
migration: migration_01_to_11.MigrateOptionalChannelCloseSummaryFields,
},
{
// The DB version that changes the gossiper's message
// store keys to account for the message's type and
// ShortChannelID.
number: 8,
migration: migration_01_to_11.MigrateGossipMessageStoreKeys,
},
{
// The DB version where the payments and payment
// statuses are moved to being stored in a combined
// bucket.
number: 9,
migration: migration_01_to_11.MigrateOutgoingPayments,
},
{
// The DB version where we started to store legacy
// payload information for all routes, as well as the
// optional TLV records.
number: 10,
migration: migration_01_to_11.MigrateRouteSerialization,
},
{
// Add invoice htlc and cltv delta fields.
number: 11,
migration: migration_01_to_11.MigrateInvoices,
},
{
// Migrate to TLV invoice bodies, add payment address
// and features, remove receipt.
number: 12,
migration: migration12.MigrateInvoiceTLV,
},
{
// Migrate to multi-path payments.
number: 13,
migration: migration13.MigrateMPP,
},
{
// Initialize payment address index and begin using it
// as the default index, falling back to payment hash
// index.
number: 14,
migration: mig.CreateTLB(payAddrIndexBucket),
},
{
// Initialize payment index bucket which will be used
// to index payments by sequence number. This index will
// be used to allow more efficient ListPayments queries.
number: 15,
migration: mig.CreateTLB(paymentsIndexBucket),
},
{
// Add our existing payments to the index bucket created
// in migration 15.
number: 16,
migration: migration16.MigrateSequenceIndex,
},
{
// Create a top level bucket which will store extra
// information about channel closes.
number: 17,
migration: mig.CreateTLB(closeSummaryBucket),
},
{
// Create a top level bucket which holds information
// about our peers.
number: 18,
migration: mig.CreateTLB(peersBucket),
},
{
// Create a top level bucket which holds outpoint
// information.
number: 19,
migration: mig.CreateTLB(outpointBucket),
},
{
// Migrate some data to the outpoint index.
number: 20,
migration: migration20.MigrateOutpointIndex,
},
{
// Migrate to length prefixed wire messages everywhere
// in the database.
number: 21,
migration: migration21.MigrateDatabaseWireMessages,
},
{
// Initialize set id index so that invoices can be
// queried by individual htlc sets.
number: 22,
migration: mig.CreateTLB(setIDIndexBucket),
},
{
number: 23,
migration: migration23.MigrateHtlcAttempts,
},
{
// Remove old forwarding packages of closed channels.
number: 24,
migration: migration24.MigrateFwdPkgCleanup,
},
{
// Save the initial local/remote balances in channel
// info.
number: 25,
migration: migration25.MigrateInitialBalances,
},
{
// Migrate the initial local/remote balance fields into
// tlv records.
number: 26,
migration: migration26.MigrateBalancesToTlvRecords,
},
{
// Patch the initial local/remote balance fields with
// empty values for historical channels.
number: 27,
migration: migration27.MigrateHistoricalBalances,
},
{
number: 28,
migration: mig.CreateTLB(chanIDBucket),
},
{
number: 29,
migration: migration29.MigrateChanID,
},
{
// Removes the "sweeper-last-tx" bucket. Although we
// do not have a mandatory version 30 we skip this
// version because its naming is already used for the
// first optional migration.
number: 31,
migration: migration31.DeleteLastPublishedTxTLB,
},
{
number: 32,
migration: migration32.MigrateMCRouteSerialisation,
},
{
number: 33,
migration: migration33.MigrateMCStoreNameSpacedResults,
},
}
// optionalVersions stores all optional migrations that are applied
// after dbVersions.
//
// NOTE: optional migrations must be fault-tolerant and re-run already
// migrated data must be noop, which means the migration must be able
// to determine its state.
optionalVersions = []optionalVersion{
{
name: "prune revocation log",
migration: func(db kvdb.Backend,
cfg MigrationConfig) error {
return migration30.MigrateRevocationLog(db, cfg)
},
},
}
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
// channelOpeningStateBucket is the database bucket used to store the
// channelOpeningState for each channel that is currently in the process
// of being opened.
channelOpeningStateBucket = []byte("channelOpeningState")
)
// DB is the primary datastore for the lnd daemon. The database stores
// information related to nodes, routing data, open/closed channels, fee
// schedules, and reputation data.
type DB struct {
kvdb.Backend
// channelStateDB separates all DB operations on channel state.
channelStateDB *ChannelStateDB
dbPath string
clock clock.Clock
dryRun bool
keepFailedPaymentAttempts bool
storeFinalHtlcResolutions bool
// noRevLogAmtData if true, means that commitment transaction amount
// data should not be stored in the revocation log.
noRevLogAmtData bool
}
// OpenForTesting opens or creates a channeldb to be used for tests. Any
// necessary schemas migrations due to updates will take place as necessary.
func OpenForTesting(t testing.TB, dbPath string,
modifiers ...OptionModifier) *DB {
backend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{
DBPath: dbPath,
DBFileName: dbName,
NoFreelistSync: true,
AutoCompact: false,
AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge,
DBTimeout: kvdb.DefaultDBTimeout,
})
require.NoError(t, err)
db, err := CreateWithBackend(backend, modifiers...)
require.NoError(t, err)
db.dbPath = dbPath
t.Cleanup(func() {
require.NoError(t, db.Close())
})
return db
}
// CreateWithBackend creates channeldb instance using the passed kvdb.Backend.
// Any necessary schemas migrations due to updates will take place as necessary.
func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB,
error) {
opts := DefaultOptions()
for _, modifier := range modifiers {
modifier(&opts)
}
if !opts.NoMigration {
if err := initChannelDB(backend); err != nil {
return nil, err
}
}
chanDB := &DB{
Backend: backend,
channelStateDB: &ChannelStateDB{
linkNodeDB: &LinkNodeDB{
backend: backend,
},
backend: backend,
},
clock: opts.clock,
dryRun: opts.dryRun,
keepFailedPaymentAttempts: opts.keepFailedPaymentAttempts,
storeFinalHtlcResolutions: opts.storeFinalHtlcResolutions,
noRevLogAmtData: opts.NoRevLogAmtData,
}
// Set the parent pointer (only used in tests).
chanDB.channelStateDB.parent = chanDB
// Synchronize the version of database and apply migrations if needed.
if !opts.NoMigration {
if err := chanDB.syncVersions(dbVersions); err != nil {
backend.Close()
return nil, err
}
// Grab the optional migration config.
omc := opts.OptionalMiragtionConfig
if err := chanDB.applyOptionalVersions(omc); err != nil {
backend.Close()
return nil, err
}
}
return chanDB, nil
}
// Path returns the file path to the channel database.
func (d *DB) Path() string {
return d.dbPath
}
var dbTopLevelBuckets = [][]byte{
openChannelBucket,
closedChannelBucket,
forwardingLogBucket,
fwdPackagesKey,
invoiceBucket,
payAddrIndexBucket,
setIDIndexBucket,
paymentsIndexBucket,
peersBucket,
nodeInfoBucket,
metaBucket,
closeSummaryBucket,
outpointBucket,
chanIDBucket,
historicalChannelBucket,
}
// Wipe completely deletes all saved state within all used buckets within the
// database. The deletion is done in a single transaction, therefore this
// operation is fully atomic.
func (d *DB) Wipe() error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
for _, tlb := range dbTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
return initChannelDB(d.Backend)
}
// initChannelDB creates and initializes a fresh version of channeldb. In the
// case that the target path has not yet been created or doesn't yet exist, then
// the path is created. Additionally, all required top-level buckets used within
// the database are created.
func initChannelDB(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
// Check if DB was marked as inactive with a tomb stone.
if err := EnsureNoTombstone(tx); err != nil {
return err
}
meta := &Meta{}
// Check if DB is already initialized.
err := FetchMeta(meta, tx)
if err == nil {
return nil
}
for _, tlb := range dbTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err
}
}
meta.DbVersionNumber = getLatestDBVersion(dbVersions)
return putMeta(meta, tx)
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channeldb: %w", err)
}
return nil
}
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// ChannelStateDB is a database that keeps track of all channel state.
type ChannelStateDB struct {
// linkNodeDB separates all DB operations on LinkNodes.
linkNodeDB *LinkNodeDB
// parent holds a pointer to the "main" channeldb.DB object. This is
// only used for testing and should never be used in production code.
// For testing use the ChannelStateDB.GetParentDB() function to retrieve
// this pointer.
parent *DB
// backend points to the actual backend holding the channel state
// database. This may be a real backend or a cache middleware.
backend kvdb.Backend
}
// GetParentDB returns the "main" channeldb.DB object that is the owner of this
// ChannelStateDB instance. Use this function only in tests where passing around
// pointers makes testing less readable. Never to be used in production code!
func (c *ChannelStateDB) GetParentDB() *DB {
return c.parent
}
// LinkNodeDB returns the current instance of the link node database.
func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB {
return c.linkNodeDB
}
// FetchOpenChannels starts a new database transaction and returns all stored
// currently active/open channels associated with the target nodeID. In the case
// that no active channels are known to have been created with this node, then a
// zero-length slice is returned.
func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) (
[]*OpenChannel, error) {
var channels []*OpenChannel
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error
channels, err = c.fetchOpenChannels(tx, nodeID)
return err
}, func() {
channels = nil
})
return channels, err
}
// fetchOpenChannels uses and existing database transaction and returns all
// stored currently active/open channels associated with the target nodeID. In
// the case that no active channels are known to have been created with this
// node, then a zero-length slice is returned.
func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx,
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
// Get the bucket dedicated to storing the metadata for open channels.
openChanBucket := tx.ReadBucket(openChannelBucket)
if openChanBucket == nil {
return nil, nil
}
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
pub := nodeID.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadBucket(pub)
if nodeChanBucket == nil {
return nil, nil
}
// Next, we'll need to go down an additional layer in order to retrieve
// the channels for each chain the node knows of.
var channels []*OpenChannel
err := nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so ignore it.
if v != nil {
return nil
}
// If we've found a valid chainhash bucket, then we'll retrieve
// that so we can extract all the channels.
chainBucket := nodeChanBucket.NestedReadBucket(chainHash)
if chainBucket == nil {
return fmt.Errorf("unable to read bucket for chain=%x",
chainHash[:])
}
// Finally, we both of the necessary buckets retrieved, fetch
// all the active channels related to this node.
nodeChannels, err := c.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read channel for "+
"chain_hash=%x, node_key=%x: %v",
chainHash[:], pub, err)
}
channels = append(channels, nodeChannels...)
return nil
})
return channels, err
}
// fetchNodeChannels retrieves all active channels from the target chainBucket
// which is under a node's dedicated channel bucket. This function is typically
// used to fetch all the active channels related to a particular node.
func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (
[]*OpenChannel, error) {
var channels []*OpenChannel
// A node may have channels on several chains, so for each known chain,
// we'll extract all the channels.
err := chainBucket.ForEach(func(chanPoint, v []byte) error {
// If there's a value, it's not a bucket so ignore it.
if v != nil {
return nil
}
// Once we've found a valid channel bucket, we'll extract it
// from the node's chain bucket.
chanBucket := chainBucket.NestedReadBucket(chanPoint)
var outPoint wire.OutPoint
err := graphdb.ReadOutpoint(
bytes.NewReader(chanPoint), &outPoint,
)
if err != nil {
return err
}
oChannel, err := fetchOpenChannel(chanBucket, &outPoint)
if err != nil {
return fmt.Errorf("unable to read channel data for "+
"chan_point=%v: %w", outPoint, err)
}
oChannel.Db = c
channels = append(channels, oChannel)
return nil
})
if err != nil {
return nil, err
}
return channels, nil
}
// FetchChannel attempts to locate a channel specified by the passed channel
// point. If the channel cannot be found, then an error will be returned.
func (c *ChannelStateDB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel,
error) {
var targetChanPoint bytes.Buffer
err := graphdb.WriteOutpoint(&targetChanPoint, &chanPoint)
if err != nil {
return nil, err
}
targetChanPointBytes := targetChanPoint.Bytes()
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
return targetChanPointBytes, &chanPoint, nil
}
return c.channelScanner(nil, selector)
}
// FetchChannelByID attempts to locate a channel specified by the passed channel
// ID. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied.
func (c *ChannelStateDB) FetchChannelByID(tx kvdb.RTx, id lnwire.ChannelID) (
*OpenChannel, error) {
selector := func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error) {
var (
targetChanPointBytes []byte
targetChanPoint *wire.OutPoint
// errChanFound is used to signal that the channel has
// been found so that iteration through the DB buckets
// can stop.
errChanFound = errors.New("channel found")
)
err := chainBkt.ForEach(func(k, _ []byte) error {
var outPoint wire.OutPoint
err := graphdb.ReadOutpoint(
bytes.NewReader(k), &outPoint,
)
if err != nil {
return err
}
chanID := lnwire.NewChanIDFromOutPoint(outPoint)
if chanID != id {
return nil
}
targetChanPoint = &outPoint
targetChanPointBytes = k
return errChanFound
})
if err != nil && !errors.Is(err, errChanFound) {
return nil, nil, err
}
if targetChanPoint == nil {
return nil, nil, ErrChannelNotFound
}
return targetChanPointBytes, targetChanPoint, nil
}
return c.channelScanner(tx, selector)
}
// ChanCount is used by the server in determining access control.
type ChanCount struct {
HasOpenOrClosedChan bool
PendingOpenCount uint64
}
// FetchPermAndTempPeers returns a map where the key is the remote node's
// public key and the value is a struct that has a tally of the pending-open
// channels and whether the peer has an open or closed channel with us.
func (c *ChannelStateDB) FetchPermAndTempPeers(
chainHash []byte) (map[string]ChanCount, error) {
peerCounts := make(map[string]ChanCount)
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
openChanBucket := tx.ReadBucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoChanDBExists
}
openChanErr := openChanBucket.ForEach(func(nodePub,
v []byte) error {
// If there is a value, this is not a bucket.
if v != nil {
return nil
}
nodeChanBucket := openChanBucket.NestedReadBucket(
nodePub,
)
if nodeChanBucket == nil {
return nil
}
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
return fmt.Errorf("no chain bucket exists")
}
var isPermPeer bool
var pendingOpenCount uint64
internalErr := chainBucket.ForEach(func(chanPoint,
val []byte) error {
// If there is a value, this is not a bucket.
if val != nil {
return nil
}
chanBucket := chainBucket.NestedReadBucket(
chanPoint,
)
if chanBucket == nil {
return nil
}
var op wire.OutPoint
readErr := graphdb.ReadOutpoint(
bytes.NewReader(chanPoint), &op,
)
if readErr != nil {
return readErr
}
// We need to go through each channel and look
// at the IsPending status.
openChan, err := fetchOpenChannel(
chanBucket, &op,
)
if err != nil {
return err
}
if openChan.IsPending {
// Add to the pending-open count since
// this is a temp peer.
pendingOpenCount++
return nil
}
// Since IsPending is false, this is a perm
// peer.
isPermPeer = true
return nil
})
if internalErr != nil {
return internalErr
}
peerCount := ChanCount{
HasOpenOrClosedChan: isPermPeer,
PendingOpenCount: pendingOpenCount,
}
peerCounts[string(nodePub)] = peerCount
return nil
})
if openChanErr != nil {
return openChanErr
}
// Now check the closed channel bucket.
historicalChanBucket := tx.ReadBucket(historicalChannelBucket)
if historicalChanBucket == nil {
return ErrNoHistoricalBucket
}
historicalErr := historicalChanBucket.ForEach(func(chanPoint,
v []byte) error {
// Parse each nested bucket and the chanInfoKey to get
// the IsPending bool. This determines whether the
// peer is protected or not.
if v != nil {
// This is not a bucket. This is currently not
// possible.
return nil
}
chanBucket := historicalChanBucket.NestedReadBucket(
chanPoint,
)
if chanBucket == nil {
// This is not possible.
return fmt.Errorf("no historical channel " +
"bucket exists")
}
var op wire.OutPoint
readErr := graphdb.ReadOutpoint(
bytes.NewReader(chanPoint), &op,
)
if readErr != nil {
return readErr
}
// This channel is closed, but the structure of the
// historical bucket is the same. This is by design,
// which means we can call fetchOpenChannel.
channel, fetchErr := fetchOpenChannel(chanBucket, &op)
if fetchErr != nil {
return fetchErr
}
// Only include this peer in the protected class if
// the closing transaction confirmed. Note that
// CloseChannel can be called in the funding manager
// while IsPending is true which is why we need this
// special-casing to not count premature funding
// manager calls to CloseChannel.
if !channel.IsPending {
// Fetch the public key of the remote node. We
// need to use the string-ified serialized,
// compressed bytes as the key.
remotePub := channel.IdentityPub
remoteSer := remotePub.SerializeCompressed()
remoteKey := string(remoteSer)
count, exists := peerCounts[remoteKey]
if exists {
count.HasOpenOrClosedChan = true
peerCounts[remoteKey] = count
} else {
peerCount := ChanCount{
HasOpenOrClosedChan: true,
}
peerCounts[remoteKey] = peerCount
}
}
return nil
})
if historicalErr != nil {
return historicalErr
}
return nil
}, func() {
clear(peerCounts)
})
return peerCounts, err
}
// channelSelector describes a function that takes a chain-hash bucket from
// within the open-channel DB and returns the wanted channel point bytes, and
// channel point. It must return the ErrChannelNotFound error if the wanted
// channel is not in the given bucket.
type channelSelector func(chainBkt walletdb.ReadBucket) ([]byte, *wire.OutPoint,
error)
// channelScanner will traverse the DB to each chain-hash bucket of each node
// pub-key bucket in the open-channel-bucket. The chanSelector will then be used
// to fetch the wanted channel outpoint from the chain bucket.
func (c *ChannelStateDB) channelScanner(tx kvdb.RTx,
chanSelect channelSelector) (*OpenChannel, error) {
var (
targetChan *OpenChannel
// errChanFound is used to signal that the channel has been
// found so that iteration through the DB buckets can stop.
errChanFound = errors.New("channel found")
)
// chanScan will traverse the following bucket structure:
// * nodePub => chainHash => chanPoint
//
// At each level we go one further, ensuring that we're traversing the
// proper key (that's actually a bucket). By only reading the bucket
// structure and skipping fully decoding each channel, we save a good
// bit of CPU as we don't need to do things like decompress public
// keys.
chanScan := func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
openChanBucket := tx.ReadBucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoActiveChannels
}
// Within the node channel bucket, are the set of node pubkeys
// we have channels with, we don't know the entire set, so we'll
// check them all.
return openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey,
// and also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}
nodeChanBucket := openChanBucket.NestedReadBucket(
nodePub,
)
if nodeChanBucket == nil {
return nil
}
// The next layer down is all the chains that this node
// has channels on with us.
return nodeChanBucket.ForEach(func(chainHash,
v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x",
chainHash)
}
// Finally, we reach the leaf bucket that stores
// all the chanPoints for this node.
targetChanBytes, chanPoint, err := chanSelect(
chainBucket,
)
if errors.Is(err, ErrChannelNotFound) {
return nil
} else if err != nil {
return err
}
chanBucket := chainBucket.NestedReadBucket(
targetChanBytes,
)
if chanBucket == nil {
return nil
}
channel, err := fetchOpenChannel(
chanBucket, chanPoint,
)
if err != nil {
return err
}
targetChan = channel
targetChan.Db = c
return errChanFound
})
})
}
var err error
if tx == nil {
err = kvdb.View(c.backend, chanScan, func() {})
} else {
err = chanScan(tx)
}
if err != nil && !errors.Is(err, errChanFound) {
return nil, err
}
if targetChan != nil {
return targetChan, nil
}
// If we can't find the channel, then we return with an error, as we
// have nothing to back up.
return nil, ErrChannelNotFound
}
// FetchAllChannels attempts to retrieve all open channels currently stored
// within the database, including pending open, fully open and channels waiting
// for a closing transaction to confirm.
func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) {
return fetchChannels(c)
}
// FetchAllOpenChannels will return all channels that have the funding
// transaction confirmed, and is not waiting for a closing transaction to be
// confirmed.
func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) {
return fetchChannels(
c,
pendingChannelFilter(false),
waitingCloseFilter(false),
)
}
// FetchPendingChannels will return channels that have completed the process of
// generating and broadcasting funding transactions, but whose funding
// transactions have yet to be confirmed on the blockchain.
func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(c,
pendingChannelFilter(true),
waitingCloseFilter(false),
)
}
// FetchWaitingCloseChannels will return all channels that have been opened,
// but are now waiting for a closing transaction to be confirmed.
//
// NOTE: This includes channels that are also pending to be opened.
func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
return fetchChannels(
c, waitingCloseFilter(true),
)
}
// fetchChannelsFilter applies a filter to channels retrieved in fetchchannels.
// A set of filters can be combined to filter across multiple dimensions.
type fetchChannelsFilter func(channel *OpenChannel) bool
// pendingChannelFilter returns a filter based on whether channels are pending
// (ie, their funding transaction still needs to confirm). If pending is false,
// channels with confirmed funding transactions are returned.
func pendingChannelFilter(pending bool) fetchChannelsFilter {
return func(channel *OpenChannel) bool {
return channel.IsPending == pending
}
}
// waitingCloseFilter returns a filter which filters channels based on whether
// they are awaiting the confirmation of their closing transaction. If waiting
// close is true, channels that have had their closing tx broadcast are
// included. If it is false, channels that are not awaiting confirmation of
// their close transaction are returned.
func waitingCloseFilter(waitingClose bool) fetchChannelsFilter {
return func(channel *OpenChannel) bool {
// If the channel is in any other state than Default,
// then it means it is waiting to be closed.
channelWaitingClose :=
channel.ChanStatus() != ChanStatusDefault
// Include the channel if it matches the value for
// waiting close that we are filtering on.
return channelWaitingClose == waitingClose
}
}
// fetchChannels attempts to retrieve channels currently stored in the
// database. It takes a set of filters which are applied to each channel to
// obtain a set of channels with the desired set of properties. Only channels
// which have a true value returned for *all* of the filters will be returned.
// If no filters are provided, every channel in the open channels bucket will
// be returned.
func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) (
[]*OpenChannel, error) {
var channels []*OpenChannel
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open
// channels.
openChanBucket := tx.ReadBucket(openChannelBucket)
if openChanBucket == nil {
return ErrNoActiveChannels
}
// Next, fetch the bucket dedicated to storing metadata related
// to all nodes. All keys within this bucket are the serialized
// public keys of all our direct counterparties.
nodeMetaBucket := tx.ReadBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return fmt.Errorf("node bucket not created")
}
// Finally for each node public key in the bucket, fetch all
// the channels related to this particular node.
return nodeMetaBucket.ForEach(func(k, v []byte) error {
nodeChanBucket := openChanBucket.NestedReadBucket(k)
if nodeChanBucket == nil {
return nil
}
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
// If we've found a valid chainhash bucket,
// then we'll retrieve that so we can extract
// all the channels.
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash[:])
}
nodeChans, err := c.fetchNodeChannels(chainBucket)
if err != nil {
return fmt.Errorf("unable to read "+
"channel for chain_hash=%x, "+
"node_key=%x: %v", chainHash[:], k, err)
}
for _, channel := range nodeChans {
// includeChannel indicates whether the channel
// meets the criteria specified by our filters.
includeChannel := true
// Run through each filter and check whether the
// channel should be included.
for _, f := range filters {
// If the channel fails the filter, set
// includeChannel to false and don't bother
// checking the remaining filters.
if !f(channel) {
includeChannel = false
break
}
}
// If the channel passed every filter, include it in
// our set of channels.
if includeChannel {
channels = append(channels, channel)
}
}
return nil
})
})
}, func() {
channels = nil
})
if err != nil {
return nil, err
}
return channels, nil
}
// FetchClosedChannels attempts to fetch all closed channels from the database.
// The pendingOnly bool toggles if channels that aren't yet fully closed should
// be returned in the response or not. When a channel was cooperatively closed,
// it becomes fully closed after a single confirmation. When a channel was
// forcibly closed, it will become fully closed after _all_ the pending funds
// (if any) have been swept.
func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) (
[]*ChannelCloseSummary, error) {
var chanSummaries []*ChannelCloseSummary
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrNoClosedChannels
}
return closeBucket.ForEach(func(chanID []byte, summaryBytes []byte) error {
summaryReader := bytes.NewReader(summaryBytes)
chanSummary, err := deserializeCloseChannelSummary(summaryReader)
if err != nil {
return err
}
// If the query specified to only include pending
// channels, then we'll skip any channels which aren't
// currently pending.
if !chanSummary.IsPending && pendingOnly {
return nil
}
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}
return chanSummaries, nil
}
// ErrClosedChannelNotFound signals that a closed channel could not be found in
// the channeldb.
var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary")
// FetchClosedChannel queries for a channel close summary using the channel
// point of the channel in question.
func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
}
var b bytes.Buffer
var err error
if err = graphdb.WriteOutpoint(&b, chanID); err != nil {
return err
}
summaryBytes := closeBucket.Get(b.Bytes())
if summaryBytes == nil {
return ErrClosedChannelNotFound
}
summaryReader := bytes.NewReader(summaryBytes)
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
return err
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
return chanSummary, nil
}
// FetchClosedChannelForID queries for a channel close summary using the
// channel ID of the channel in question.
func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary
if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrClosedChannelNotFound
}
// The first 30 bytes of the channel ID and outpoint will be
// equal.
cursor := closeBucket.ReadCursor()
op, c := cursor.Seek(cid[:30])
// We scan over all possible candidates for this channel ID.
for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() {
var outPoint wire.OutPoint
err := graphdb.ReadOutpoint(
bytes.NewReader(op), &outPoint,
)
if err != nil {
return err
}
// If the found outpoint does not correspond to this
// channel ID, we continue.
if !cid.IsChanPoint(&outPoint) {
continue
}
// Deserialize the close summary and return.
r := bytes.NewReader(c)
chanSummary, err = deserializeCloseChannelSummary(r)
if err != nil {
return err
}
return nil
}
return ErrClosedChannelNotFound
}, func() {
chanSummary = nil
}); err != nil {
return nil, err
}
return chanSummary, nil
}
// MarkChanFullyClosed marks a channel as fully closed within the database. A
// channel should be marked as fully closed if the channel was initially
// cooperatively closed and it's reached a single confirmation, or after all
// the pending funds in a channel that has been forcibly closed have been
// swept.
func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
var (
openChannels []*OpenChannel
pruneLinkNode *btcec.PublicKey
)
err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
var b bytes.Buffer
if err := graphdb.WriteOutpoint(&b, chanPoint); err != nil {
return err
}
chanID := b.Bytes()
closedChanBucket, err := tx.CreateTopLevelBucket(
closedChannelBucket,
)
if err != nil {
return err
}
chanSummaryBytes := closedChanBucket.Get(chanID)
if chanSummaryBytes == nil {
return fmt.Errorf("no closed channel for "+
"chan_point=%v found", chanPoint)
}
chanSummaryReader := bytes.NewReader(chanSummaryBytes)
chanSummary, err := deserializeCloseChannelSummary(
chanSummaryReader,
)
if err != nil {
return err
}
chanSummary.IsPending = false
var newSummary bytes.Buffer
err = serializeChannelCloseSummary(&newSummary, chanSummary)
if err != nil {
return err
}
err = closedChanBucket.Put(chanID, newSummary.Bytes())
if err != nil {
return err
}
// Now that the channel is closed, we'll check if we have any
// other open channels with this peer. If we don't we'll
// garbage collect it to ensure we don't establish persistent
// connections to peers without open channels.
pruneLinkNode = chanSummary.RemotePub
openChannels, err = c.fetchOpenChannels(
tx, pruneLinkNode,
)
if err != nil {
return fmt.Errorf("unable to fetch open channels for "+
"peer %x: %v",
pruneLinkNode.SerializeCompressed(), err)
}
return nil
}, func() {
openChannels = nil
pruneLinkNode = nil
})
if err != nil {
return err
}
// Decide whether we want to remove the link node, based upon the number
// of still open channels.
return c.pruneLinkNode(openChannels, pruneLinkNode)
}
// pruneLinkNode determines whether we should garbage collect a link node from
// the database due to no longer having any open channels with it. If there are
// any left, then this acts as a no-op.
func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel,
remotePub *btcec.PublicKey) error {
if len(openChannels) > 0 {
return nil
}
log.Infof("Pruning link node %x with zero open channels from database",
remotePub.SerializeCompressed())
return c.linkNodeDB.DeleteLinkNode(remotePub)
}
// PruneLinkNodes attempts to prune all link nodes found within the database
// with whom we no longer have any open channels with.
func (c *ChannelStateDB) PruneLinkNodes() error {
allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes()
if err != nil {
return err
}
for _, linkNode := range allLinkNodes {
var (
openChannels []*OpenChannel
linkNode = linkNode
)
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error
openChannels, err = c.fetchOpenChannels(
tx, linkNode.IdentityPub,
)
return err
}, func() {
openChannels = nil
})
if err != nil {
return err
}
err = c.pruneLinkNode(openChannels, linkNode.IdentityPub)
if err != nil {
return err
}
}
return nil
}
// ChannelShell is a shell of a channel that is meant to be used for channel
// recovery purposes. It contains a minimal OpenChannel instance along with
// addresses for that target node.
type ChannelShell struct {
// NodeAddrs the set of addresses that this node has known to be
// reachable at in the past.
NodeAddrs []net.Addr
// Chan is a shell of an OpenChannel, it contains only the items
// required to restore the channel on disk.
Chan *OpenChannel
}
// RestoreChannelShells is a method that allows the caller to reconstruct the
// state of an OpenChannel from the ChannelShell. We'll attempt to write the
// new channel to disk, create a LinkNode instance with the passed node
// addresses, and finally create an edge within the graph for the channel as
// well. This method is idempotent, so repeated calls with the same set of
// channel shells won't modify the database after the initial call.
func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error {
err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
for _, channelShell := range channelShells {
channel := channelShell.Chan
// When we make a channel, we mark that the channel has
// been restored, this will signal to other sub-systems
// to not attempt to use the channel as if it was a
// regular one.
channel.chanStatus |= ChanStatusRestored
// First, we'll attempt to create a new open channel
// and link node for this channel. If the channel
// already exists, then in order to ensure this method
// is idempotent, we'll continue to the next step.
channel.Db = c
err := syncNewChannel(
tx, channel, channelShell.NodeAddrs,
)
if err != nil {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
return nil
}
// AddrsForNode consults the channel database for all addresses known to the
// passed node public key. The returned boolean indicates if the given node is
// unknown to the channel DB or not.
//
// NOTE: this is part of the AddrSource interface.
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub)
// Only if the error is something other than ErrNodeNotFound do we
// return it.
switch {
case err != nil && !errors.Is(err, ErrNodeNotFound):
return false, nil, err
case errors.Is(err, ErrNodeNotFound):
return false, nil, nil
}
return true, linkNode.Addresses, nil
}
// AbandonChannel attempts to remove the target channel from the open channel
// database. If the channel was already removed (has a closed channel entry),
// then we'll return a nil error. Otherwise, we'll insert a new close summary
// into the database.
func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint,
bestHeight uint32) error {
// With the chanPoint constructed, we'll attempt to find the target
// channel in the database. If we can't find the channel, then we'll
// return the error back to the caller.
dbChan, err := c.FetchChannel(*chanPoint)
switch {
// If the channel wasn't found, then it's possible that it was already
// abandoned from the database.
case err == ErrChannelNotFound:
_, closedErr := c.FetchClosedChannel(chanPoint)
if closedErr != nil {
return closedErr
}
// If the channel was already closed, then we don't return an
// error as we'd like this step to be repeatable.
return nil
case err != nil:
return err
}
// Now that we've found the channel, we'll populate a close summary for
// the channel, so we can store as much information for this abounded
// channel as possible. We also ensure that we set Pending to false, to
// indicate that this channel has been "fully" closed.
summary := &ChannelCloseSummary{
CloseType: Abandoned,
ChanPoint: *chanPoint,
ChainHash: dbChan.ChainHash,
CloseHeight: bestHeight,
RemotePub: dbChan.IdentityPub,
Capacity: dbChan.Capacity,
SettledBalance: dbChan.LocalCommitment.LocalBalance.ToSatoshis(),
ShortChanID: dbChan.ShortChanID(),
RemoteCurrentRevocation: dbChan.RemoteCurrentRevocation,
RemoteNextRevocation: dbChan.RemoteNextRevocation,
LocalChanConfig: dbChan.LocalChanCfg,
}
// Finally, we'll close the channel in the DB, and return back to the
// caller. We set ourselves as the close initiator because we abandoned
// the channel.
return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator)
}
// SaveChannelOpeningState saves the serialized channel state for the provided
// chanPoint to the channelOpeningStateBucket.
func (c *ChannelStateDB) SaveChannelOpeningState(outPoint,
serializedState []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket)
if err != nil {
return err
}
return bucket.Put(outPoint, serializedState)
}, func() {})
}
// GetChannelOpeningState fetches the serialized channel state for the provided
// outPoint from the database, or returns ErrChannelNotFound if the channel
// is not found.
func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte,
error) {
var serializedState []byte
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(channelOpeningStateBucket)
if bucket == nil {
// If the bucket does not exist, it means we never added
// a channel to the db, so return ErrChannelNotFound.
return ErrChannelNotFound
}
stateBytes := bucket.Get(outPoint)
if stateBytes == nil {
return ErrChannelNotFound
}
serializedState = append(serializedState, stateBytes...)
return nil
}, func() {
serializedState = nil
})
return serializedState, err
}
// DeleteChannelOpeningState removes any state for outPoint from the database.
func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(channelOpeningStateBucket)
if bucket == nil {
return ErrChannelNotFound
}
return bucket.Delete(outPoint)
}, func() {})
}
// syncVersions function is used for safe db version synchronization. It
// applies migration functions to the current database and recovers the
// previous state of db if at least one error/panic appeared during migration.
func (d *DB) syncVersions(versions []mandatoryVersion) error {
meta, err := d.FetchMeta()
if err != nil {
if err == ErrMetaNotFound {
meta = &Meta{}
} else {
return err
}
}
latestVersion := getLatestDBVersion(versions)
log.Infof("Checking for schema update: latest_version=%v, "+
"db_version=%v", latestVersion, meta.DbVersionNumber)
switch {
// If the database reports a higher version that we are aware of, the
// user is probably trying to revert to a prior version of lnd. We fail
// here to prevent reversions and unintended corruption.
case meta.DbVersionNumber > latestVersion:
log.Errorf("Refusing to revert from db_version=%d to "+
"lower version=%d", meta.DbVersionNumber,
latestVersion)
return ErrDBReversion
// If the current database version matches the latest version number,
// then we don't need to perform any migrations.
case meta.DbVersionNumber == latestVersion:
return nil
}
log.Infof("Performing database schema migration")
// Otherwise, we fetch the migrations which need to applied, and
// execute them serially within a single database transaction to ensure
// the migration is atomic.
migrations, migrationVersions := getMigrationsToApply(
versions, meta.DbVersionNumber,
)
return kvdb.Update(d, func(tx kvdb.RwTx) error {
for i, migration := range migrations {
if migration == nil {
continue
}
log.Infof("Applying migration #%v",
migrationVersions[i])
if err := migration(tx); err != nil {
log.Infof("Unable to apply migration #%v",
migrationVersions[i])
return err
}
}
meta.DbVersionNumber = latestVersion
err := putMeta(meta, tx)
if err != nil {
return err
}
// In dry-run mode, return an error to prevent the transaction
// from committing.
if d.dryRun {
return ErrDryRunMigrationOK
}
return nil
}, func() {})
}
// applyOptionalVersions takes a config to determine whether the optional
// migrations will be applied.
//
// NOTE: only support the prune_revocation_log optional migration atm.
func (d *DB) applyOptionalVersions(cfg OptionalMiragtionConfig) error {
// TODO(yy): need to design the db to support dry run for optional
// migrations.
if d.dryRun {
log.Info("Skipped optional migrations as dry run mode is not " +
"supported yet")
return nil
}
om, err := d.fetchOptionalMeta()
if err != nil {
if err == ErrMetaNotFound {
om = &OptionalMeta{
Versions: make(map[uint64]string),
}
} else {
return err
}
}
log.Infof("Checking for optional update: prune_revocation_log=%v, "+
"db_version=%s", cfg.PruneRevocationLog, om)
// Exit early if the optional migration is not specified.
if !cfg.PruneRevocationLog {
return nil
}
// Exit early if the optional migration has already been applied.
if _, ok := om.Versions[0]; ok {
return nil
}
// Get the optional version.
version := optionalVersions[0]
log.Infof("Performing database optional migration: %s", version.name)
migrationCfg := &MigrationConfigImpl{
migration30.MigrateRevLogConfigImpl{
NoAmountData: d.noRevLogAmtData,
},
}
// Migrate the data.
if err := version.migration(d, migrationCfg); err != nil {
log.Errorf("Unable to apply optional migration: %s, error: %v",
version.name, err)
return err
}
// Update the optional meta. Notice that unlike the mandatory db
// migrations where we perform the migration and updating meta in a
// single db transaction, we use different transactions here. Even when
// the following update is failed, we should be fine here as we would
// re-run the optional migration again, which is a noop, during next
// startup.
om.Versions[0] = version.name
if err := d.putOptionalMeta(om); err != nil {
log.Errorf("Unable to update optional meta: %v", err)
return err
}
return nil
}
// ChannelStateDB returns the sub database that is concerned with the channel
// state.
func (d *DB) ChannelStateDB() *ChannelStateDB {
return d.channelStateDB
}
// LatestDBVersion returns the number of the latest database version currently
// known to the channel DB.
func LatestDBVersion() uint32 {
return getLatestDBVersion(dbVersions)
}
func getLatestDBVersion(versions []mandatoryVersion) uint32 {
return versions[len(versions)-1].number
}
// getMigrationsToApply retrieves the migration function that should be
// applied to the database.
func getMigrationsToApply(versions []mandatoryVersion,
version uint32) ([]migration, []uint32) {
migrations := make([]migration, 0, len(versions))
migrationVersions := make([]uint32, 0, len(versions))
for _, v := range versions {
if v.number > version {
migrations = append(migrations, v.migration)
migrationVersions = append(migrationVersions, v.number)
}
}
return migrations, migrationVersions
}
// fetchHistoricalChanBucket returns a the channel bucket for a given outpoint
// from the historical channel bucket. If the bucket does not exist,
// ErrNoHistoricalBucket is returned.
func fetchHistoricalChanBucket(tx kvdb.RTx,
outPoint *wire.OutPoint) (kvdb.RBucket, error) {
// First fetch the top level bucket which stores all data related to
// historically stored channels.
historicalChanBucket := tx.ReadBucket(historicalChannelBucket)
if historicalChanBucket == nil {
return nil, ErrNoHistoricalBucket
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for the channel itself.
var chanPointBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := historicalChanBucket.NestedReadBucket(
chanPointBuf.Bytes(),
)
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
}
// FetchHistoricalChannel fetches open channel data from the historical channel
// bucket.
func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) (
*OpenChannel, error) {
var channel *OpenChannel
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchHistoricalChanBucket(tx, outPoint)
if err != nil {
return err
}
channel, err = fetchOpenChannel(chanBucket, outPoint)
if err != nil {
return err
}
channel.Db = c
return nil
}, func() {
channel = nil
})
if err != nil {
return nil, err
}
return channel, nil
}
func fetchFinalHtlcsBucket(tx kvdb.RTx,
chanID lnwire.ShortChannelID) (kvdb.RBucket, error) {
finalHtlcsBucket := tx.ReadBucket(finalHtlcsBucket)
if finalHtlcsBucket == nil {
return nil, ErrFinalHtlcsBucketNotFound
}
var chanIDBytes [8]byte
byteOrder.PutUint64(chanIDBytes[:], chanID.ToUint64())
chanBucket := finalHtlcsBucket.NestedReadBucket(chanIDBytes[:])
if chanBucket == nil {
return nil, ErrFinalChannelBucketNotFound
}
return chanBucket, nil
}
var ErrHtlcUnknown = errors.New("htlc unknown")
// LookupFinalHtlc retrieves a final htlc resolution from the database. If the
// htlc has no final resolution yet, ErrHtlcUnknown is returned.
func (c *ChannelStateDB) LookupFinalHtlc(chanID lnwire.ShortChannelID,
htlcIndex uint64) (*FinalHtlcInfo, error) {
var idBytes [8]byte
byteOrder.PutUint64(idBytes[:], htlcIndex)
var settledByte byte
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
finalHtlcsBucket, err := fetchFinalHtlcsBucket(
tx, chanID,
)
switch {
case errors.Is(err, ErrFinalHtlcsBucketNotFound):
fallthrough
case errors.Is(err, ErrFinalChannelBucketNotFound):
return ErrHtlcUnknown
case err != nil:
return fmt.Errorf("cannot fetch final htlcs bucket: %w",
err)
}
value := finalHtlcsBucket.Get(idBytes[:])
if value == nil {
return ErrHtlcUnknown
}
if len(value) != 1 {
return errors.New("unexpected final htlc value length")
}
settledByte = value[0]
return nil
}, func() {
settledByte = 0
})
if err != nil {
return nil, err
}
info := FinalHtlcInfo{
Settled: settledByte&byte(FinalHtlcSettledBit) != 0,
Offchain: settledByte&byte(FinalHtlcOffchainBit) != 0,
}
return &info, nil
}
// PutOnchainFinalHtlcOutcome stores the final on-chain outcome of an htlc in
// the database.
func (c *ChannelStateDB) PutOnchainFinalHtlcOutcome(
chanID lnwire.ShortChannelID, htlcID uint64, settled bool) error {
// Skip if the user did not opt in to storing final resolutions.
if !c.parent.storeFinalHtlcResolutions {
return nil
}
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
finalHtlcsBucket, err := fetchFinalHtlcsBucketRw(tx, chanID)
if err != nil {
return err
}
return putFinalHtlc(
finalHtlcsBucket, htlcID,
FinalHtlcInfo{
Settled: settled,
Offchain: false,
},
)
}, func() {})
}
// MakeTestInvoiceDB is used to create a test invoice database for testing
// purposes. It simply calls into MakeTestDB so the same modifiers can be used.
func MakeTestInvoiceDB(t *testing.T, modifiers ...OptionModifier) (
invoices.InvoiceDB, error) {
return MakeTestDB(t, modifiers...)
}
// MakeTestDB creates a new instance of the ChannelDB for testing purposes.
// A callback which cleans up the created temporary directories is also
// returned and intended to be executed after the test completes.
func MakeTestDB(t *testing.T, modifiers ...OptionModifier) (*DB, error) {
// First, create a temporary directory to be used for the duration of
// this test.
tempDirName := t.TempDir()
// Next, create channeldb for the first time.
backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb")
if err != nil {
backendCleanup()
return nil, err
}
cdb, err := CreateWithBackend(backend, modifiers...)
if err != nil {
backendCleanup()
return nil, err
}
t.Cleanup(func() {
cdb.Close()
backendCleanup()
})
return cdb, nil
}
package channeldb
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
var (
// duplicatePaymentsBucket is the name of a optional sub-bucket within
// the payment hash bucket, that is used to hold duplicate payments to a
// payment hash. This is needed to support information from earlier
// versions of lnd, where it was possible to pay to a payment hash more
// than once.
duplicatePaymentsBucket = []byte("payment-duplicate-bucket")
// duplicatePaymentSettleInfoKey is a key used in the payment's
// sub-bucket to store the settle info of the payment.
duplicatePaymentSettleInfoKey = []byte("payment-settle-info")
// duplicatePaymentAttemptInfoKey is a key used in the payment's
// sub-bucket to store the info about the latest attempt that was done
// for the payment in question.
duplicatePaymentAttemptInfoKey = []byte("payment-attempt-info")
// duplicatePaymentCreationInfoKey is a key used in the payment's
// sub-bucket to store the creation info of the payment.
duplicatePaymentCreationInfoKey = []byte("payment-creation-info")
// duplicatePaymentFailInfoKey is a key used in the payment's sub-bucket
// to store information about the reason a payment failed.
duplicatePaymentFailInfoKey = []byte("payment-fail-info")
// duplicatePaymentSequenceKey is a key used in the payment's sub-bucket
// to store the sequence number of the payment.
duplicatePaymentSequenceKey = []byte("payment-sequence-key")
)
// duplicateHTLCAttemptInfo contains static information about a specific HTLC
// attempt for a payment. This information is used by the router to handle any
// errors coming back after an attempt is made, and to query the switch about
// the status of the attempt.
type duplicateHTLCAttemptInfo struct {
// attemptID is the unique ID used for this attempt.
attemptID uint64
// sessionKey is the ephemeral key used for this attempt.
sessionKey [btcec.PrivKeyBytesLen]byte
// route is the route attempted to send the HTLC.
route route.Route
}
// fetchDuplicatePaymentStatus fetches the payment status of the payment. If
// the payment isn't found, it will return error `ErrPaymentNotInitiated`.
func fetchDuplicatePaymentStatus(bucket kvdb.RBucket) (PaymentStatus, error) {
if bucket.Get(duplicatePaymentSettleInfoKey) != nil {
return StatusSucceeded, nil
}
if bucket.Get(duplicatePaymentFailInfoKey) != nil {
return StatusFailed, nil
}
if bucket.Get(duplicatePaymentCreationInfoKey) != nil {
return StatusInFlight, nil
}
return 0, ErrPaymentNotInitiated
}
func deserializeDuplicateHTLCAttemptInfo(r io.Reader) (
*duplicateHTLCAttemptInfo, error) {
a := &duplicateHTLCAttemptInfo{}
err := ReadElements(r, &a.attemptID, &a.sessionKey)
if err != nil {
return nil, err
}
a.route, err = DeserializeRoute(r)
if err != nil {
return nil, err
}
return a, nil
}
func deserializeDuplicatePaymentCreationInfo(r io.Reader) (
*PaymentCreationInfo, error) {
var scratch [8]byte
c := &PaymentCreationInfo{}
if _, err := io.ReadFull(r, c.PaymentIdentifier[:]); err != nil {
return nil, err
}
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return nil, err
}
c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return nil, err
}
c.CreationTime = time.Unix(int64(byteOrder.Uint64(scratch[:])), 0)
if _, err := io.ReadFull(r, scratch[:4]); err != nil {
return nil, err
}
reqLen := byteOrder.Uint32(scratch[:4])
payReq := make([]byte, reqLen)
if reqLen > 0 {
if _, err := io.ReadFull(r, payReq); err != nil {
return nil, err
}
}
c.PaymentRequest = payReq
return c, nil
}
func fetchDuplicatePayment(bucket kvdb.RBucket) (*MPPayment, error) {
seqBytes := bucket.Get(duplicatePaymentSequenceKey)
if seqBytes == nil {
return nil, fmt.Errorf("sequence number not found")
}
sequenceNum := binary.BigEndian.Uint64(seqBytes)
// Get the payment status.
paymentStatus, err := fetchDuplicatePaymentStatus(bucket)
if err != nil {
return nil, err
}
// Get the PaymentCreationInfo.
b := bucket.Get(duplicatePaymentCreationInfoKey)
if b == nil {
return nil, fmt.Errorf("creation info not found")
}
r := bytes.NewReader(b)
creationInfo, err := deserializeDuplicatePaymentCreationInfo(r)
if err != nil {
return nil, err
}
// Get failure reason if available.
var failureReason *FailureReason
b = bucket.Get(duplicatePaymentFailInfoKey)
if b != nil {
reason := FailureReason(b[0])
failureReason = &reason
}
payment := &MPPayment{
SequenceNum: sequenceNum,
Info: creationInfo,
FailureReason: failureReason,
Status: paymentStatus,
}
// Get the HTLCAttemptInfo. It can be absent.
b = bucket.Get(duplicatePaymentAttemptInfoKey)
if b != nil {
r = bytes.NewReader(b)
attempt, err := deserializeDuplicateHTLCAttemptInfo(r)
if err != nil {
return nil, err
}
htlc := HTLCAttempt{
HTLCAttemptInfo: HTLCAttemptInfo{
AttemptID: attempt.attemptID,
Route: attempt.route,
sessionKey: attempt.sessionKey,
},
}
// Get the payment preimage. This is only found for
// successful payments.
b = bucket.Get(duplicatePaymentSettleInfoKey)
if b != nil {
var preimg lntypes.Preimage
copy(preimg[:], b)
htlc.Settle = &HTLCSettleInfo{
Preimage: preimg,
SettleTime: time.Time{},
}
} else {
// Otherwise the payment must have failed.
htlc.Failure = &HTLCFailInfo{
FailTime: time.Time{},
}
}
payment.HTLCs = []HTLCAttempt{htlc}
}
return payment, nil
}
func fetchDuplicatePayments(paymentHashBucket kvdb.RBucket) ([]*MPPayment,
error) {
var payments []*MPPayment
// For older versions of lnd, duplicate payments to a payment has was
// possible. These will be found in a sub-bucket indexed by their
// sequence number if available.
dup := paymentHashBucket.NestedReadBucket(duplicatePaymentsBucket)
if dup == nil {
return nil, nil
}
err := dup.ForEach(func(k, v []byte) error {
subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
// We one bucket for each duplicate to be found.
return fmt.Errorf("non bucket element" +
"in duplicate bucket")
}
p, err := fetchDuplicatePayment(subBucket)
if err != nil {
return err
}
payments = append(payments, p)
return nil
})
if err != nil {
return nil, err
}
return payments, nil
}
package channeldb
import (
"bytes"
"io"
"sort"
"time"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// forwardingLogBucket is the bucket that we'll use to store the
// forwarding log. The forwarding log contains a time series database
// of the forwarding history of a lightning daemon. Each key within the
// bucket is a timestamp (in nano seconds since the unix epoch), and
// the value a slice of a forwarding event for that timestamp.
forwardingLogBucket = []byte("circuit-fwd-log")
)
const (
// forwardingEventSize is the size of a forwarding event. The breakdown
// is as follows:
//
// * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in
// || 8 byte value out
//
// From the value in and value out, callers can easily compute the
// total fee extract from a forwarding event.
forwardingEventSize = 32
// MaxResponseEvents is the max number of forwarding events that will
// be returned by a single query response. This size was selected to
// safely remain under gRPC's 4MiB message size response limit. As each
// full forwarding event (including the timestamp) is 40 bytes, we can
// safely return 50k entries in a single response.
MaxResponseEvents = 50000
)
// ForwardingLog returns an instance of the ForwardingLog object backed by the
// target database instance.
func (d *DB) ForwardingLog() *ForwardingLog {
return &ForwardingLog{
db: d,
}
}
// ForwardingLog is a time series database that logs the fulfillment of payment
// circuits by a lightning network daemon. The log contains a series of
// forwarding events which map a timestamp to a forwarding event. A forwarding
// event describes which channels were used to create+settle a circuit, and the
// amount involved. Subtracting the outgoing amount from the incoming amount
// reveals the fee charged for the forwarding service.
type ForwardingLog struct {
db *DB
}
// ForwardingEvent is an event in the forwarding log's time series. Each
// forwarding event logs the creation and tear-down of a payment circuit. A
// circuit is created once an incoming HTLC has been fully forwarded, and
// destroyed once the payment has been settled.
type ForwardingEvent struct {
// Timestamp is the settlement time of this payment circuit.
Timestamp time.Time
// IncomingChanID is the incoming channel ID of the payment circuit.
IncomingChanID lnwire.ShortChannelID
// OutgoingChanID is the outgoing channel ID of the payment circuit.
OutgoingChanID lnwire.ShortChannelID
// AmtIn is the amount of the incoming HTLC. Subtracting this from the
// outgoing amount gives the total fees of this payment circuit.
AmtIn lnwire.MilliSatoshi
// AmtOut is the amount of the outgoing HTLC. Subtracting the incoming
// amount from this gives the total fees for this payment circuit.
AmtOut lnwire.MilliSatoshi
}
// encodeForwardingEvent writes out the target forwarding event to the passed
// io.Writer, using the expected DB format. Note that the timestamp isn't
// serialized as this will be the key value within the bucket.
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
return WriteElements(
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
)
}
// decodeForwardingEvent attempts to decode the raw bytes of a serialized
// forwarding event into the target ForwardingEvent. Note that the timestamp
// won't be decoded, as the caller is expected to set this due to the bucket
// structure of the forwarding log.
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
return ReadElements(
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
)
}
// AddForwardingEvents adds a series of forwarding events to the database.
// Before inserting, the set of events will be sorted according to their
// timestamp. This ensures that all writes to disk are sequential.
func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error {
// Before we create the database transaction, we'll ensure that the set
// of forwarding events are properly sorted according to their
// timestamp and that no duplicate timestamps exist to avoid collisions
// in the key we are going to store the events under.
makeUniqueTimestamps(events)
var timestamp [8]byte
return kvdb.Batch(f.db.Backend, func(tx kvdb.RwTx) error {
// First, we'll fetch the bucket that stores our time series
// log.
logBucket, err := tx.CreateTopLevelBucket(
forwardingLogBucket,
)
if err != nil {
return err
}
// With the bucket obtained, we can now begin to write out the
// series of events.
for _, event := range events {
err := storeEvent(logBucket, event, timestamp[:])
if err != nil {
return err
}
}
return nil
})
}
// storeEvent tries to store a forwarding event into the given bucket by trying
// to avoid collisions. If a key for the event timestamp already exists in the
// database, the timestamp is incremented in nanosecond intervals until a "free"
// slot is found.
func storeEvent(bucket walletdb.ReadWriteBucket, event ForwardingEvent,
timestampScratchSpace []byte) error {
// First, we'll serialize this timestamp into our
// timestamp buffer.
byteOrder.PutUint64(
timestampScratchSpace, uint64(event.Timestamp.UnixNano()),
)
// Next we'll loop until we find a "free" slot in the bucket to store
// the event under. This should almost never happen unless we're running
// on a system that has a very bad system clock that doesn't properly
// resolve to nanosecond scale. We try up to 100 times (which would come
// to a maximum shift of 0.1 microsecond which is acceptable for most
// use cases). If we don't find a free slot, we just give up and let
// the collision happen. Something must be wrong with the data in that
// case, even on a very fast machine forwarding payments _will_ take a
// few microseconds at least so we should find a nanosecond slot
// somewhere.
const maxTries = 100
tries := 0
for tries < maxTries {
val := bucket.Get(timestampScratchSpace)
if val == nil {
break
}
// Collision, try the next nanosecond timestamp.
nextNano := event.Timestamp.UnixNano() + 1
event.Timestamp = time.Unix(0, nextNano)
byteOrder.PutUint64(timestampScratchSpace, uint64(nextNano))
tries++
}
// With the key encoded, we'll then encode the event
// into our buffer, then write it out to disk.
var eventBytes [forwardingEventSize]byte
eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize])
err := encodeForwardingEvent(eventBuf, &event)
if err != nil {
return err
}
return bucket.Put(timestampScratchSpace, eventBuf.Bytes())
}
// ForwardingEventQuery represents a query to the forwarding log payment
// circuit time series database. The query allows a caller to retrieve all
// records for a particular time slice, offset in that time slice, limiting the
// total number of responses returned.
type ForwardingEventQuery struct {
// StartTime is the start time of the time slice.
StartTime time.Time
// EndTime is the end time of the time slice.
EndTime time.Time
// IndexOffset is the offset within the time slice to start at. This
// can be used to start the response at a particular record.
IndexOffset uint32
// NumMaxEvents is the max number of events to return.
NumMaxEvents uint32
}
// ForwardingLogTimeSlice is the response to a forwarding query. It includes
// the original query, the set events that match the query, and an integer
// which represents the offset index of the last item in the set of returned
// events. This integer allows callers to resume their query using this offset
// in the event that the query's response exceeds the max number of returnable
// events.
type ForwardingLogTimeSlice struct {
ForwardingEventQuery
// ForwardingEvents is the set of events in our time series that answer
// the query embedded above.
ForwardingEvents []ForwardingEvent
// LastIndexOffset is the index of the last element in the set of
// returned ForwardingEvents above. Callers can use this to resume
// their query in the event that the time slice has too many events to
// fit into a single response.
LastIndexOffset uint32
}
// Query allows a caller to query the forwarding event time series for a
// particular time slice. The caller can control the precise time as well as
// the number of events to be returned.
//
// TODO(roasbeef): rename?
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
var resp ForwardingLogTimeSlice
// If the user provided an index offset, then we'll not know how many
// records we need to skip. We'll also keep track of the record offset
// as that's part of the final return value.
recordsToSkip := q.IndexOffset
recordOffset := q.IndexOffset
err := kvdb.View(f.db, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any events to
// be returned.
logBucket := tx.ReadBucket(forwardingLogBucket)
if logBucket == nil {
return ErrNoForwardingEvents
}
// We'll be using a cursor to seek into the database, so we'll
// populate byte slices that represent the start of the key
// space we're interested in, and the end.
var startTime, endTime [8]byte
byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano()))
byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano()))
// If we know that a set of log events exists, then we'll begin
// our seek through the log in order to satisfy the query.
// We'll continue until either we reach the end of the range,
// or reach our max number of events.
logCursor := logBucket.ReadCursor()
timestamp, events := logCursor.Seek(startTime[:])
for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() {
// If our current return payload exceeds the max number
// of events, then we'll exit now.
if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents {
return nil
}
// If we're not yet past the user defined offset, then
// we'll continue to seek forward.
if recordsToSkip > 0 {
recordsToSkip--
continue
}
currentTime := time.Unix(
0, int64(byteOrder.Uint64(timestamp)),
)
// At this point, we've skipped enough records to start
// to collate our query. For each record, we'll
// increment the final record offset so the querier can
// utilize pagination to seek further.
readBuf := bytes.NewReader(events)
for readBuf.Len() != 0 {
var event ForwardingEvent
err := decodeForwardingEvent(readBuf, &event)
if err != nil {
return err
}
event.Timestamp = currentTime
resp.ForwardingEvents = append(resp.ForwardingEvents, event)
recordOffset++
}
}
return nil
}, func() {
resp = ForwardingLogTimeSlice{
ForwardingEventQuery: q,
}
})
if err != nil && err != ErrNoForwardingEvents {
return ForwardingLogTimeSlice{}, err
}
resp.LastIndexOffset = recordOffset
return resp, nil
}
// makeUniqueTimestamps takes a slice of forwarding events, sorts it by the
// event timestamps and then makes sure there are no duplicates in the
// timestamps. If duplicates are found, some of the timestamps are increased on
// the nanosecond scale until only unique values remain. This is a fix to
// address the problem that in some environments (looking at you, Windows) the
// system clock has such a bad resolution that two serial invocations of
// time.Now() might return the same timestamp, even if some time has elapsed
// between the calls.
func makeUniqueTimestamps(events []ForwardingEvent) {
sort.Slice(events, func(i, j int) bool {
return events[i].Timestamp.Before(events[j].Timestamp)
})
// Now that we know the events are sorted by timestamp, we can go
// through the list and fix all duplicates until only unique values
// remain.
for outer := 0; outer < len(events)-1; outer++ {
current := events[outer].Timestamp.UnixNano()
next := events[outer+1].Timestamp.UnixNano()
// We initially sorted the slice. So if the current is now
// greater or equal to the next one, it's either because it's a
// duplicate or because we increased the current in the last
// iteration.
if current >= next {
next = current + 1
events[outer+1].Timestamp = time.Unix(0, next)
}
}
}
package channeldb
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
// package has potentially been mangled.
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
// FwdState is an enum used to describe the lifecycle of a FwdPkg.
type FwdState byte
const (
// FwdStateLockedIn is the starting state for all forwarding packages.
// Packages in this state have not yet committed to the exact set of
// Adds to forward to the switch.
FwdStateLockedIn FwdState = iota
// FwdStateProcessed marks the state in which all Adds have been
// locally processed and the forwarding decision to the switch has been
// persisted.
FwdStateProcessed
// FwdStateCompleted signals that all Adds have been acked, and that all
// settles and fails have been delivered to their sources. Packages in
// this state can be removed permanently.
FwdStateCompleted
)
var (
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
//
// Bucket hierarchy:
//
// fwdPackagesKey(root-bucket)
// |
// |-- <shortChannelID>
// | |
// | |-- <height>
// | | |-- ackFilterKey: <encoded bytes of PkgFilter>
// | | |-- settleFailFilterKey: <encoded bytes of PkgFilter>
// | | |-- fwdFilterKey: <encoded bytes of PkgFilter>
// | | |
// | | |-- addBucketKey
// | | | |-- <index of LogUpdate>: <encoded bytes of LogUpdate>
// | | | |-- <index of LogUpdate>: <encoded bytes of LogUpdate>
// | | | ...
// | | |
// | | |-- failSettleBucketKey
// | | |-- <index of LogUpdate>: <encoded bytes of LogUpdate>
// | | |-- <index of LogUpdate>: <encoded bytes of LogUpdate>
// | | ...
// | |
// | |-- <height>
// | | |
// | ... ...
// |
// |
// |-- <shortChannelID>
// | |
// | ...
// ...
//
fwdPackagesKey = []byte("fwd-packages")
// addBucketKey is the bucket to which all Add log updates are written.
addBucketKey = []byte("add-updates")
// failSettleBucketKey is the bucket to which all Settle/Fail log
// updates are written.
failSettleBucketKey = []byte("fail-settle-updates")
// fwdFilterKey is a key used to write the set of Adds that passed
// validation and are to be forwarded to the switch.
// NOTE: The presence of this key within a forwarding package indicates
// that the package has reached FwdStateProcessed.
fwdFilterKey = []byte("fwd-filter-key")
// ackFilterKey is a key used to access the PkgFilter indicating which
// Adds have received a Settle/Fail. This response may come from a
// number of sources, including: exitHop settle/fails, switch failures,
// chain arbiter interjections, as well as settle/fails from the
// next hop in the route.
ackFilterKey = []byte("ack-filter-key")
// settleFailFilterKey is a key used to access the PkgFilter indicating
// which Settles/Fails in have been received and processed by the link
// that originally received the Add.
settleFailFilterKey = []byte("settle-fail-filter-key")
)
// PkgFilter is used to compactly represent a particular subset of the Adds in a
// forwarding package. Each filter is represented as a simple, statically-sized
// bitvector, where the elements are intended to be the indices of the Adds as
// they are written in the FwdPkg.
type PkgFilter struct {
count uint16
filter []byte
}
// NewPkgFilter initializes an empty PkgFilter supporting `count` elements.
func NewPkgFilter(count uint16) *PkgFilter {
// We add 7 to ensure that the integer division yields properly rounded
// values.
filterLen := (count + 7) / 8
return &PkgFilter{
count: count,
filter: make([]byte, filterLen),
}
}
// Count returns the number of elements represented by this PkgFilter.
func (f *PkgFilter) Count() uint16 {
return f.count
}
// Set marks the `i`-th element as included by this filter.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Set(i uint16) {
byt := i / 8
bit := i % 8
// Set the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
f.filter[byt] |= byte(1 << (7 - bit))
}
// Contains queries the filter for membership of index `i`.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Contains(i uint16) bool {
byt := i / 8
bit := i % 8
// Read the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
return f.filter[byt]&(1<<(7-bit)) != 0
}
// Equal checks two PkgFilters for equality.
func (f *PkgFilter) Equal(f2 *PkgFilter) bool {
if f == f2 {
return true
}
if f.count != f2.count {
return false
}
return bytes.Equal(f.filter, f2.filter)
}
// IsFull returns true if every element in the filter has been Set, and false
// otherwise.
func (f *PkgFilter) IsFull() bool {
// Batch validate bytes that are fully used.
for i := uint16(0); i < f.count/8; i++ {
if f.filter[i] != 0xFF {
return false
}
}
// If the count is not a multiple of 8, check that the filter contains
// all remaining bits.
rem := f.count % 8
for idx := f.count - rem; idx < f.count; idx++ {
if !f.Contains(idx) {
return false
}
}
return true
}
// Size returns number of bytes produced when the PkgFilter is serialized.
func (f *PkgFilter) Size() uint16 {
// 2 bytes for uint16 `count`, then round up number of bytes required to
// represent `count` bits.
return 2 + (f.count+7)/8
}
// Encode writes the filter to the provided io.Writer.
func (f *PkgFilter) Encode(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, f.count); err != nil {
return err
}
_, err := w.Write(f.filter)
return err
}
// Decode reads the filter from the provided io.Reader.
func (f *PkgFilter) Decode(r io.Reader) error {
if err := binary.Read(r, binary.BigEndian, &f.count); err != nil {
return err
}
f.filter = make([]byte, f.Size()-2)
_, err := io.ReadFull(r, f.filter)
return err
}
// String returns a human-readable string.
func (f *PkgFilter) String() string {
return fmt.Sprintf("count=%v, filter=%v", f.count, f.filter)
}
// FwdPkg records all adds, settles, and fails that were locked in as a result
// of the remote peer sending us a revocation. Each package is identified by
// the short chanid and remote commitment height corresponding to the revocation
// that locked in the HTLCs. For everything except a locally initiated payment,
// settles and fails in a forwarding package must have a corresponding Add in
// another package, and can be removed individually once the source link has
// received the fail/settle.
//
// Adds cannot be removed, as we need to present the same batch of Adds to
// properly handle replay protection. Instead, we use a PkgFilter to mark that
// we have finished processing a particular Add. A FwdPkg should only be deleted
// after the AckFilter is full and all settles and fails have been persistently
// removed.
type FwdPkg struct {
// Source identifies the channel that wrote this forwarding package.
Source lnwire.ShortChannelID
// Height is the height of the remote commitment chain that locked in
// this forwarding package.
Height uint64
// State signals the persistent condition of the package and directs how
// to reprocess the package in the event of failures.
State FwdState
// Adds contains all add messages which need to be processed and
// forwarded to the switch. Adds does not change over the life of a
// forwarding package.
Adds []LogUpdate
// FwdFilter is a filter containing the indices of all Adds that were
// forwarded to the switch.
FwdFilter *PkgFilter
// AckFilter is a filter containing the indices of all Adds for which
// the source has received a settle or fail and is reflected in the next
// commitment txn. A package should not be removed until IsFull()
// returns true.
AckFilter *PkgFilter
// SettleFails contains all settle and fail messages that should be
// forwarded to the switch.
SettleFails []LogUpdate
// SettleFailFilter is a filter containing the indices of all Settle or
// Fails originating in this package that have been received and locked
// into the incoming link's commitment state.
SettleFailFilter *PkgFilter
}
// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This
// should be used to create a package at the time we receive a revocation.
func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
addUpdates, settleFailUpdates []LogUpdate) *FwdPkg {
nAddUpdates := uint16(len(addUpdates))
nSettleFailUpdates := uint16(len(settleFailUpdates))
return &FwdPkg{
Source: source,
Height: height,
State: FwdStateLockedIn,
Adds: addUpdates,
FwdFilter: NewPkgFilter(nAddUpdates),
AckFilter: NewPkgFilter(nAddUpdates),
SettleFails: settleFailUpdates,
SettleFailFilter: NewPkgFilter(nSettleFailUpdates),
}
}
// SourceRef is a convenience method that returns an AddRef to this forwarding
// package for the index in the argument. It is the caller's responsibility
// to ensure that the index is in bounds.
func (f *FwdPkg) SourceRef(i uint16) AddRef {
return AddRef{
Height: f.Height,
Index: i,
}
}
// DestRef is a convenience method that returns a SettleFailRef to this
// forwarding package for the index in the argument. It is the caller's
// responsibility to ensure that the index is in bounds.
func (f *FwdPkg) DestRef(i uint16) SettleFailRef {
return SettleFailRef{
Source: f.Source,
Height: f.Height,
Index: i,
}
}
// ID returns an unique identifier for this package, used to ensure that sphinx
// replay processing of this batch is idempotent.
func (f *FwdPkg) ID() []byte {
var id = make([]byte, 16)
byteOrder.PutUint64(id[:8], f.Source.ToUint64())
byteOrder.PutUint64(id[8:], f.Height)
return id
}
// String returns a human-readable description of the forwarding package.
func (f *FwdPkg) String() string {
return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)",
f, f.Source, f.Height, len(f.Adds), len(f.SettleFails))
}
// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID
// is assumed to be that of the packager.
type AddRef struct {
// Height is the remote commitment height that locked in the Add.
Height uint64
// Index is the index of the Add within the fwd pkg's Adds.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// Encode serializes the AddRef to the given io.Writer.
func (a *AddRef) Encode(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, a.Height); err != nil {
return err
}
return binary.Write(w, binary.BigEndian, a.Index)
}
// Decode deserializes the AddRef from the given io.Reader.
func (a *AddRef) Decode(r io.Reader) error {
if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil {
return err
}
return binary.Read(r, binary.BigEndian, &a.Index)
}
// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A
// channel does not remove its own Settle/Fail htlcs, so the source is provided
// to locate a db bucket belonging to another channel.
type SettleFailRef struct {
// Source identifies the outgoing link that locked in the settle or
// fail. This is then used by the *incoming* link to find the settle
// fail in another link's forwarding packages.
Source lnwire.ShortChannelID
// Height is the remote commitment height that locked in this
// Settle/Fail.
Height uint64
// Index is the index of the Add with the fwd pkg's SettleFails.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// SettleFailAcker is a generic interface providing the ability to acknowledge
// settle/fail HTLCs stored in forwarding packages.
type SettleFailAcker interface {
// AckSettleFails atomically updates the settle-fail filters in *other*
// channels' forwarding packages.
AckSettleFails(tx kvdb.RwTx, settleFailRefs ...SettleFailRef) error
}
// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages
// of any active channel.
type GlobalFwdPkgReader interface {
// LoadChannelFwdPkgs loads all known forwarding packages for the given
// channel.
LoadChannelFwdPkgs(tx kvdb.RTx,
source lnwire.ShortChannelID) ([]*FwdPkg, error)
}
// FwdOperator defines the interfaces for managing forwarding packages that are
// external to a particular channel. This interface is used by the switch to
// read forwarding packages from arbitrary channels, and acknowledge settles and
// fails for locally-sourced payments.
type FwdOperator interface {
// GlobalFwdPkgReader provides read access to all known forwarding
// packages
GlobalFwdPkgReader
// SettleFailAcker grants the ability to acknowledge settles or fails
// residing in arbitrary forwarding packages.
SettleFailAcker
}
// SwitchPackager is a concrete implementation of the FwdOperator interface.
// A SwitchPackager offers the ability to read any forwarding package, and ack
// arbitrary settle and fail HTLCs.
type SwitchPackager struct{}
// NewSwitchPackager instantiates a new SwitchPackager.
func NewSwitchPackager() *SwitchPackager {
return &SwitchPackager{}
}
// AckSettleFails atomically updates the settle-fail filters in *other*
// channels' forwarding packages, to mark that the switch has received a settle
// or fail residing in the forwarding package of a link.
func (*SwitchPackager) AckSettleFails(tx kvdb.RwTx,
settleFailRefs ...SettleFailRef) error {
return ackSettleFails(tx, settleFailRefs)
}
// LoadChannelFwdPkgs loads all forwarding packages for a particular channel.
func (*SwitchPackager) LoadChannelFwdPkgs(tx kvdb.RTx,
source lnwire.ShortChannelID) ([]*FwdPkg, error) {
return loadChannelFwdPkgs(tx, source)
}
// FwdPackager supports all operations required to modify fwd packages, such as
// creation, updates, reading, and removal. The interfaces are broken down in
// this way to support future delegation of the subinterfaces.
type FwdPackager interface {
// AddFwdPkg serializes and writes a FwdPkg for this channel at the
// remote commitment height included in the forwarding package.
AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error
// SetFwdFilter looks up the forwarding package at the remote `height`
// and sets the `fwdFilter`, marking the Adds for which:
// 1) We are not the exit node
// 2) Passed all validation
// 3) Should be forwarded to the switch immediately after a failure
SetFwdFilter(tx kvdb.RwTx, height uint64, fwdFilter *PkgFilter) error
// AckAddHtlcs atomically updates the add filters in this channel's
// forwarding packages to mark the resolution of an Add that was
// received from the remote party.
AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error
// SettleFailAcker allows a link to acknowledge settle/fail HTLCs
// belonging to other channels.
SettleFailAcker
// LoadFwdPkgs loads all known forwarding packages owned by this
// channel.
LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error)
// RemovePkg deletes a forwarding package owned by this channel at
// the provided remote `height`.
RemovePkg(tx kvdb.RwTx, height uint64) error
// Wipe deletes all the forwarding packages owned by this channel.
Wipe(tx kvdb.RwTx) error
}
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
// packages. The packager is tied to a particular source channel ID, allowing it
// to create and edit its own packages. Each packager also has the ability to
// remove fail/settle htlcs that correspond to an add contained in one of
// source's packages.
type ChannelPackager struct {
source lnwire.ShortChannelID
}
// NewChannelPackager creates a new packager for a single channel.
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
return &ChannelPackager{
source: source,
}
}
// AddFwdPkg writes a newly locked in forwarding package to disk.
func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { // nolint: dupl
fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey)
if err != nil {
return err
}
source := makeLogKey(fwdPkg.Source.ToUint64())
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
if err != nil {
return err
}
heightKey := makeLogKey(fwdPkg.Height)
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
if err != nil {
return err
}
// Write ADD updates we received at this commit height.
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
if err != nil {
return err
}
// Write SETTLE/FAIL updates we received at this commit height.
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
if err != nil {
return err
}
for i := range fwdPkg.Adds {
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
if err != nil {
return err
}
}
// Persist the initialized pkg filter, which will be used to determine
// when we can remove this forwarding package from disk.
var ackFilterBuf bytes.Buffer
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
return err
}
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
return err
}
for i := range fwdPkg.SettleFails {
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
if err != nil {
return err
}
}
var settleFailFilterBuf bytes.Buffer
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
if err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *LogUpdate) error {
var b bytes.Buffer
if err := serializeLogUpdate(&b, htlc); err != nil {
return err
}
return bkt.Put(uint16Key(idx), b.Bytes())
}
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
// processed, and returns their deserialized log updates in a map indexed by the
// remote commitment height at which the updates were locked in.
func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*FwdPkg, error) {
return loadChannelFwdPkgs(tx, p.source)
}
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*FwdPkg, error) {
fwdPkgBkt := tx.ReadBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil, nil
}
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, nil
}
var heights []uint64
if err := sourceBkt.ForEach(func(k, _ []byte) error {
if len(k) != 8 {
return ErrCorruptedFwdPkg
}
heights = append(heights, byteOrder.Uint64(k))
return nil
}); err != nil {
return nil, err
}
// Load the forwarding package for each retrieved height.
fwdPkgs := make([]*FwdPkg, 0, len(heights))
for _, height := range heights {
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
if err != nil {
return nil, err
}
fwdPkgs = append(fwdPkgs, fwdPkg)
}
return fwdPkgs, nil
}
// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the
// appropriate FwdState.
func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID,
height uint64) (*FwdPkg, error) {
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.NestedReadBucket(heightKey[:])
if heightBkt == nil {
return nil, ErrCorruptedFwdPkg
}
// Load ADDs from disk.
addBkt := heightBkt.NestedReadBucket(addBucketKey)
if addBkt == nil {
return nil, ErrCorruptedFwdPkg
}
adds, err := loadHtlcs(addBkt)
if err != nil {
return nil, err
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
ackFilterReader := bytes.NewReader(ackFilterBytes)
ackFilter := &PkgFilter{}
if err := ackFilter.Decode(ackFilterReader); err != nil {
return nil, err
}
// Load SETTLE/FAILs from disk.
failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey)
if failSettleBkt == nil {
return nil, ErrCorruptedFwdPkg
}
failSettles, err := loadHtlcs(failSettleBkt)
if err != nil {
return nil, err
}
// Load settle fail filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
settleFailFilter := &PkgFilter{}
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return nil, err
}
// Initialize the fwding package, which always starts in the
// FwdStateLockedIn. We can determine what state the package was left in
// by examining constraints on the information loaded from disk.
fwdPkg := &FwdPkg{
Source: source,
State: FwdStateLockedIn,
Height: height,
Adds: adds,
AckFilter: ackFilter,
SettleFails: failSettles,
SettleFailFilter: settleFailFilter,
}
// Check to see if we have written the set exported filter adds to
// disk. If we haven't, processing of this package was never started, or
// failed during the last attempt.
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
if fwdFilterBytes == nil {
nAdds := uint16(len(adds))
fwdPkg.FwdFilter = NewPkgFilter(nAdds)
return fwdPkg, nil
}
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
fwdPkg.FwdFilter = &PkgFilter{}
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
return nil, err
}
// Otherwise, a complete round of processing was completed, and we
// advance the package to FwdStateProcessed.
fwdPkg.State = FwdStateProcessed
// If every add, settle, and fail has been fully acknowledged, we can
// safely set the package's state to FwdStateCompleted, signalling that
// it can be garbage collected.
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
fwdPkg.State = FwdStateCompleted
}
return fwdPkg, nil
}
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
// them in order of the indexes they were written under.
func loadHtlcs(bkt kvdb.RBucket) ([]LogUpdate, error) {
var htlcs []LogUpdate
if err := bkt.ForEach(func(_, v []byte) error {
htlc, err := deserializeLogUpdate(bytes.NewReader(v))
if err != nil {
return err
}
htlcs = append(htlcs, *htlc)
return nil
}); err != nil {
return nil, err
}
return htlcs, nil
}
// SetFwdFilter writes the set of indexes corresponding to Adds at the
// `height` that are to be forwarded to the switch. Calling this method causes
// the forwarding package at `height` to be in FwdStateProcessed. We write this
// forwarding decision so that we always arrive at the same behavior for HTLCs
// leaving this channel. After a restart, we skip validation of these Adds,
// since they are assumed to have already been validated, and make the switch or
// outgoing link responsible for handling replays.
func (p *ChannelPackager) SetFwdFilter(tx kvdb.RwTx, height uint64,
fwdFilter *PkgFilter) error {
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
source := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadWriteBucket(source[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:])
if heightBkt == nil {
return ErrCorruptedFwdPkg
}
// If the fwd filter has already been written, we return early to avoid
// modifying the persistent state.
forwardedAddsBytes := heightBkt.Get(fwdFilterKey)
if forwardedAddsBytes != nil {
return nil
}
// Otherwise we serialize and write the provided fwd filter.
var b bytes.Buffer
if err := fwdFilter.Encode(&b); err != nil {
return err
}
return heightBkt.Put(fwdFilterKey, b.Bytes())
}
// AckAddHtlcs accepts a list of references to add htlcs, and updates the
// AckAddFilter of those forwarding packages to indicate that a settle or fail
// has been received in response to the add.
func (p *ChannelPackager) AckAddHtlcs(tx kvdb.RwTx, addRefs ...AddRef) error {
if len(addRefs) == 0 {
return nil
}
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
sourceKey := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceKey[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
// Organize the forward references such that we just get a single slice
// of indexes for each unique height.
heightDiffs := make(map[uint64][]uint16)
for _, addRef := range addRefs {
heightDiffs[addRef.Height] = append(
heightDiffs[addRef.Height],
addRef.Index,
)
}
// Load each height bucket once and remove all acked htlcs at that
// height.
for height, indexes := range heightDiffs {
err := ackAddHtlcsAtHeight(sourceBkt, height, indexes)
if err != nil {
return err
}
}
return nil
}
// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package
// with a list of indexes, writing the resulting filter back in its place.
func ackAddHtlcsAtHeight(sourceBkt kvdb.RwBucket, height uint64,
indexes []uint16) error {
heightKey := makeLogKey(height)
heightBkt := sourceBkt.NestedReadWriteBucket(heightKey[:])
if heightBkt == nil {
// If the height bucket isn't found, this could be because the
// forwarding package was already removed. We'll return nil to
// signal that the operation is successful, as there is nothing
// to ack.
return nil
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return ErrCorruptedFwdPkg
}
ackFilter := &PkgFilter{}
ackFilterReader := bytes.NewReader(ackFilterBytes)
if err := ackFilter.Decode(ackFilterReader); err != nil {
return err
}
// Update the ack filter for this height.
for _, index := range indexes {
ackFilter.Set(index)
}
// Write the resulting filter to disk.
var ackFilterBuf bytes.Buffer
if err := ackFilter.Encode(&ackFilterBuf); err != nil {
return err
}
return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes())
}
// AckSettleFails persistently acknowledges settles or fails from a remote forwarding
// package. This should only be called after the source of the Add has locked in
// the settle/fail, or it becomes otherwise safe to forgo retransmitting the
// settle/fail after a restart.
func (p *ChannelPackager) AckSettleFails(tx kvdb.RwTx, settleFailRefs ...SettleFailRef) error {
return ackSettleFails(tx, settleFailRefs)
}
// ackSettleFails persistently acknowledges a batch of settle fail references.
func ackSettleFails(tx kvdb.RwTx, settleFailRefs []SettleFailRef) error {
if len(settleFailRefs) == 0 {
return nil
}
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return ErrCorruptedFwdPkg
}
// Organize the forward references such that we just get a single slice
// of indexes for each unique destination-height pair.
destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16)
for _, settleFailRef := range settleFailRefs {
destHeights, ok := destHeightDiffs[settleFailRef.Source]
if !ok {
destHeights = make(map[uint64][]uint16)
destHeightDiffs[settleFailRef.Source] = destHeights
}
destHeights[settleFailRef.Height] = append(
destHeights[settleFailRef.Height],
settleFailRef.Index,
)
}
// With the references organized by destination and height, we now load
// each remote bucket, and update the settle fail filter for any
// settle/fail htlcs.
for dest, destHeights := range destHeightDiffs {
destKey := makeLogKey(dest.ToUint64())
destBkt := fwdPkgBkt.NestedReadWriteBucket(destKey[:])
if destBkt == nil {
// If the destination bucket is not found, this is
// likely the result of the destination channel being
// closed and having it's forwarding packages wiped. We
// won't treat this as an error, because the response
// will no longer be retransmitted internally.
continue
}
for height, indexes := range destHeights {
err := ackSettleFailsAtHeight(destBkt, height, indexes)
if err != nil {
return err
}
}
}
return nil
}
// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes
// at particular a height by updating the settle fail filter.
func ackSettleFailsAtHeight(destBkt kvdb.RwBucket, height uint64,
indexes []uint16) error {
heightKey := makeLogKey(height)
heightBkt := destBkt.NestedReadWriteBucket(heightKey[:])
if heightBkt == nil {
// If the height bucket isn't found, this could be because the
// forwarding package was already removed. We'll return nil to
// signal that the operation is as there is nothing to ack.
return nil
}
// Load ack filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return ErrCorruptedFwdPkg
}
settleFailFilter := &PkgFilter{}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return err
}
// Update the ack filter for this height.
for _, index := range indexes {
settleFailFilter.Set(index)
}
// Write the resulting filter to disk.
var settleFailFilterBuf bytes.Buffer
if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// RemovePkg deletes the forwarding package at the given height from the
// packager's source bucket.
func (p *ChannelPackager) RemovePkg(tx kvdb.RwTx, height uint64) error {
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil
}
sourceBytes := makeLogKey(p.source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:])
if sourceBkt == nil {
return ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
return sourceBkt.DeleteNestedBucket(heightKey[:])
}
// Wipe deletes all the channel's forwarding packages, if any.
func (p *ChannelPackager) Wipe(tx kvdb.RwTx) error {
// If the root bucket doesn't exist, there's no need to delete.
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil
}
sourceBytes := makeLogKey(p.source.ToUint64())
// If the nested bucket doesn't exist, there's no need to delete.
if fwdPkgBkt.NestedReadWriteBucket(sourceBytes[:]) == nil {
return nil
}
return fwdPkgBkt.DeleteNestedBucket(sourceBytes[:])
}
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
func uint16Key(i uint16) []byte {
key := make([]byte, 2)
byteOrder.PutUint16(key, i)
return key
}
// Compile-time constraint to ensure that ChannelPackager implements the public
// FwdPackager interface.
var _ FwdPackager = (*ChannelPackager)(nil)
// Compile-time constraint to ensure that SwitchPackager implements the public
// FwdOperator interface.
var _ FwdOperator = (*SwitchPackager)(nil)
package channeldb
import (
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// initialChannelForwardingPolicyBucket is the database bucket used to
// store the forwarding policy for each permanent channel that is
// currently in the process of being opened.
initialChannelForwardingPolicyBucket = []byte(
"initialChannelFwdingPolicy",
)
)
// SaveInitialForwardingPolicy saves the serialized forwarding policy for the
// provided permanent channel id to the initialChannelForwardingPolicyBucket.
func (c *ChannelStateDB) SaveInitialForwardingPolicy(chanID lnwire.ChannelID,
forwardingPolicy *models.ForwardingPolicy) error {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
scratch := make([]byte, 36)
byteOrder.PutUint64(scratch[:8], uint64(forwardingPolicy.MinHTLCOut))
byteOrder.PutUint64(scratch[8:16], uint64(forwardingPolicy.MaxHTLC))
byteOrder.PutUint64(scratch[16:24], uint64(forwardingPolicy.BaseFee))
byteOrder.PutUint64(scratch[24:32], uint64(forwardingPolicy.FeeRate))
byteOrder.PutUint32(scratch[32:], forwardingPolicy.TimeLockDelta)
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(
initialChannelForwardingPolicyBucket,
)
if err != nil {
return err
}
return bucket.Put(chanIDCopy, scratch)
}, func() {})
}
// GetInitialForwardingPolicy fetches the serialized forwarding policy for the
// provided channel id from the database, or returns ErrChannelNotFound if
// a forwarding policy for this channel id is not found.
func (c *ChannelStateDB) GetInitialForwardingPolicy(
chanID lnwire.ChannelID) (*models.ForwardingPolicy, error) {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
var forwardingPolicy *models.ForwardingPolicy
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(initialChannelForwardingPolicyBucket)
if bucket == nil {
// If the bucket does not exist, it means we
// never added a channel fees to the db, so
// return ErrChannelNotFound.
return ErrChannelNotFound
}
stateBytes := bucket.Get(chanIDCopy)
if stateBytes == nil {
return ErrChannelNotFound
}
forwardingPolicy = &models.ForwardingPolicy{
MinHTLCOut: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[:8]),
),
MaxHTLC: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[8:16]),
),
BaseFee: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[16:24]),
),
FeeRate: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[24:32]),
),
TimeLockDelta: byteOrder.Uint32(stateBytes[32:36]),
}
return nil
}, func() {
forwardingPolicy = nil
})
return forwardingPolicy, err
}
// DeleteInitialForwardingPolicy removes the forwarding policy for a given
// channel from the database.
func (c *ChannelStateDB) DeleteInitialForwardingPolicy(
chanID lnwire.ChannelID) error {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(
initialChannelForwardingPolicyBucket,
)
if bucket == nil {
return ErrChannelNotFound
}
return bucket.Delete(chanIDCopy)
}, func() {})
}
package channeldb
import (
"bytes"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// spendHintBucket is the name of the bucket which houses the height
// hint for outpoints. Each height hint represents the earliest height
// at which its corresponding outpoint could have been spent within.
spendHintBucket = []byte("spend-hints")
// confirmHintBucket is the name of the bucket which houses the height
// hints for transactions. Each height hint represents the earliest
// height at which its corresponding transaction could have been
// confirmed within.
confirmHintBucket = []byte("confirm-hints")
)
// CacheConfig contains the HeightHintCache configuration.
type CacheConfig struct {
// QueryDisable prevents reliance on the Height Hint Cache. This is
// necessary to recover from an edge case when the height recorded in
// the cache is higher than the actual height of a spend, causing a
// channel to become "stuck" in a pending close state.
QueryDisable bool
}
// HeightHintCache is an implementation of the SpendHintCache and
// ConfirmHintCache interfaces backed by a channeldb DB instance where the hints
// will be stored.
type HeightHintCache struct {
cfg CacheConfig
db kvdb.Backend
}
// Compile-time checks to ensure HeightHintCache satisfies the SpendHintCache
// and ConfirmHintCache interfaces.
var _ chainntnfs.SpendHintCache = (*HeightHintCache)(nil)
var _ chainntnfs.ConfirmHintCache = (*HeightHintCache)(nil)
// NewHeightHintCache returns a new height hint cache backed by a database.
func NewHeightHintCache(cfg CacheConfig, db kvdb.Backend) (*HeightHintCache,
error) {
cache := &HeightHintCache{cfg, db}
if err := cache.initBuckets(); err != nil {
return nil, err
}
return cache, nil
}
// initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup.
func (c *HeightHintCache) initBuckets() error {
return kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(spendHintBucket)
if err != nil {
return err
}
_, err = tx.CreateTopLevelBucket(confirmHintBucket)
return err
})
}
// CommitSpendHint commits a spend hint for the outpoints to the cache.
func (c *HeightHintCache) CommitSpendHint(height uint32,
spendRequests ...chainntnfs.SpendRequest) error {
if len(spendRequests) == 0 {
return nil
}
log.Tracef("Updating spend hint to height %d for %v", height,
spendRequests)
return kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
spendHints := tx.ReadWriteBucket(spendHintBucket)
if spendHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
var hint bytes.Buffer
if err := WriteElement(&hint, height); err != nil {
return err
}
for _, spendRequest := range spendRequests {
spendRequest := spendRequest
spendHintKey, err := spendHintKey(&spendRequest)
if err != nil {
return err
}
err = spendHints.Put(spendHintKey, hint.Bytes())
if err != nil {
return err
}
}
return nil
})
}
// QuerySpendHint returns the latest spend hint for an outpoint.
// ErrSpendHintNotFound is returned if a spend hint does not exist within the
// cache for the outpoint.
func (c *HeightHintCache) QuerySpendHint(
spendRequest chainntnfs.SpendRequest) (uint32, error) {
var hint uint32
if c.cfg.QueryDisable {
log.Debugf("Ignoring spend height hint for %v (height hint "+
"cache query disabled)", spendRequest)
return 0, nil
}
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
spendHints := tx.ReadBucket(spendHintBucket)
if spendHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
spendHintKey, err := spendHintKey(&spendRequest)
if err != nil {
return err
}
spendHint := spendHints.Get(spendHintKey)
if spendHint == nil {
return chainntnfs.ErrSpendHintNotFound
}
return ReadElement(bytes.NewReader(spendHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err
}
return hint, nil
}
// PurgeSpendHint removes the spend hint for the outpoints from the cache.
func (c *HeightHintCache) PurgeSpendHint(
spendRequests ...chainntnfs.SpendRequest) error {
if len(spendRequests) == 0 {
return nil
}
log.Tracef("Removing spend hints for %v", spendRequests)
return kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
spendHints := tx.ReadWriteBucket(spendHintBucket)
if spendHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
for _, spendRequest := range spendRequests {
spendRequest := spendRequest
spendHintKey, err := spendHintKey(&spendRequest)
if err != nil {
return err
}
if err := spendHints.Delete(spendHintKey); err != nil {
return err
}
}
return nil
})
}
// CommitConfirmHint commits a confirm hint for the transactions to the cache.
func (c *HeightHintCache) CommitConfirmHint(height uint32,
confRequests ...chainntnfs.ConfRequest) error {
if len(confRequests) == 0 {
return nil
}
log.Tracef("Updating confirm hints to height %d for %v", height,
confRequests)
return kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
confirmHints := tx.ReadWriteBucket(confirmHintBucket)
if confirmHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
var hint bytes.Buffer
if err := WriteElement(&hint, height); err != nil {
return err
}
for _, confRequest := range confRequests {
confRequest := confRequest
confHintKey, err := confHintKey(&confRequest)
if err != nil {
return err
}
err = confirmHints.Put(confHintKey, hint.Bytes())
if err != nil {
return err
}
}
return nil
})
}
// QueryConfirmHint returns the latest confirm hint for a transaction hash.
// ErrConfirmHintNotFound is returned if a confirm hint does not exist within
// the cache for the transaction hash.
func (c *HeightHintCache) QueryConfirmHint(
confRequest chainntnfs.ConfRequest) (uint32, error) {
var hint uint32
if c.cfg.QueryDisable {
log.Debugf("Ignoring confirmation height hint for %v (height "+
"hint cache query disabled)", confRequest)
return 0, nil
}
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
confirmHints := tx.ReadBucket(confirmHintBucket)
if confirmHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
confHintKey, err := confHintKey(&confRequest)
if err != nil {
return err
}
confirmHint := confirmHints.Get(confHintKey)
if confirmHint == nil {
return chainntnfs.ErrConfirmHintNotFound
}
return ReadElement(bytes.NewReader(confirmHint), &hint)
}, func() {
hint = 0
})
if err != nil {
return 0, err
}
return hint, nil
}
// PurgeConfirmHint removes the confirm hint for the transactions from the
// cache.
func (c *HeightHintCache) PurgeConfirmHint(
confRequests ...chainntnfs.ConfRequest) error {
if len(confRequests) == 0 {
return nil
}
log.Tracef("Removing confirm hints for %v", confRequests)
return kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
confirmHints := tx.ReadWriteBucket(confirmHintBucket)
if confirmHints == nil {
return chainntnfs.ErrCorruptedHeightHintCache
}
for _, confRequest := range confRequests {
confRequest := confRequest
confHintKey, err := confHintKey(&confRequest)
if err != nil {
return err
}
if err := confirmHints.Delete(confHintKey); err != nil {
return err
}
}
return nil
})
}
// confHintKey returns the key that will be used to index the confirmation
// request's hint within the height hint cache.
func confHintKey(r *chainntnfs.ConfRequest) ([]byte, error) {
if r.TxID == chainntnfs.ZeroHash {
return r.PkScript.Script(), nil
}
var txid bytes.Buffer
if err := WriteElement(&txid, r.TxID); err != nil {
return nil, err
}
return txid.Bytes(), nil
}
// spendHintKey returns the key that will be used to index the spend request's
// hint within the height hint cache.
func spendHintKey(r *chainntnfs.SpendRequest) ([]byte, error) {
if r.OutPoint == chainntnfs.ZeroOutPoint {
return r.PkScript.Script(), nil
}
var outpoint bytes.Buffer
err := WriteElement(&outpoint, r.OutPoint)
if err != nil {
return nil, err
}
return outpoint.Bytes(), nil
}
package channeldb
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"maps"
"time"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
// Within the invoice bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// paymentHashIndexBucket is the name of the sub-bucket within the
// invoiceBucket which indexes all invoices by their payment hash. The
// payment hash is the sha256 of the invoice's payment preimage. This
// index is used to detect duplicates, and also to provide a fast path
// for looking up incoming HTLCs to determine if we're able to settle
// them fully.
//
// maps: payHash => invoiceKey
invoiceIndexBucket = []byte("paymenthashes")
// payAddrIndexBucket is the name of the top-level bucket that maps
// payment addresses to their invoice number. This can be used
// to efficiently query or update non-legacy invoices. Note that legacy
// invoices will not be included in this index since they all have the
// same, all-zero payment address, however all newly generated invoices
// will end up in this index.
//
// maps: payAddr => invoiceKey
payAddrIndexBucket = []byte("pay-addr-index")
// setIDIndexBucket is the name of the top-level bucket that maps set
// ids to their invoice number. This can be used to efficiently query or
// update AMP invoice. Note that legacy or MPP invoices will not be
// included in this index, since their HTLCs do not have a set id.
//
// maps: setID => invoiceKey
setIDIndexBucket = []byte("set-id-index")
// numInvoicesKey is the name of key which houses the auto-incrementing
// invoice ID which is essentially used as a primary key. With each
// invoice inserted, the primary key is incremented by one. This key is
// stored within the invoiceIndexBucket. Within the invoiceBucket
// invoices are uniquely identified by the invoice ID.
numInvoicesKey = []byte("nik")
// addIndexBucket is an index bucket that we'll use to create a
// monotonically increasing set of add indexes. Each time we add a new
// invoice, this sequence number will be incremented and then populated
// within the new invoice.
//
// In addition to this sequence number, we map:
//
// addIndexNo => invoiceKey
addIndexBucket = []byte("invoice-add-index")
// settleIndexBucket is an index bucket that we'll use to create a
// monotonically increasing integer for tracking a "settle index". Each
// time an invoice is settled, this sequence number will be incremented
// as populate within the newly settled invoice.
//
// In addition to this sequence number, we map:
//
// settleIndexNo => invoiceKey
settleIndexBucket = []byte("invoice-settle-index")
// invoiceBucketTombstone is a special key that indicates the invoice
// bucket has been permanently closed. Its purpose is to prevent the
// invoice bucket from being reopened in the future. A key use case for
// the tombstone is to ensure users cannot switch back to the KV invoice
// database after migrating to the native SQL database.
invoiceBucketTombstone = []byte("invoice-tombstone")
)
const (
// A set of tlv type definitions used to serialize invoice htlcs to the
// database.
//
// NOTE: A migration should be added whenever this list changes. This
// prevents against the database being rolled back to an older
// format where the surrounding logic might assume a different set of
// fields are known.
chanIDType tlv.Type = 1
htlcIDType tlv.Type = 3
amtType tlv.Type = 5
acceptHeightType tlv.Type = 7
acceptTimeType tlv.Type = 9
resolveTimeType tlv.Type = 11
expiryHeightType tlv.Type = 13
htlcStateType tlv.Type = 15
mppTotalAmtType tlv.Type = 17
htlcAMPType tlv.Type = 19
htlcHashType tlv.Type = 21
htlcPreimageType tlv.Type = 23
// A set of tlv type definitions used to serialize invoice bodiees.
//
// NOTE: A migration should be added whenever this list changes. This
// prevents against the database being rolled back to an older
// format where the surrounding logic might assume a different set of
// fields are known.
memoType tlv.Type = 0
payReqType tlv.Type = 1
createTimeType tlv.Type = 2
settleTimeType tlv.Type = 3
addIndexType tlv.Type = 4
settleIndexType tlv.Type = 5
preimageType tlv.Type = 6
valueType tlv.Type = 7
cltvDeltaType tlv.Type = 8
expiryType tlv.Type = 9
paymentAddrType tlv.Type = 10
featuresType tlv.Type = 11
invStateType tlv.Type = 12
amtPaidType tlv.Type = 13
hodlInvoiceType tlv.Type = 14
invoiceAmpStateType tlv.Type = 15
// A set of tlv type definitions used to serialize the invoice AMP
// state along-side the main invoice body.
ampStateSetIDType tlv.Type = 0
ampStateHtlcStateType tlv.Type = 1
ampStateSettleIndexType tlv.Type = 2
ampStateSettleDateType tlv.Type = 3
ampStateCircuitKeysType tlv.Type = 4
ampStateAmtPaidType tlv.Type = 5
// invoiceProgressLogInterval is the interval we use limiting the
// logging output of invoice processing.
invoiceProgressLogInterval = 30 * time.Second
)
// AddInvoice inserts the targeted invoice into the database. If the invoice has
// *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes. A side effect of this function is that it sets
// AddIndex on newInvoice.
func (d *DB) AddInvoice(_ context.Context, newInvoice *invpkg.Invoice,
paymentHash lntypes.Hash) (uint64, error) {
if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
return 0, err
}
var invoiceAddIndex uint64
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
if err != nil {
return err
}
invoiceIndex, err := invoices.CreateBucketIfNotExists(
invoiceIndexBucket,
)
if err != nil {
return err
}
addIndex, err := invoices.CreateBucketIfNotExists(
addIndexBucket,
)
if err != nil {
return err
}
// Ensure that an invoice an identical payment hash doesn't
// already exist within the index.
if invoiceIndex.Get(paymentHash[:]) != nil {
return invpkg.ErrDuplicateInvoice
}
// Check that we aren't inserting an invoice with a duplicate
// payment address. The all-zeros payment address is
// special-cased to support legacy keysend invoices which don't
// assign one. This is safe since later we also will avoid
// indexing them and avoid collisions.
payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
paymentAddr := newInvoice.Terms.PaymentAddr[:]
if payAddrIndex.Get(paymentAddr) != nil {
return invpkg.ErrDuplicatePayAddr
}
}
// If the current running payment ID counter hasn't yet been
// created, then create it now.
var invoiceNum uint32
invoiceCounter := invoiceIndex.Get(numInvoicesKey)
if invoiceCounter == nil {
var scratch [4]byte
byteOrder.PutUint32(scratch[:], invoiceNum)
err := invoiceIndex.Put(numInvoicesKey, scratch[:])
if err != nil {
return err
}
} else {
invoiceNum = byteOrder.Uint32(invoiceCounter)
}
newIndex, err := putInvoice(
invoices, invoiceIndex, payAddrIndex, addIndex,
newInvoice, invoiceNum, paymentHash,
)
if err != nil {
return err
}
invoiceAddIndex = newIndex
return nil
}, func() {
invoiceAddIndex = 0
})
if err != nil {
return 0, err
}
return invoiceAddIndex, err
}
// InvoicesAddedSince can be used by callers to seek into the event time series
// of all the invoices added in the database. The specified sinceAddIndex
// should be the highest add index that the caller knows of. This method will
// return all invoices with an add index greater than the specified
// sinceAddIndex.
//
// NOTE: The index starts from 1, as a result. We enforce that specifying a
// value below the starting index value is a noop.
func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
[]invpkg.Invoice, error) {
var (
newInvoices []invpkg.Invoice
start = time.Now()
lastLogTime = time.Now()
processedCount int
)
// If an index of zero was specified, then in order to maintain
// backwards compat, we won't send out any new invoices.
if sinceAddIndex == 0 {
return newInvoices, nil
}
var startIndex [8]byte
byteOrder.PutUint64(startIndex[:], sinceAddIndex)
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return nil
}
addIndex := invoices.NestedReadBucket(addIndexBucket)
if addIndex == nil {
return nil
}
// We'll now run through each entry in the add index starting
// at our starting index. We'll continue until we reach the
// very end of the current key space.
invoiceCursor := addIndex.ReadCursor()
// We'll seek to the starting index, then manually advance the
// cursor in order to skip the entry with the since add index.
invoiceCursor.Seek(startIndex[:])
addSeqNo, invoiceKey := invoiceCursor.Next()
for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(
invoiceKey, invoices, nil, false,
)
if err != nil {
return err
}
newInvoices = append(newInvoices, invoice)
processedCount++
if time.Since(lastLogTime) >=
invoiceProgressLogInterval {
log.Debugf("Processed %d invoices which "+
"were added since add index %v",
processedCount, sinceAddIndex)
lastLogTime = time.Now()
}
}
return nil
}, func() {
newInvoices = nil
})
if err != nil {
return nil, err
}
elapsed := time.Since(start)
log.Debugf("Completed scanning for invoices added since index %v: "+
"total_processed=%d, found_invoices=%d, elapsed=%v",
sinceAddIndex, processedCount, len(newInvoices),
elapsed.Round(time.Millisecond))
return newInvoices, nil
}
// LookupInvoice attempts to look up an invoice according to its 32 byte
// payment hash. If an invoice which can settle the HTLC identified by the
// passed payment hash isn't found, then an error is returned. Otherwise, the
// full invoice is returned. Before setting the incoming HTLC, the values
// SHOULD be checked to ensure the payer meets the agreed upon contractual
// terms of the payment.
func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
invpkg.Invoice, error) {
var invoice invpkg.Invoice
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return invpkg.ErrNoInvoicesCreated
}
invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
if invoiceIndex == nil {
return invpkg.ErrNoInvoicesCreated
}
payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
setIDIndex := tx.ReadBucket(setIDIndexBucket)
// Retrieve the invoice number for this invoice using
// the provided invoice reference.
invoiceNum, err := fetchInvoiceNumByRef(
invoiceIndex, payAddrIndex, setIDIndex, ref,
)
if err != nil {
return err
}
var setID *invpkg.SetID
switch {
// If this is a payment address ref, and the blank modified was
// specified, then we'll use the zero set ID to indicate that
// we won't want any HTLCs returned.
case ref.PayAddr() != nil &&
ref.Modifier() == invpkg.HtlcSetBlankModifier:
var zeroSetID invpkg.SetID
setID = &zeroSetID
// If this is a set ID ref, and the htlc set only modified was
// specified, then we'll pass through the specified setID so
// only that will be returned.
case ref.SetID() != nil &&
ref.Modifier() == invpkg.HtlcSetOnlyModifier:
setID = (*invpkg.SetID)(ref.SetID())
}
// An invoice was found, retrieve the remainder of the invoice
// body.
i, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
)
if err != nil {
return err
}
invoice = i
return nil
}, func() {})
if err != nil {
return invoice, err
}
return invoice, nil
}
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
// reference. The payment address will be treated as the primary key, falling
// back to the payment hash if nothing is found for the payment address. An
// error is returned if the invoice is not found.
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex, setIDIndex kvdb.RBucket,
ref invpkg.InvoiceRef) ([]byte, error) {
// If the set id is present, we only consult the set id index for this
// invoice. This type of query is only used to facilitate user-facing
// requests to lookup, settle or cancel an AMP invoice.
setID := ref.SetID()
if setID != nil {
invoiceNumBySetID := setIDIndex.Get(setID[:])
if invoiceNumBySetID == nil {
return nil, invpkg.ErrInvoiceNotFound
}
return invoiceNumBySetID, nil
}
payHash := ref.PayHash()
payAddr := ref.PayAddr()
getInvoiceNumByHash := func() []byte {
if payHash != nil {
return invoiceIndex.Get(payHash[:])
}
return nil
}
getInvoiceNumByAddr := func() []byte {
if payAddr != nil {
// Only allow lookups for payment address if it is not a
// blank payment address, which is a special-cased value
// for legacy keysend invoices.
if *payAddr != invpkg.BlankPayAddr {
return payAddrIndex.Get(payAddr[:])
}
}
return nil
}
invoiceNumByHash := getInvoiceNumByHash()
invoiceNumByAddr := getInvoiceNumByAddr()
switch {
// If payment address and payment hash both reference an existing
// invoice, ensure they reference the _same_ invoice.
case invoiceNumByAddr != nil && invoiceNumByHash != nil:
if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
return nil, invpkg.ErrInvRefEquivocation
}
return invoiceNumByAddr, nil
// Return invoices by payment addr only.
//
// NOTE: We constrain this lookup to only apply if the invoice ref does
// not contain a payment hash. Legacy and MPP payments depend on the
// payment hash index to enforce that the HTLCs payment hash matches the
// payment hash for the invoice, without this check we would
// inadvertently assume the invoice contains the correct preimage for
// the HTLC, which we only enforce via the lookup by the invoice index.
case invoiceNumByAddr != nil && payHash == nil:
return invoiceNumByAddr, nil
// If we were only able to reference the invoice by hash, return the
// corresponding invoice number. This can happen when no payment address
// was provided, or if it didn't match anything in our records.
case invoiceNumByHash != nil:
return invoiceNumByHash, nil
// Otherwise we don't know of the target invoice.
default:
return nil, invpkg.ErrInvoiceNotFound
}
}
// FetchPendingInvoices returns all invoices that have not yet been settled or
// canceled. The returned map is keyed by the payment hash of each respective
// invoice.
func (d *DB) FetchPendingInvoices(_ context.Context) (
map[lntypes.Hash]invpkg.Invoice, error) {
result := make(map[lntypes.Hash]invpkg.Invoice)
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return nil
}
invoiceIndex := invoices.NestedReadBucket(invoiceIndexBucket)
if invoiceIndex == nil {
// Mask the error if there's no invoice
// index as that simply means there are no
// invoices added yet to the DB. In this case
// we simply return an empty list.
return nil
}
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
// Skip sub-buckets.
if v == nil {
return nil
}
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
if invoice.IsPending() {
var paymentHash lntypes.Hash
copy(paymentHash[:], k)
result[paymentHash] = invoice
}
return nil
})
}, func() {
result = make(map[lntypes.Hash]invpkg.Invoice)
})
if err != nil {
return nil, err
}
return result, nil
}
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
invpkg.InvoiceSlice, error) {
var resp invpkg.InvoiceSlice
err := kvdb.View(d, func(tx kvdb.RTx) error {
// If the bucket wasn't found, then there aren't any invoices
// within the database yet, so we can simply exit.
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return invpkg.ErrNoInvoicesCreated
}
// Get the add index bucket which we will use to iterate through
// our indexed invoices.
invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket)
if invoiceAddIndex == nil {
return invpkg.ErrNoInvoicesCreated
}
// Create a paginator which reads from our add index bucket with
// the parameters provided by the invoice query.
paginator := newPaginator(
invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset,
q.NumMaxInvoices,
)
// accumulateInvoices looks up an invoice based on the index we
// are given, adds it to our set of invoices if it has the right
// characteristics for our query and returns the number of items
// we have added to our set of invoices.
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
invoice, err := fetchInvoice(
indexValue, invoices, nil, false,
)
if err != nil {
return false, err
}
// Skip any settled or canceled invoices if the caller
// is only interested in pending ones.
if q.PendingOnly && !invoice.IsPending() {
return false, nil
}
// Get the creation time in Unix seconds, this always
// rounds down the nanoseconds to full seconds.
createTime := invoice.CreationDate.Unix()
// Skip any invoices that were created before the
// specified time.
if createTime < q.CreationDateStart {
return false, nil
}
// Skip any invoices that were created after the
// specified time.
if q.CreationDateEnd != 0 &&
createTime > q.CreationDateEnd {
return false, nil
}
// At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range.
resp.Invoices = append(resp.Invoices, invoice)
return true, nil
}
// Query our paginator using accumulateInvoices to build up a
// set of invoices.
if err := paginator.query(accumulateInvoices); err != nil {
return err
}
// If we iterated through the add index in reverse order, then
// we'll need to reverse the slice of invoices to return them in
// forward order.
if q.Reversed {
numInvoices := len(resp.Invoices)
for i := 0; i < numInvoices/2; i++ {
reverse := numInvoices - i - 1
resp.Invoices[i], resp.Invoices[reverse] =
resp.Invoices[reverse], resp.Invoices[i]
}
}
return nil
}, func() {
resp = invpkg.InvoiceSlice{
InvoiceQuery: q,
}
})
if err != nil && !errors.Is(err, invpkg.ErrNoInvoicesCreated) {
return resp, err
}
// Finally, record the indexes of the first and last invoices returned
// so that the caller can resume from this point later on.
if len(resp.Invoices) > 0 {
resp.FirstIndexOffset = resp.Invoices[0].AddIndex
lastIdx := len(resp.Invoices) - 1
resp.LastIndexOffset = resp.Invoices[lastIdx].AddIndex
}
return resp, nil
}
// UpdateInvoice attempts to update an invoice corresponding to the passed
// payment hash. If an invoice matching the passed payment hash doesn't exist
// within the database, then the action will fail with a "not found" error.
//
// The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback. When updating an invoice, the update itself happens
// in-memory on a copy of the invoice. Once it is written successfully to the
// database, the in-memory copy is returned to the caller.
func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
setIDHint *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
*invpkg.Invoice, error) {
var updatedInvoice *invpkg.Invoice
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
if err != nil {
return err
}
invoiceIndex, err := invoices.CreateBucketIfNotExists(
invoiceIndexBucket,
)
if err != nil {
return err
}
settleIndex, err := invoices.CreateBucketIfNotExists(
settleIndexBucket,
)
if err != nil {
return err
}
payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
setIDIndex := tx.ReadWriteBucket(setIDIndexBucket)
// Retrieve the invoice number for this invoice using the
// provided invoice reference.
invoiceNum, err := fetchInvoiceNumByRef(
invoiceIndex, payAddrIndex, setIDIndex, ref,
)
if err != nil {
return err
}
// setIDHint can also be nil here, which means all the HTLCs
// for AMP invoices are fetched. If the blank setID is passed
// in, then no HTLCs are fetched for the AMP invoice. If a
// specific setID is passed in, then only the HTLCs for that
// setID are fetched for a particular sub-AMP invoice.
invoice, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setIDHint}, false,
)
if err != nil {
return err
}
now := d.clock.Now()
updater := &kvInvoiceUpdater{
db: d,
invoicesBucket: invoices,
settleIndexBucket: settleIndex,
setIDIndexBucket: setIDIndex,
updateTime: now,
invoiceNum: invoiceNum,
invoice: &invoice,
updatedAmpHtlcs: make(ampHTLCsMap),
settledSetIDs: make(map[invpkg.SetID]struct{}),
}
payHash := ref.PayHash()
updatedInvoice, err = invpkg.UpdateInvoice(
payHash, updater.invoice, now, callback, updater,
)
if err != nil {
return err
}
// If this is an AMP update, then limit the returned AMP state
// to only the requested set ID.
if setIDHint != nil {
filterInvoiceAMPState(updatedInvoice, setIDHint)
}
return nil
}, func() {
updatedInvoice = nil
})
return updatedInvoice, err
}
// filterInvoiceAMPState filters the AMP state of the invoice to only include
// state for the specified set IDs.
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
filteredAMPState := make(invpkg.AMPInvoiceState)
for _, setID := range setIDs {
if setID == nil {
return
}
ampState, ok := invoice.AMPState[*setID]
if ok {
filteredAMPState[*setID] = ampState
}
}
invoice.AMPState = filteredAMPState
}
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
// kvInvoiceUpdater is an implementation of the InvoiceUpdater interface that
// is used with the kv implementation of the invoice database. Note that this
// updater is not concurrency safe and synchronizaton is expected to be handled
// on the DB level.
type kvInvoiceUpdater struct {
db *DB
invoicesBucket kvdb.RwBucket
settleIndexBucket kvdb.RwBucket
setIDIndexBucket kvdb.RwBucket
// updateTime is the timestamp for the update.
updateTime time.Time
// invoiceNum is a legacy key similar to the add index that is used
// only in the kv implementation.
invoiceNum []byte
// invoice is the invoice that we're updating. As a side effect of the
// update this invoice will be mutated.
invoice *invpkg.Invoice
// updatedAmpHtlcs holds the set of AMP HTLCs that were added or
// cancelled as part of this update.
updatedAmpHtlcs ampHTLCsMap
// settledSetIDs holds the set IDs that are settled with this update.
settledSetIDs map[invpkg.SetID]struct{}
}
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
func (k *kvInvoiceUpdater) AddHtlc(_ models.CircuitKey,
_ *invpkg.InvoiceHTLC) error {
return nil
}
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
func (k *kvInvoiceUpdater) ResolveHtlc(_ models.CircuitKey, _ invpkg.HtlcState,
_ time.Time) error {
return nil
}
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
func (k *kvInvoiceUpdater) AddAmpHtlcPreimage(_ [32]byte, _ models.CircuitKey,
_ lntypes.Preimage) error {
return nil
}
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
func (k *kvInvoiceUpdater) UpdateInvoiceState(_ invpkg.ContractState,
_ *lntypes.Preimage) error {
return nil
}
// NOTE: this method does nothing in the k/v implementation of InvoiceUpdater.
func (k *kvInvoiceUpdater) UpdateInvoiceAmtPaid(_ lnwire.MilliSatoshi) error {
return nil
}
// UpdateAmpState updates the state of the AMP invoice identified by the setID.
func (k *kvInvoiceUpdater) UpdateAmpState(setID [32]byte,
state invpkg.InvoiceStateAMP, circuitKey models.CircuitKey) error {
if _, ok := k.updatedAmpHtlcs[setID]; !ok {
switch state.State {
case invpkg.HtlcStateAccepted:
// If we're just now creating the HTLCs for this set
// then we'll also pull in the existing HTLCs that are
// part of this set, so we can write them all to disk
// together (same value)
k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
&setID, invpkg.HtlcStateAccepted,
)
case invpkg.HtlcStateCanceled:
// Only HTLCs in the accepted state, can be cancelled,
// but we also want to merge that with HTLCs that may be
// canceled as well since it can be cancelled one by
// one.
k.updatedAmpHtlcs[setID] = k.invoice.HTLCSet(
&setID, invpkg.HtlcStateAccepted,
)
cancelledHtlcs := k.invoice.HTLCSet(
&setID, invpkg.HtlcStateCanceled,
)
maps.Copy(k.updatedAmpHtlcs[setID], cancelledHtlcs)
case invpkg.HtlcStateSettled:
k.updatedAmpHtlcs[setID] = make(
map[models.CircuitKey]*invpkg.InvoiceHTLC,
)
}
}
if state.State == invpkg.HtlcStateSettled {
// Add the set ID to the set that was settled in this invoice
// update. We'll use this later to update the settle index.
k.settledSetIDs[setID] = struct{}{}
}
k.updatedAmpHtlcs[setID][circuitKey] = k.invoice.Htlcs[circuitKey]
return nil
}
// Finalize finalizes the update before it is written to the database.
func (k *kvInvoiceUpdater) Finalize(updateType invpkg.UpdateType) error {
switch updateType {
case invpkg.AddHTLCsUpdate:
return k.storeAddHtlcsUpdate()
case invpkg.CancelHTLCsUpdate:
return k.storeCancelHtlcsUpdate()
case invpkg.SettleHodlInvoiceUpdate:
return k.storeSettleHodlInvoiceUpdate()
case invpkg.CancelInvoiceUpdate:
// Persist all changes which where made when cancelling the
// invoice. All HTLCs which were accepted are now canceled, so
// we persist this state.
return k.storeCancelHtlcsUpdate()
}
return fmt.Errorf("unknown update type: %v", updateType)
}
// storeCancelHtlcsUpdate updates the invoice in the database after cancelling a
// set of HTLCs.
func (k *kvInvoiceUpdater) storeCancelHtlcsUpdate() error {
err := k.serializeAndStoreInvoice()
if err != nil {
return err
}
// If this is an AMP invoice, then we'll actually store the rest
// of the HTLCs in-line with the invoice, using the invoice ID
// as a prefix, and the AMP key as a suffix: invoiceNum ||
// setID.
if k.invoice.IsAMP() {
return k.updateAMPInvoices()
}
return nil
}
// storeAddHtlcsUpdate updates the invoice in the database after adding a set of
// HTLCs.
func (k *kvInvoiceUpdater) storeAddHtlcsUpdate() error {
invoiceIsAMP := k.invoice.IsAMP()
for htlcSetID := range k.updatedAmpHtlcs {
// Check if this SetID already exist.
setIDInvNum := k.setIDIndexBucket.Get(htlcSetID[:])
if setIDInvNum == nil {
err := k.setIDIndexBucket.Put(
htlcSetID[:], k.invoiceNum,
)
if err != nil {
return err
}
} else if !bytes.Equal(setIDInvNum, k.invoiceNum) {
return invpkg.ErrDuplicateSetID{
SetID: htlcSetID,
}
}
}
// If this is a non-AMP invoice, then the state can eventually go to
// ContractSettled, so we pass in nil value as part of
// setSettleMetaFields.
if !invoiceIsAMP && k.invoice.State == invpkg.ContractSettled {
err := k.setSettleMetaFields(nil)
if err != nil {
return err
}
}
// As we don't update the settle index above for AMP invoices, we'll do
// it here for each sub-AMP invoice that was settled.
for settledSetID := range k.settledSetIDs {
settledSetID := settledSetID
err := k.setSettleMetaFields(&settledSetID)
if err != nil {
return err
}
}
err := k.serializeAndStoreInvoice()
if err != nil {
return err
}
// If this is an AMP invoice, then we'll actually store the rest of the
// HTLCs in-line with the invoice, using the invoice ID as a prefix,
// and the AMP key as a suffix: invoiceNum || setID.
if invoiceIsAMP {
return k.updateAMPInvoices()
}
return nil
}
// storeSettleHodlInvoiceUpdate updates the invoice in the database after
// settling a hodl invoice.
func (k *kvInvoiceUpdater) storeSettleHodlInvoiceUpdate() error {
err := k.setSettleMetaFields(nil)
if err != nil {
return err
}
return k.serializeAndStoreInvoice()
}
// setSettleMetaFields updates the metadata associated with settlement of an
// invoice. If a non-nil setID is passed in, then the value will be append to
// the invoice number as well, in order to allow us to detect repeated payments
// to the same AMP invoices "across time".
func (k *kvInvoiceUpdater) setSettleMetaFields(setID *invpkg.SetID) error {
// Now that we know the invoice hasn't already been settled, we'll
// update the settle index so we can place this settle event in the
// proper location within our time series.
nextSettleSeqNo, err := k.settleIndexBucket.NextSequence()
if err != nil {
return err
}
// Make a new byte array on the stack that can potentially store the 4
// byte invoice number along w/ the 32 byte set ID. We capture valueLen
// here which is the number of bytes copied so we can only store the 4
// bytes if this is a non-AMP invoice.
var indexKey [invoiceSetIDKeyLen]byte
valueLen := copy(indexKey[:], k.invoiceNum)
if setID != nil {
valueLen += copy(indexKey[valueLen:], setID[:])
}
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
err = k.settleIndexBucket.Put(seqNoBytes[:], indexKey[:valueLen])
if err != nil {
return err
}
// If the setID is nil, then this means that this is a non-AMP settle,
// so we'll update the invoice settle index directly.
if setID == nil {
k.invoice.SettleDate = k.updateTime
k.invoice.SettleIndex = nextSettleSeqNo
} else {
// If the set ID isn't blank, we'll update the AMP state map
// which tracks when each of the setIDs associated with a given
// AMP invoice are settled.
ampState := k.invoice.AMPState[*setID]
ampState.SettleDate = k.updateTime
ampState.SettleIndex = nextSettleSeqNo
k.invoice.AMPState[*setID] = ampState
}
return nil
}
// updateAMPInvoices updates the set of AMP invoices in-place. For AMP, rather
// then continually write the invoices to the end of the invoice value, we
// instead write the invoices into a new key preifx that follows the main
// invoice number. This ensures that we don't need to continually decode a
// potentially massive HTLC set, and also allows us to quickly find the HLTCs
// associated with a particular HTLC set.
func (k *kvInvoiceUpdater) updateAMPInvoices() error {
for setID, htlcSet := range k.updatedAmpHtlcs {
// First write out the set of HTLCs including all the relevant
// TLV values.
var b bytes.Buffer
if err := serializeHtlcs(&b, htlcSet); err != nil {
return err
}
// Next store each HTLC in-line, using a prefix based off the
// invoice number.
invoiceSetIDKey := makeInvoiceSetIDKey(k.invoiceNum, setID[:])
err := k.invoicesBucket.Put(invoiceSetIDKey[:], b.Bytes())
if err != nil {
return err
}
}
return nil
}
// serializeAndStoreInvoice is a helper function used to store invoices.
func (k *kvInvoiceUpdater) serializeAndStoreInvoice() error {
var buf bytes.Buffer
if err := serializeInvoice(&buf, k.invoice); err != nil {
return err
}
return k.invoicesBucket.Put(k.invoiceNum, buf.Bytes())
}
// InvoicesSettledSince can be used by callers to catch up any settled invoices
// they missed within the settled invoice time series. We'll return all known
// settled invoice that have a settle index higher than the passed
// sinceSettleIndex.
//
// NOTE: The index starts from 1, as a result. We enforce that specifying a
// value below the starting index value is a noop.
func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
[]invpkg.Invoice, error) {
var (
settledInvoices []invpkg.Invoice
start = time.Now()
lastLogTime = time.Now()
processedCount int
)
// If an index of zero was specified, then in order to maintain
// backwards compat, we won't send out any new invoices.
if sinceSettleIndex == 0 {
return settledInvoices, nil
}
var startIndex [8]byte
byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return nil
}
settleIndex := invoices.NestedReadBucket(settleIndexBucket)
if settleIndex == nil {
return nil
}
// We'll now run through each entry in the add index starting
// at our starting index. We'll continue until we reach the
// very end of the current key space.
invoiceCursor := settleIndex.ReadCursor()
// We'll seek to the starting index, then manually advance the
// cursor in order to skip the entry with the since add index.
invoiceCursor.Seek(startIndex[:])
seqNo, indexValue := invoiceCursor.Next()
for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, indexValue = invoiceCursor.Next() {
// Depending on the length of the index value, this may
// or may not be an AMP invoice, so we'll extract the
// invoice value into two components: the invoice num,
// and the setID (may not be there).
var (
invoiceKey [4]byte
setID *invpkg.SetID
)
valueLen := copy(invoiceKey[:], indexValue)
if len(indexValue) == invoiceSetIDKeyLen {
setID = new(invpkg.SetID)
copy(setID[:], indexValue[valueLen:])
}
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(
invoiceKey[:], invoices, []*invpkg.SetID{setID},
true,
)
if err != nil {
return err
}
settledInvoices = append(settledInvoices, invoice)
processedCount++
if time.Since(lastLogTime) >=
invoiceProgressLogInterval {
log.Debugf("Processed %d settled invoices "+
"which have a settle index greater "+
"than %v", processedCount,
sinceSettleIndex)
lastLogTime = time.Now()
}
}
return nil
}, func() {
settledInvoices = nil
})
if err != nil {
return nil, err
}
elapsed := time.Since(start)
log.Debugf("Completed scanning for settled invoices starting at "+
"index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
sinceSettleIndex, processedCount, len(settledInvoices),
elapsed.Round(time.Millisecond))
return settledInvoices, nil
}
func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
i *invpkg.Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
uint64, error) {
// Create the invoice key which is just the big-endian representation
// of the invoice number.
var invoiceKey [4]byte
byteOrder.PutUint32(invoiceKey[:], invoiceNum)
// Increment the num invoice counter index so the next invoice bares
// the proper ID.
var scratch [4]byte
invoiceCounter := invoiceNum + 1
byteOrder.PutUint32(scratch[:], invoiceCounter)
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
return 0, err
}
// Add the payment hash to the invoice index. This will let us quickly
// identify if we can settle an incoming payment, and also to possibly
// allow a single invoice to have multiple payment installations.
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
if err != nil {
return 0, err
}
// Add the invoice to the payment address index, but only if the invoice
// has a non-zero payment address. The all-zero payment address is still
// in use by legacy keysend, so we special-case here to avoid
// collisions.
if i.Terms.PaymentAddr != invpkg.BlankPayAddr {
err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
if err != nil {
return 0, err
}
}
// Next, we'll obtain the next add invoice index (sequence
// number), so we can properly place this invoice within this
// event stream.
nextAddSeqNo, err := addIndex.NextSequence()
if err != nil {
return 0, err
}
// With the next sequence obtained, we'll updating the event series in
// the add index bucket to map this current add counter to the index of
// this new invoice.
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
return 0, err
}
i.AddIndex = nextAddSeqNo
// Finally, serialize the invoice itself to be written to the disk.
var buf bytes.Buffer
if err := serializeInvoice(&buf, i); err != nil {
return 0, err
}
if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
return 0, err
}
return nextAddSeqNo, nil
}
// recordSize returns the amount of bytes this TLV record will occupy when
// encoded.
func ampRecordSize(a *invpkg.AMPInvoiceState) func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
// We know that encoding works since the tests pass in the build this
// file is checked into, so we'll simplify things and simply encode it
// ourselves then report the total amount of bytes used.
if err := ampStateEncoder(&b, a, &buf); err != nil {
// This should never error out, but we log it just in case it
// does.
log.Errorf("encoding the amp invoice state failed: %v", err)
}
return func() uint64 {
return uint64(len(b.Bytes()))
}
}
// serializeInvoice serializes an invoice to a writer.
//
// Note: this function is in use for a migration. Before making changes that
// would modify the on disk format, make a copy of the original code and store
// it with the migration.
func serializeInvoice(w io.Writer, i *invpkg.Invoice) error {
creationDateBytes, err := i.CreationDate.MarshalBinary()
if err != nil {
return err
}
settleDateBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return err
}
var fb bytes.Buffer
err = i.Terms.Features.EncodeBase256(&fb)
if err != nil {
return err
}
featureBytes := fb.Bytes()
preimage := [32]byte(invpkg.UnknownPreimage)
if i.Terms.PaymentPreimage != nil {
preimage = *i.Terms.PaymentPreimage
if preimage == invpkg.UnknownPreimage {
return errors.New("cannot use all-zeroes preimage")
}
}
value := uint64(i.Terms.Value)
cltvDelta := uint32(i.Terms.FinalCltvDelta)
expiry := uint64(i.Terms.Expiry)
amtPaid := uint64(i.AmtPaid)
state := uint8(i.State)
var hodlInvoice uint8
if i.HodlInvoice {
hodlInvoice = 1
}
tlvStream, err := tlv.NewStream(
// Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo),
tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
// Add/settle metadata.
tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms.
tlv.MakePrimitiveRecord(preimageType, &preimage),
tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry),
tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
tlv.MakePrimitiveRecord(featuresType, &featureBytes),
// Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
// Invoice AMP state.
tlv.MakeDynamicRecord(
invoiceAmpStateType, &i.AMPState,
ampRecordSize(&i.AMPState),
ampStateEncoder, ampStateDecoder,
),
)
if err != nil {
return err
}
var b bytes.Buffer
if err = tlvStream.Encode(&b); err != nil {
return err
}
err = binary.Write(w, byteOrder, uint64(b.Len()))
if err != nil {
return err
}
if _, err = w.Write(b.Bytes()); err != nil {
return err
}
// Only if this is a _non_ AMP invoice do we serialize the HTLCs
// in-line with the rest of the invoice.
if i.IsAMP() {
return nil
}
return serializeHtlcs(w, i.Htlcs)
}
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
// a writer.
func serializeHtlcs(w io.Writer,
htlcs map[models.CircuitKey]*invpkg.InvoiceHTLC) error {
for key, htlc := range htlcs {
// Encode the htlc in a tlv stream.
chanID := key.ChanID.ToUint64()
amt := uint64(htlc.Amt)
mppTotalAmt := uint64(htlc.MppTotalAmt)
acceptTime := putNanoTime(htlc.AcceptTime)
resolveTime := putNanoTime(htlc.ResolveTime)
state := uint8(htlc.State)
var records []tlv.Record
records = append(records,
tlv.MakePrimitiveRecord(chanIDType, &chanID),
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
tlv.MakePrimitiveRecord(amtType, &amt),
tlv.MakePrimitiveRecord(
acceptHeightType, &htlc.AcceptHeight,
),
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
)
if htlc.AMP != nil {
setIDRecord := tlv.MakeDynamicRecord(
htlcAMPType, &htlc.AMP.Record,
htlc.AMP.Record.PayloadSize,
record.AMPEncoder, record.AMPDecoder,
)
records = append(records, setIDRecord)
hash32 := [32]byte(htlc.AMP.Hash)
hashRecord := tlv.MakePrimitiveRecord(
htlcHashType, &hash32,
)
records = append(records, hashRecord)
if htlc.AMP.Preimage != nil {
preimage32 := [32]byte(*htlc.AMP.Preimage)
preimageRecord := tlv.MakePrimitiveRecord(
htlcPreimageType, &preimage32,
)
records = append(records, preimageRecord)
}
}
// Convert the custom records to tlv.Record types that are ready
// for serialization.
customRecords := tlv.MapToRecords(htlc.CustomRecords)
// Append the custom records. Their ids are in the experimental
// range and sorted, so there is no need to sort again.
records = append(records, customRecords...)
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
var b bytes.Buffer
if err := tlvStream.Encode(&b); err != nil {
return err
}
// Write the length of the tlv stream followed by the stream
// bytes.
err = binary.Write(w, byteOrder, uint64(b.Len()))
if err != nil {
return err
}
if _, err := w.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}
// putNanoTime returns the unix nano time for the passed timestamp. A zero-value
// timestamp will be mapped to 0, since calling UnixNano in that case is
// undefined.
func putNanoTime(t time.Time) uint64 {
if t.IsZero() {
return 0
}
return uint64(t.UnixNano())
}
// getNanoTime returns a timestamp for the given number of nano seconds. If zero
// is provided, an zero-value time stamp is returned.
func getNanoTime(ns uint64) time.Time {
if ns == 0 {
return time.Time{}
}
return time.Unix(0, int64(ns))
}
// fetchFilteredAmpInvoices retrieves only a select set of AMP invoices
// identified by the setID value.
func fetchFilteredAmpInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
error) {
htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
for _, setID := range setIDs {
invoiceSetIDKey := makeInvoiceSetIDKey(invoiceNum, setID[:])
htlcSetBytes := invoiceBucket.Get(invoiceSetIDKey[:])
if htlcSetBytes == nil {
// A set ID was passed in, but we don't have this
// stored yet, meaning that the setID is being added
// for the first time.
return htlcs, invpkg.ErrInvoiceNotFound
}
htlcSetReader := bytes.NewReader(htlcSetBytes)
htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
if err != nil {
return nil, err
}
maps.Copy(htlcs, htlcsBySetID)
}
return htlcs, nil
}
// forEachAMPInvoice is a helper function that attempts to iterate over each of
// the HTLC sets (based on their set ID) for the given AMP invoice identified
// by its invoiceNum. The callback closure is called for each key within the
// prefix range.
func forEachAMPInvoice(invoiceBucket kvdb.RBucket, invoiceNum []byte,
callback func(key, htlcSet []byte) error) error {
invoiceCursor := invoiceBucket.ReadCursor()
// Seek to the first key that includes the invoice data itself.
invoiceCursor.Seek(invoiceNum)
// Advance to the very first key _after_ the invoice data, as this is
// where we'll encounter our first HTLC (if any are present).
cursorKey, htlcSet := invoiceCursor.Next()
// If at this point, the cursor key doesn't match the invoice num
// prefix, then we know that this HTLC doesn't have any set ID HTLCs
// associated with it.
if !bytes.HasPrefix(cursorKey, invoiceNum) {
return nil
}
// Otherwise continue to iterate until we no longer match the prefix,
// executing the call back at each step.
for ; cursorKey != nil && bytes.HasPrefix(cursorKey, invoiceNum); cursorKey, htlcSet = invoiceCursor.Next() {
err := callback(cursorKey, htlcSet)
if err != nil {
return err
}
}
return nil
}
// fetchAmpSubInvoices attempts to use the invoiceNum as a prefix within the
// AMP bucket to find all the individual HTLCs (by setID) associated with a
// given invoice. If a list of set IDs are specified, then only HTLCs
// associated with that setID will be retrieved.
func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
setIDs ...*invpkg.SetID) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
error) {
// If a set of setIDs was specified, then we can skip the cursor and
// just read out exactly what we need.
if len(setIDs) != 0 && setIDs[0] != nil {
return fetchFilteredAmpInvoices(
invoiceBucket, invoiceNum, setIDs...,
)
}
// Otherwise, iterate over all the htlc sets that are prefixed beside
// this invoice in the main invoice bucket.
htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
err := forEachAMPInvoice(invoiceBucket, invoiceNum,
func(key, htlcSet []byte) error {
htlcSetReader := bytes.NewReader(htlcSet)
htlcsBySetID, err := deserializeHtlcs(htlcSetReader)
if err != nil {
return err
}
maps.Copy(htlcs, htlcsBySetID)
return nil
},
)
if err != nil {
return nil, err
}
return htlcs, nil
}
// fetchInvoice attempts to read out the relevant state for the invoice as
// specified by the invoice number. If the setID fields are set, then only the
// HTLC information pertaining to those set IDs is returned.
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
return invpkg.Invoice{}, invpkg.ErrInvoiceNotFound
}
invoiceReader := bytes.NewReader(invoiceBytes)
invoice, err := deserializeInvoice(invoiceReader)
if err != nil {
return invpkg.Invoice{}, err
}
// If this is an AMP invoice we'll also attempt to read out the set of
// HTLCs that were paid to prior set IDs, if needed.
if !invoice.IsAMP() {
return invoice, nil
}
if shouldFetchAMPHTLCs(invoice, setIDs) {
invoice.Htlcs, err = fetchAmpSubInvoices(
invoices, invoiceNum, setIDs...,
)
// TODO(positiveblue): we should fail when we are not able to
// fetch all the HTLCs for an AMP invoice. Multiple tests in
// the invoice and channeldb package break if we return this
// error. We need to update them when we migrate this logic to
// the sql implementation.
if err != nil {
log.Errorf("unable to fetch amp htlcs for inv "+
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
}
if filterAMPState {
filterInvoiceAMPState(&invoice, setIDs...)
}
}
return invoice, nil
}
// shouldFetchAMPHTLCs returns true if we need to fetch the set of HTLCs that
// were paid to the relevant set IDs.
func shouldFetchAMPHTLCs(invoice invpkg.Invoice, setIDs []*invpkg.SetID) bool {
// For AMP invoice that already have HTLCs populated (created before
// recurring invoices), then we don't need to read from the prefix
// keyed section of the bucket.
if len(invoice.Htlcs) != 0 {
return false
}
// If the "zero" setID was specified, then this means that no HTLC data
// should be returned alongside of it.
if len(setIDs) != 0 && setIDs[0] != nil &&
*setIDs[0] == invpkg.BlankPayAddr {
return false
}
return true
}
// fetchInvoiceStateAMP retrieves the state of all the relevant sub-invoice for
// an AMP invoice. This methods only decode the relevant state vs the entire
// invoice.
func fetchInvoiceStateAMP(invoiceNum []byte,
invoices kvdb.RBucket) (invpkg.AMPInvoiceState, error) {
// Fetch the raw invoice bytes.
invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
return nil, invpkg.ErrInvoiceNotFound
}
r := bytes.NewReader(invoiceBytes)
var bodyLen int64
err := binary.Read(r, byteOrder, &bodyLen)
if err != nil {
return nil, err
}
// Next, we'll make a new TLV stream that only attempts to decode the
// bytes we actually need.
ampState := make(invpkg.AMPInvoiceState)
tlvStream, err := tlv.NewStream(
// Invoice AMP state.
tlv.MakeDynamicRecord(
invoiceAmpStateType, &State, nil,
ampStateEncoder, ampStateDecoder,
),
)
if err != nil {
return nil, err
}
invoiceReader := io.LimitReader(r, bodyLen)
if err = tlvStream.Decode(invoiceReader); err != nil {
return nil, err
}
return ampState, nil
}
func deserializeInvoice(r io.Reader) (invpkg.Invoice, error) {
var (
preimageBytes [32]byte
value uint64
cltvDelta uint32
expiry uint64
amtPaid uint64
state uint8
hodlInvoice uint8
creationDateBytes []byte
settleDateBytes []byte
featureBytes []byte
)
var i invpkg.Invoice
i.AMPState = make(invpkg.AMPInvoiceState)
tlvStream, err := tlv.NewStream(
// Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo),
tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
// Add/settle metadata.
tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms.
tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry),
tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
tlv.MakePrimitiveRecord(featuresType, &featureBytes),
// Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
// Invoice AMP state.
tlv.MakeDynamicRecord(
invoiceAmpStateType, &i.AMPState, nil,
ampStateEncoder, ampStateDecoder,
),
)
if err != nil {
return i, err
}
var bodyLen int64
err = binary.Read(r, byteOrder, &bodyLen)
if err != nil {
return i, err
}
lr := io.LimitReader(r, bodyLen)
if err = tlvStream.Decode(lr); err != nil {
return i, err
}
preimage := lntypes.Preimage(preimageBytes)
if preimage != invpkg.UnknownPreimage {
i.Terms.PaymentPreimage = &preimage
}
i.Terms.Value = lnwire.MilliSatoshi(value)
i.Terms.FinalCltvDelta = int32(cltvDelta)
i.Terms.Expiry = time.Duration(expiry)
i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
i.State = invpkg.ContractState(state)
if hodlInvoice != 0 {
i.HodlInvoice = true
}
err = i.CreationDate.UnmarshalBinary(creationDateBytes)
if err != nil {
return i, err
}
err = i.SettleDate.UnmarshalBinary(settleDateBytes)
if err != nil {
return i, err
}
rawFeatures := lnwire.NewRawFeatureVector()
err = rawFeatures.DecodeBase256(
bytes.NewReader(featureBytes), len(featureBytes),
)
if err != nil {
return i, err
}
i.Terms.Features = lnwire.NewFeatureVector(
rawFeatures, lnwire.Features,
)
i.Htlcs, err = deserializeHtlcs(r)
return i, err
}
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
// We encode the set of circuit keys as a varint length prefix.
// followed by a series of fixed sized uint8 integers.
numKeys := uint64(len(*v))
if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
return err
}
for key := range *v {
scidInt := key.ChanID.ToUint64()
if err := tlv.EUint64(w, &scidInt, buf); err != nil {
return err
}
if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
}
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*map[models.CircuitKey]struct{}); ok {
// First, we'll read out the varint that encodes the number of
// circuit keys encoded.
numKeys, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many keys to expect, iterate reading
// each one until we're done.
for i := uint64(0); i < numKeys; i++ {
var (
key models.CircuitKey
scid uint64
)
if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
return err
}
key.ChanID = lnwire.NewShortChanIDFromInt(scid)
err := tlv.DUint64(r, &key.HtlcID, buf, 8)
if err != nil {
return err
}
(*v)[key] = struct{}{}
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
}
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*invpkg.AMPInvoiceState); ok {
// We'll encode the AMP state as a series of KV pairs on the
// wire with a length prefix.
numRecords := uint64(len(*v))
// First, we'll write out the number of records as a var int.
if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for setID, ampState := range *v {
setID := [32]byte(setID)
ampState := ampState
htlcState := uint8(ampState.State)
settleDate := ampState.SettleDate
settleDateBytes, err := settleDate.MarshalBinary()
if err != nil {
return err
}
amtPaid := uint64(ampState.AmtPaid)
var ampStateTlvBytes bytes.Buffer
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(
ampStateSetIDType, &setID,
),
tlv.MakePrimitiveRecord(
ampStateHtlcStateType, &htlcState,
),
tlv.MakePrimitiveRecord(
ampStateSettleIndexType,
&State.SettleIndex,
),
tlv.MakePrimitiveRecord(
ampStateSettleDateType,
&settleDateBytes,
),
tlv.MakeDynamicRecord(
ampStateCircuitKeysType,
&State.InvoiceKeys,
func() uint64 {
// The record takes 8 bytes to
// encode the set of circuits,
// 8 bytes for the scid for the
// key, and 8 bytes for the HTLC
// index.
keys := ampState.InvoiceKeys
numKeys := uint64(len(keys))
size := tlv.VarIntSize(numKeys)
dataSize := (numKeys * 16)
return size + dataSize
},
encodeCircuitKeys, decodeCircuitKeys,
),
tlv.MakePrimitiveRecord(
ampStateAmtPaidType, &amtPaid,
),
)
if err != nil {
return err
}
err = tlvStream.Encode(&StateTlvBytes)
if err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
_, err = w.Write(ampStateTlvBytes.Bytes())
if err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
}
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*invpkg.AMPInvoiceState); ok {
// First, we'll decode the varint that encodes how many set IDs
// are encoded within the greater map.
numRecords, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numRecords; i++ {
// Read out the varint that encodes the size of this
// inner TLV record.
stateRecordSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := io.LimitedReader{
R: r,
N: int64(stateRecordSize),
}
var (
setID [32]byte
htlcState uint8
settleIndex uint64
settleDateBytes []byte
invoiceKeys = make(
map[models.CircuitKey]struct{},
)
amtPaid uint64
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(
ampStateSetIDType, &setID,
),
tlv.MakePrimitiveRecord(
ampStateHtlcStateType, &htlcState,
),
tlv.MakePrimitiveRecord(
ampStateSettleIndexType, &settleIndex,
),
tlv.MakePrimitiveRecord(
ampStateSettleDateType,
&settleDateBytes,
),
tlv.MakeDynamicRecord(
ampStateCircuitKeysType,
&invoiceKeys, nil,
encodeCircuitKeys, decodeCircuitKeys,
),
tlv.MakePrimitiveRecord(
ampStateAmtPaidType, &amtPaid,
),
)
if err != nil {
return err
}
err = tlvStream.Decode(&innerTlvReader)
if err != nil {
return err
}
var settleDate time.Time
err = settleDate.UnmarshalBinary(settleDateBytes)
if err != nil {
return err
}
(*v)[setID] = invpkg.InvoiceStateAMP{
State: invpkg.HtlcState(htlcState),
SettleIndex: settleIndex,
SettleDate: settleDate,
InvoiceKeys: invoiceKeys,
AmtPaid: lnwire.MilliSatoshi(amtPaid),
}
}
return nil
}
return tlv.NewTypeForDecodingErr(
val, "channeldb.AMPInvoiceState", l, l,
)
}
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
// as a map.
func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
error) {
htlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
for {
// Read the length of the tlv stream for this htlc.
var streamLen int64
if err := binary.Read(r, byteOrder, &streamLen); err != nil {
if err == io.EOF {
break
}
return nil, err
}
// Limit the reader so that it stops at the end of this htlc's
// stream.
htlcReader := io.LimitReader(r, streamLen)
// Decode the contents into the htlc fields.
var (
htlc invpkg.InvoiceHTLC
key models.CircuitKey
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt, mppTotalAmt uint64
amp = &record.AMP{}
hash32 = &[32]byte{}
preimage32 = &[32]byte{}
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
tlv.MakePrimitiveRecord(amtType, &amt),
tlv.MakePrimitiveRecord(
acceptHeightType, &htlc.AcceptHeight,
),
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state),
tlv.MakePrimitiveRecord(mppTotalAmtType, &mppTotalAmt),
tlv.MakeDynamicRecord(
htlcAMPType, amp, amp.PayloadSize,
record.AMPEncoder, record.AMPDecoder,
),
tlv.MakePrimitiveRecord(htlcHashType, hash32),
tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
)
if err != nil {
return nil, err
}
parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
if err != nil {
return nil, err
}
if _, ok := parsedTypes[htlcAMPType]; !ok {
amp = nil
}
var preimage *lntypes.Preimage
if _, ok := parsedTypes[htlcPreimageType]; ok {
pimg := lntypes.Preimage(*preimage32)
preimage = &pimg
}
var hash *lntypes.Hash
if _, ok := parsedTypes[htlcHashType]; ok {
h := lntypes.Hash(*hash32)
hash = &h
}
key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
htlc.AcceptTime = getNanoTime(acceptTime)
htlc.ResolveTime = getNanoTime(resolveTime)
htlc.State = invpkg.HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt)
htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
if amp != nil && hash != nil {
htlc.AMP = &invpkg.InvoiceHtlcAMPData{
Record: *amp,
Hash: *hash,
Preimage: preimage,
}
}
// Reconstruct the custom records fields from the parsed types
// map return from the tlv parser.
htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
htlcs[key] = &htlc
}
return htlcs, nil
}
// invoiceSetIDKeyLen is the length of the key that's used to store the
// individual HTLCs prefixed by their ID along side the main invoice within the
// invoiceBytes. We use 4 bytes for the invoice number, and 32 bytes for the
// set ID.
const invoiceSetIDKeyLen = 4 + 32
// makeInvoiceSetIDKey returns the prefix key, based on the set ID and invoice
// number where the HTLCs for this setID will be stored udner.
func makeInvoiceSetIDKey(invoiceNum, setID []byte) [invoiceSetIDKeyLen]byte {
// Construct the prefix key we need to obtain the invoice information:
// invoiceNum || setID.
var invoiceSetIDKey [invoiceSetIDKeyLen]byte
copy(invoiceSetIDKey[:], invoiceNum)
copy(invoiceSetIDKey[len(invoiceNum):], setID)
return invoiceSetIDKey
}
// delAMPInvoices attempts to delete all the "sub" invoices associated with a
// greater AMP invoices. We do this by deleting the set of keys that share the
// invoice number as a prefix.
func delAMPInvoices(invoiceNum []byte, invoiceBucket kvdb.RwBucket) error {
// Since it isn't safe to delete using an active cursor, we'll use the
// cursor simply to collect the set of keys we need to delete, _then_
// delete them in another pass.
var keysToDel [][]byte
err := forEachAMPInvoice(
invoiceBucket, invoiceNum,
func(cursorKey, v []byte) error {
keysToDel = append(keysToDel, cursorKey)
return nil
},
)
if err != nil {
return err
}
// In this next phase, we'll then delete all the relevant invoices.
for _, keyToDel := range keysToDel {
if err := invoiceBucket.Delete(keyToDel); err != nil {
return err
}
}
return nil
}
// delAMPSettleIndex removes all the entries in the settle index associated
// with a given AMP invoice.
func delAMPSettleIndex(invoiceNum []byte, invoices,
settleIndex kvdb.RwBucket) error {
// First, we need to grab the AMP invoice state to see if there's
// anything that we even need to delete.
ampState, err := fetchInvoiceStateAMP(invoiceNum, invoices)
if err != nil {
return err
}
// If there's no AMP state at all (non-AMP invoice), then we can return
// early.
if len(ampState) == 0 {
return nil
}
// Otherwise, we'll need to iterate and delete each settle index within
// the set of returned entries.
var settleIndexKey [8]byte
for _, subState := range ampState {
byteOrder.PutUint64(
settleIndexKey[:], subState.SettleIndex,
)
if err := settleIndex.Delete(settleIndexKey[:]); err != nil {
return err
}
}
return nil
}
// DeleteCanceledInvoices deletes all canceled invoices from the database.
func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
invoices := tx.ReadWriteBucket(invoiceBucket)
if invoices == nil {
return nil
}
invoiceIndex := invoices.NestedReadWriteBucket(
invoiceIndexBucket,
)
if invoiceIndex == nil {
return nil
}
invoiceAddIndex := invoices.NestedReadWriteBucket(
addIndexBucket,
)
if invoiceAddIndex == nil {
return nil
}
payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
// Skip sub-buckets.
if v == nil {
return nil
}
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
if invoice.State != invpkg.ContractCanceled {
return nil
}
// Delete the payment hash from the invoice index.
err = invoiceIndex.Delete(k)
if err != nil {
return err
}
// Delete payment address index reference if there's a
// valid payment address.
if invoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
// To ensure consistency check that the already
// fetched invoice key matches the one in the
// payment address index.
key := payAddrIndex.Get(
invoice.Terms.PaymentAddr[:],
)
if bytes.Equal(key, k) {
// Delete from the payment address
// index.
if err := payAddrIndex.Delete(
invoice.Terms.PaymentAddr[:],
); err != nil {
return err
}
}
}
// Remove from the add index.
var addIndexKey [8]byte
byteOrder.PutUint64(addIndexKey[:], invoice.AddIndex)
err = invoiceAddIndex.Delete(addIndexKey[:])
if err != nil {
return err
}
// Note that we don't need to delete the invoice from
// the settle index as it is not added until the
// invoice is settled.
// Now remove all sub invoices.
err = delAMPInvoices(k, invoices)
if err != nil {
return err
}
// Finally remove the serialized invoice from the
// invoice bucket.
return invoices.Delete(k)
})
}, func() {})
}
// DeleteInvoice attempts to delete the passed invoices from the database in
// one transaction. The passed delete references hold all keys required to
// delete the invoices without also needing to deserialize them.
func (d *DB) DeleteInvoice(_ context.Context,
invoicesToDelete []invpkg.InvoiceDeleteRef) error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
invoices := tx.ReadWriteBucket(invoiceBucket)
if invoices == nil {
return invpkg.ErrNoInvoicesCreated
}
invoiceIndex := invoices.NestedReadWriteBucket(
invoiceIndexBucket,
)
if invoiceIndex == nil {
return invpkg.ErrNoInvoicesCreated
}
invoiceAddIndex := invoices.NestedReadWriteBucket(
addIndexBucket,
)
if invoiceAddIndex == nil {
return invpkg.ErrNoInvoicesCreated
}
// settleIndex can be nil, as the bucket is created lazily
// when the first invoice is settled.
settleIndex := invoices.NestedReadWriteBucket(settleIndexBucket)
payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
for _, ref := range invoicesToDelete {
// Fetch the invoice key for using it to check for
// consistency and also to delete from the invoice
// index.
invoiceKey := invoiceIndex.Get(ref.PayHash[:])
if invoiceKey == nil {
return invpkg.ErrInvoiceNotFound
}
err := invoiceIndex.Delete(ref.PayHash[:])
if err != nil {
return err
}
// Delete payment address index reference if there's a
// valid payment address passed.
if ref.PayAddr != nil {
// To ensure consistency check that the already
// fetched invoice key matches the one in the
// payment address index.
key := payAddrIndex.Get(ref.PayAddr[:])
if bytes.Equal(key, invoiceKey) {
// Delete from the payment address
// index. Note that since the payment
// address index has been introduced
// with an empty migration it may be
// possible that the index doesn't have
// an entry for this invoice.
// ref: https://github.com/lightningnetwork/lnd/pull/4285/commits/cbf71b5452fa1d3036a43309e490787c5f7f08dc#r426368127
if err := payAddrIndex.Delete(
ref.PayAddr[:],
); err != nil {
return err
}
}
}
var addIndexKey [8]byte
byteOrder.PutUint64(addIndexKey[:], ref.AddIndex)
// To ensure consistency check that the key stored in
// the add index also matches the previously fetched
// invoice key.
key := invoiceAddIndex.Get(addIndexKey[:])
if !bytes.Equal(key, invoiceKey) {
return fmt.Errorf("unknown invoice " +
"in add index")
}
// Remove from the add index.
err = invoiceAddIndex.Delete(addIndexKey[:])
if err != nil {
return err
}
// Remove from the settle index if available and
// if the invoice is settled.
if settleIndex != nil && ref.SettleIndex > 0 {
var settleIndexKey [8]byte
byteOrder.PutUint64(
settleIndexKey[:], ref.SettleIndex,
)
// To ensure consistency check that the already
// fetched invoice key matches the one in the
// settle index
key := settleIndex.Get(settleIndexKey[:])
if !bytes.Equal(key, invoiceKey) {
return fmt.Errorf("unknown invoice " +
"in settle index")
}
err = settleIndex.Delete(settleIndexKey[:])
if err != nil {
return err
}
}
// In addition to deleting the main invoice state, if
// this is an AMP invoice, then we'll also need to
// delete the set HTLC set stored as a key prefix. For
// non-AMP invoices, this'll be a noop.
err = delAMPSettleIndex(
invoiceKey, invoices, settleIndex,
)
if err != nil {
return err
}
err = delAMPInvoices(invoiceKey, invoices)
if err != nil {
return err
}
// Finally remove the serialized invoice from the
// invoice bucket.
err = invoices.Delete(invoiceKey)
if err != nil {
return err
}
}
return nil
}, func() {})
return err
}
// SetInvoiceBucketTombstone sets the tombstone key in the invoice bucket to
// mark the bucket as permanently closed. This prevents it from being reopened
// in the future.
func (d *DB) SetInvoiceBucketTombstone() error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
// Access the top-level invoice bucket.
invoices := tx.ReadWriteBucket(invoiceBucket)
if invoices == nil {
return fmt.Errorf("invoice bucket does not exist")
}
// Add the tombstone key to the invoice bucket.
err := invoices.Put(invoiceBucketTombstone, []byte("1"))
if err != nil {
return fmt.Errorf("failed to set tombstone: %w", err)
}
return nil
}, func() {})
}
// GetInvoiceBucketTombstone checks if the tombstone key exists in the invoice
// bucket. It returns true if the tombstone is present and false otherwise.
func (d *DB) GetInvoiceBucketTombstone() (bool, error) {
var tombstoneExists bool
err := kvdb.View(d, func(tx kvdb.RTx) error {
// Access the top-level invoice bucket.
invoices := tx.ReadBucket(invoiceBucket)
if invoices == nil {
return fmt.Errorf("invoice bucket does not exist")
}
// Check if the tombstone key exists.
tombstone := invoices.Get(invoiceBucketTombstone)
tombstoneExists = tombstone != nil
return nil
}, func() {})
if err != nil {
return false, err
}
return tombstoneExists, nil
}
package channeldb
import (
"io"
)
// deserializeCloseChannelSummaryV6 reads the v6 database format for
// ChannelCloseSummary.
//
// NOTE: deprecated, only for migration.
func deserializeCloseChannelSummaryV6(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
err = ReadElements(r, &c.RemoteCurrentRevocation)
switch {
case err == io.EOF:
return c, nil
// If we got a non-eof error, then we know there's an actually issue.
// Otherwise, it may have been the case that this summary didn't have
// the set of optional fields.
case err != nil:
return nil, err
}
if err := readChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// channel_ready message, then this can be nil. As a result, we'll use
// the same technique to read the field, only if there's still data
// left in the buffer.
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil && err != io.EOF {
// If we got a non-eof error, then we know there's an actually
// issue. Otherwise, it may have been the case that this
// summary didn't have the set of optional fields.
return nil, err
}
return c, nil
}
package channeldb
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration13"
"github.com/lightningnetwork/lnd/channeldb/migration16"
"github.com/lightningnetwork/lnd/channeldb/migration24"
"github.com/lightningnetwork/lnd/channeldb/migration30"
"github.com/lightningnetwork/lnd/channeldb/migration31"
"github.com/lightningnetwork/lnd/channeldb/migration32"
"github.com/lightningnetwork/lnd/channeldb/migration33"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
func init() {
UseLogger(build.NewSubLogger("CHDB", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
mig.UseLogger(logger)
migration_01_to_11.UseLogger(logger)
migration12.UseLogger(logger)
migration13.UseLogger(logger)
migration16.UseLogger(logger)
migration24.UseLogger(logger)
migration30.UseLogger(logger)
migration31.UseLogger(logger)
migration32.UseLogger(logger)
migration33.UseLogger(logger)
kvdb.UseLogger(logger)
}
package channeldb
import (
"bytes"
"errors"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// metaBucket stores all the meta information concerning the state of
// the database.
metaBucket = []byte("metadata")
// dbVersionKey is a boltdb key and it's used for storing/retrieving
// current database version.
dbVersionKey = []byte("dbp")
// dbVersionKey is a boltdb key and it's used for storing/retrieving
// a list of optional migrations that have been applied.
optionalVersionKey = []byte("ovk")
// TombstoneKey is the key under which we add a tag in the source DB
// after we've successfully and completely migrated it to the target/
// destination DB.
TombstoneKey = []byte("data-migration-tombstone")
// ErrMarkerNotPresent is the error that is returned if the queried
// marker is not present in the given database.
ErrMarkerNotPresent = errors.New("marker not present")
)
// Meta structure holds the database meta information.
type Meta struct {
// DbVersionNumber is the current schema version of the database.
DbVersionNumber uint32
}
// FetchMeta fetches the metadata from boltdb and returns filled meta structure.
func (d *DB) FetchMeta() (*Meta, error) {
var meta *Meta
err := kvdb.View(d, func(tx kvdb.RTx) error {
return FetchMeta(meta, tx)
}, func() {
meta = &Meta{}
})
if err != nil {
return nil, err
}
return meta, nil
}
// FetchMeta is a helper function used in order to allow callers to re-use a
// database transaction.
func FetchMeta(meta *Meta, tx kvdb.RTx) error {
metaBucket := tx.ReadBucket(metaBucket)
if metaBucket == nil {
return ErrMetaNotFound
}
data := metaBucket.Get(dbVersionKey)
if data == nil {
meta.DbVersionNumber = getLatestDBVersion(dbVersions)
} else {
meta.DbVersionNumber = byteOrder.Uint32(data)
}
return nil
}
// PutMeta writes the passed instance of the database met-data struct to disk.
func (d *DB) PutMeta(meta *Meta) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
return putMeta(meta, tx)
}, func() {})
}
// putMeta is an internal helper function used in order to allow callers to
// re-use a database transaction. See the publicly exported PutMeta method for
// more information.
func putMeta(meta *Meta, tx kvdb.RwTx) error {
metaBucket, err := tx.CreateTopLevelBucket(metaBucket)
if err != nil {
return err
}
return putDbVersion(metaBucket, meta)
}
func putDbVersion(metaBucket kvdb.RwBucket, meta *Meta) error {
scratch := make([]byte, 4)
byteOrder.PutUint32(scratch, meta.DbVersionNumber)
return metaBucket.Put(dbVersionKey, scratch)
}
// OptionalMeta structure holds the database optional migration information.
type OptionalMeta struct {
// Versions is a set that contains the versions that have been applied.
// When saved to disk, only the indexes are stored.
Versions map[uint64]string
}
func (om *OptionalMeta) String() string {
s := ""
for index, name := range om.Versions {
s += fmt.Sprintf("%d: %s", index, name)
}
if s == "" {
s = "empty"
}
return s
}
// fetchOptionalMeta reads the optional meta from the database.
func (d *DB) fetchOptionalMeta() (*OptionalMeta, error) {
om := &OptionalMeta{
Versions: make(map[uint64]string),
}
err := kvdb.View(d, func(tx kvdb.RTx) error {
metaBucket := tx.ReadBucket(metaBucket)
if metaBucket == nil {
return ErrMetaNotFound
}
vBytes := metaBucket.Get(optionalVersionKey)
// Exit early if nothing found.
if vBytes == nil {
return nil
}
// Read the versions' length.
r := bytes.NewReader(vBytes)
vLen, err := tlv.ReadVarInt(r, &[8]byte{})
if err != nil {
return err
}
// Write the version index.
for i := uint64(0); i < vLen; i++ {
version, err := tlv.ReadVarInt(r, &[8]byte{})
if err != nil {
return err
}
om.Versions[version] = optionalVersions[i].name
}
return nil
}, func() {})
if err != nil {
return nil, err
}
return om, nil
}
// putOptionalMeta writes an optional meta to the database.
func (d *DB) putOptionalMeta(om *OptionalMeta) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
metaBucket, err := tx.CreateTopLevelBucket(metaBucket)
if err != nil {
return err
}
var b bytes.Buffer
// Write the total length.
err = tlv.WriteVarInt(&b, uint64(len(om.Versions)), &[8]byte{})
if err != nil {
return err
}
// Write the version indexes.
for v := range om.Versions {
err := tlv.WriteVarInt(&b, v, &[8]byte{})
if err != nil {
return err
}
}
return metaBucket.Put(optionalVersionKey, b.Bytes())
}, func() {})
}
// CheckMarkerPresent returns the marker under the requested key or
// ErrMarkerNotFound if either the root bucket or the marker key within that
// bucket does not exist.
func CheckMarkerPresent(tx kvdb.RTx, markerKey []byte) ([]byte, error) {
markerBucket := tx.ReadBucket(markerKey)
if markerBucket == nil {
return nil, ErrMarkerNotPresent
}
val := markerBucket.Get(markerKey)
// If we wrote the marker correctly, we created a bucket _and_ created a
// key with a non-empty value. It doesn't matter to us whether the key
// exists or whether its value is empty, to us, it just means the marker
// isn't there.
if len(val) == 0 {
return nil, ErrMarkerNotPresent
}
return val, nil
}
// EnsureNoTombstone returns an error if there is a tombstone marker in the DB
// of the given transaction.
func EnsureNoTombstone(tx kvdb.RTx) error {
marker, err := CheckMarkerPresent(tx, TombstoneKey)
if err == ErrMarkerNotPresent {
// No marker present, so no tombstone. The DB is still alive.
return nil
}
if err != nil {
return err
}
// There was no error so there is a tombstone marker/tag. We cannot use
// this DB anymore.
return fmt.Errorf("refusing to use db, it was marked with a tombstone "+
"after successful data migration; tombstone reads: %s",
string(marker))
}
// AddMarker adds the marker with the given key into a top level bucket with the
// same name. So the structure will look like:
//
// marker-key (top level bucket)
// |-> marker-key:marker-value (key/value pair)
func AddMarker(tx kvdb.RwTx, markerKey, markerValue []byte) error {
if len(markerValue) == 0 {
return fmt.Errorf("marker value cannot be empty")
}
markerBucket, err := tx.CreateTopLevelBucket(markerKey)
if err != nil {
return err
}
return markerBucket.Put(markerKey, markerValue)
}
package migration
import (
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
// CreateTLB creates a new top-level bucket with the passed bucket identifier.
func CreateTLB(bucket []byte) func(kvdb.RwTx) error {
return func(tx kvdb.RwTx) error {
log.Infof("Creating top-level bucket: \"%s\" ...", bucket)
if tx.ReadBucket(bucket) != nil {
return fmt.Errorf("top-level bucket \"%s\" "+
"already exists", bucket)
}
_, err := tx.CreateTopLevelBucket(bucket)
if err != nil {
return err
}
log.Infof("Created top-level bucket: \"%s\"", bucket)
return nil
}
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
)
// AcceptChannel is the message Bob sends to Alice after she initiates the
// single funder channel workflow via an AcceptChannel message. Once Alice
// receives Bob's response, then she has all the items necessary to construct
// the funding transaction, and both commitment transactions.
type AcceptChannel struct {
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// DustLimit is the specific dust limit the sender of this message
// would like enforced on their version of the commitment transaction.
// Any output below this value will be "trimmed" from the commitment
// transaction, with the amount of the HTLC going to dust.
DustLimit btcutil.Amount
// MaxValueInFlight represents the maximum amount of coins that can be
// pending within the channel at any given time. If the amount of funds
// in limbo exceeds this amount, then the channel will be failed.
MaxValueInFlight MilliSatoshi
// ChannelReserve is the amount of BTC that the receiving party MUST
// maintain a balance above at all times. This is a safety mechanism to
// ensure that both sides always have skin in the game during the
// channel's lifetime.
ChannelReserve btcutil.Amount
// HtlcMinimum is the smallest HTLC that the sender of this message
// will accept.
HtlcMinimum MilliSatoshi
// MinAcceptDepth is the minimum depth that the initiator of the
// channel should wait before considering the channel open.
MinAcceptDepth uint32
// CsvDelay is the number of blocks to use for the relative time lock
// in the pay-to-self output of both commitment transactions.
CsvDelay uint16
// MaxAcceptedHTLCs is the total number of incoming HTLC's that the
// sender of this channel will accept.
//
// TODO(roasbeef): acks the initiator's, same with max in flight?
MaxAcceptedHTLCs uint16
// FundingKey is the key that should be used on behalf of the sender
// within the 2-of-2 multi-sig output that it contained within the
// funding transaction.
FundingKey *btcec.PublicKey
// RevocationPoint is the base revocation point for the sending party.
// Any commitment transaction belonging to the receiver of this message
// should use this key and their per-commitment point to derive the
// revocation key for the commitment transaction.
RevocationPoint *btcec.PublicKey
// PaymentPoint is the base payment point for the sending party. This
// key should be combined with the per commitment point for a
// particular commitment state in order to create the key that should
// be used in any output that pays directly to the sending party, and
// also within the HTLC covenant transactions.
PaymentPoint *btcec.PublicKey
// DelayedPaymentPoint is the delay point for the sending party. This
// key should be combined with the per commitment point to derive the
// keys that are used in outputs of the sender's commitment transaction
// where they claim funds.
DelayedPaymentPoint *btcec.PublicKey
// HtlcPoint is the base point used to derive the set of keys for this
// party that will be used within the HTLC public key scripts. This
// value is combined with the receiver's revocation base point in order
// to derive the keys that are used within HTLC scripts.
HtlcPoint *btcec.PublicKey
// FirstCommitmentPoint is the first commitment point for the sending
// party. This value should be combined with the receiver's revocation
// base point in order to derive the revocation keys that are placed
// within the commitment transaction of the sender.
FirstCommitmentPoint *btcec.PublicKey
// UpfrontShutdownScript is the script to which the channel funds should
// be paid when mutually closing the channel. This field is optional, and
// and has a length prefix, so a zero will be written if it is not set
// and its length followed by the script will be written if it is set.
UpfrontShutdownScript DeliveryAddress
}
// A compile time check to ensure AcceptChannel implements the lnwire.Message
// interface.
var _ Message = (*AcceptChannel)(nil)
// Encode serializes the target AcceptChannel into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
a.PendingChannelID[:],
a.DustLimit,
a.MaxValueInFlight,
a.ChannelReserve,
a.HtlcMinimum,
a.MinAcceptDepth,
a.CsvDelay,
a.MaxAcceptedHTLCs,
a.FundingKey,
a.RevocationPoint,
a.PaymentPoint,
a.DelayedPaymentPoint,
a.HtlcPoint,
a.FirstCommitmentPoint,
a.UpfrontShutdownScript,
)
}
// Decode deserializes the serialized AcceptChannel stored in the passed
// io.Reader into the target AcceptChannel using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error {
// Read all the mandatory fields in the accept message.
err := ReadElements(r,
a.PendingChannelID[:],
&a.DustLimit,
&a.MaxValueInFlight,
&a.ChannelReserve,
&a.HtlcMinimum,
&a.MinAcceptDepth,
&a.CsvDelay,
&a.MaxAcceptedHTLCs,
&a.FundingKey,
&a.RevocationPoint,
&a.PaymentPoint,
&a.DelayedPaymentPoint,
&a.HtlcPoint,
&a.FirstCommitmentPoint,
)
if err != nil {
return err
}
// Check for the optional upfront shutdown script field. If it is not there,
// silence the EOF error.
err = ReadElement(r, &a.UpfrontShutdownScript)
if err != nil && err != io.EOF {
return err
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as an AcceptChannel on the wire.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) MsgType() MessageType {
return MsgAcceptChannel
}
// MaxPayloadLength returns the maximum allowed payload length for a
// AcceptChannel message.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 {
// 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6)
var length uint32 = 270 // base length
// Upfront shutdown script max length.
length += 2 + deliveryAddressMaxSize
return length
}
package lnwire
import (
"io"
)
// AnnounceSignatures is a direct message between two endpoints of a
// channel and serves as an opt-in mechanism to allow the announcement of
// the channel to the rest of the network. It contains the necessary
// signatures by the sender to construct the channel announcement message.
type AnnounceSignatures struct {
// ChannelID is the unique description of the funding transaction.
// Channel id is better for users and debugging and short channel id is
// used for quick test on existence of the particular utxo inside the
// block chain, because it contains information about block.
ChannelID ChannelID
// ShortChannelID is the unique description of the funding
// transaction. It is constructed with the most significant 3 bytes
// as the block height, the next 3 bytes indicating the transaction
// index within the block, and the least significant two bytes
// indicating the output index which pays to the channel.
ShortChannelID ShortChannelID
// NodeSignature is the signature which contains the signed announce
// channel message, by this signature we proof that we possess of the
// node pub key and creating the reference node_key -> bitcoin_key.
NodeSignature Sig
// BitcoinSignature is the signature which contains the signed node
// public key, by this signature we proof that we possess of the
// bitcoin key and and creating the reverse reference bitcoin_key ->
// node_key.
BitcoinSignature Sig
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// A compile time check to ensure AnnounceSignatures implements the
// lnwire.Message interface.
var _ Message = (*AnnounceSignatures)(nil)
// Decode deserializes a serialized AnnounceSignatures stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.ChannelID,
&a.ShortChannelID,
&a.NodeSignature,
&a.BitcoinSignature,
)
if err != nil {
return err
}
// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = io.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}
return nil
}
// Encode serializes the target AnnounceSignatures into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
a.ChannelID,
a.ShortChannelID,
a.NodeSignature,
a.BitcoinSignature,
a.ExtraOpaqueData,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) MsgType() MessageType {
return MsgAnnounceSignatures
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 {
return 65533
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ChannelAnnouncement message is used to announce the existence of a channel
// between two peers in the overlay, which is propagated by the discovery
// service over broadcast handler.
type ChannelAnnouncement struct {
// This signatures are used by nodes in order to create cross
// references between node's channel and node. Requiring both nodes
// to sign indicates they are both willing to route other payments via
// this node.
NodeSig1 Sig
NodeSig2 Sig
// This signatures are used by nodes in order to create cross
// references between node's channel and node. Requiring the bitcoin
// signatures proves they control the channel.
BitcoinSig1 Sig
BitcoinSig2 Sig
// Features is the feature vector that encodes the features supported
// by the target node. This field can be used to signal the type of the
// channel, or modifications to the fields that would normally follow
// this vector.
Features *RawFeatureVector
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
ChainHash chainhash.Hash
// ShortChannelID is the unique description of the funding transaction,
// or where exactly it's located within the target blockchain.
ShortChannelID ShortChannelID
// The public keys of the two nodes who are operating the channel, such
// that is NodeID1 the numerically-lesser than NodeID2 (ascending
// numerical order).
NodeID1 [33]byte
NodeID2 [33]byte
// Public keys which corresponds to the keys which was declared in
// multisig funding transaction output.
BitcoinKey1 [33]byte
BitcoinKey2 [33]byte
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// A compile time check to ensure ChannelAnnouncement implements the
// lnwire.Message interface.
var _ Message = (*ChannelAnnouncement)(nil)
// Decode deserializes a serialized ChannelAnnouncement stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.NodeSig1,
&a.NodeSig2,
&a.BitcoinSig1,
&a.BitcoinSig2,
&a.Features,
a.ChainHash[:],
&a.ShortChannelID,
&a.NodeID1,
&a.NodeID2,
&a.BitcoinKey1,
&a.BitcoinKey2,
)
if err != nil {
return err
}
// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = io.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}
return nil
}
// Encode serializes the target ChannelAnnouncement into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
a.NodeSig1,
a.NodeSig2,
a.BitcoinSig1,
a.BitcoinSig2,
a.Features,
a.ChainHash[:],
a.ShortChannelID,
a.NodeID1,
a.NodeID2,
a.BitcoinKey1,
a.BitcoinKey2,
a.ExtraOpaqueData,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) MsgType() MessageType {
return MsgChannelAnnouncement
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 {
return 65533
}
// DataToSign is used to retrieve part of the announcement message which should
// be signed.
func (a *ChannelAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
var w bytes.Buffer
err := WriteElements(&w,
a.Features,
a.ChainHash[:],
a.ShortChannelID,
a.NodeID1,
a.NodeID2,
a.BitcoinKey1,
a.BitcoinKey2,
a.ExtraOpaqueData,
)
if err != nil {
return nil, err
}
return w.Bytes(), nil
}
package lnwire
import (
"encoding/binary"
"encoding/hex"
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
const (
// MaxFundingTxOutputs is the maximum number of allowed outputs on a
// funding transaction within the protocol. This is due to the fact
// that we use 2-bytes to encode the index within the funding output
// during the funding workflow. Funding transaction with more outputs
// than this are considered invalid within the protocol.
MaxFundingTxOutputs = math.MaxUint16
)
// ChannelID is a series of 32-bytes that uniquely identifies all channels
// within the network. The ChannelID is computed using the outpoint of the
// funding transaction (the txid, and output index). Given a funding output the
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
// txid and the big-endian serialization of the output index, truncated to
// 2 bytes.
type ChannelID [32]byte
// ConnectionWideID is an all-zero ChannelID, which is used to represent a
// message intended for all channels to specific peer.
var ConnectionWideID = ChannelID{}
// String returns the string representation of the ChannelID. This is just the
// hex string encoding of the ChannelID itself.
func (c ChannelID) String() string {
return hex.EncodeToString(c[:])
}
// NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is
// usable within the network. In order to convert the OutPoint into a ChannelID,
// we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian
// serialization of the Index of the OutPoint, truncated to 2-bytes.
func NewChanIDFromOutPoint(op *wire.OutPoint) ChannelID {
// First we'll copy the txid of the outpoint into our channel ID slice.
var cid ChannelID
copy(cid[:], op.Hash[:])
// With the txid copied over, we'll now XOR the lower 2-bytes of the
// partial channelID with big-endian serialization of output index.
xorTxid(&cid, uint16(op.Index))
return cid
}
// xorTxid performs the transformation needed to transform an OutPoint into a
// ChannelID. To do this, we expect the cid parameter to contain the txid
// unaltered and the outputIndex to be the output index
func xorTxid(cid *ChannelID, outputIndex uint16) {
var buf [2]byte
binary.BigEndian.PutUint16(buf[:], outputIndex)
cid[30] ^= buf[0]
cid[31] ^= buf[1]
}
// GenPossibleOutPoints generates all the possible outputs given a channel ID.
// In order to generate these possible outpoints, we perform a brute-force
// search through the candidate output index space, performing a reverse
// mapping from channelID back to OutPoint.
func (c *ChannelID) GenPossibleOutPoints() [MaxFundingTxOutputs]wire.OutPoint {
var possiblePoints [MaxFundingTxOutputs]wire.OutPoint
for i := uint16(0); i < MaxFundingTxOutputs; i++ {
cidCopy := *c
xorTxid(&cidCopy, i)
possiblePoints[i] = wire.OutPoint{
Hash: chainhash.Hash(cidCopy),
Index: uint32(i),
}
}
return possiblePoints
}
// IsChanPoint returns true if the OutPoint passed corresponds to the target
// ChannelID.
func (c ChannelID) IsChanPoint(op *wire.OutPoint) bool {
candidateCid := NewChanIDFromOutPoint(op)
return candidateCid == c
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
)
// ChannelReestablish is a message sent between peers that have an existing
// open channel upon connection reestablishment. This message allows both sides
// to report their local state, and their current knowledge of the state of the
// remote commitment chain. If a deviation is detected and can be recovered
// from, then the necessary messages will be retransmitted. If the level of
// desynchronization is irreconcilable, then the channel will be force closed.
type ChannelReestablish struct {
// ChanID is the channel ID of the channel state we're attempting to
// synchronize with the remote party.
ChanID ChannelID
// NextLocalCommitHeight is the next local commitment height of the
// sending party. If the height of the sender's commitment chain from
// the receiver's Pov is one less that this number, then the sender
// should re-send the *exact* same proposed commitment.
//
// In other words, the receiver should re-send their last sent
// commitment iff:
//
// * NextLocalCommitHeight == remoteCommitChain.Height
//
// This covers the case of a lost commitment which was sent by the
// sender of this message, but never received by the receiver of this
// message.
NextLocalCommitHeight uint64
// RemoteCommitTailHeight is the height of the receiving party's
// unrevoked commitment from the PoV of the sender of this message. If
// the height of the receiver's commitment is *one more* than this
// value, then their prior RevokeAndAck message should be
// retransmitted.
//
// In other words, the receiver should re-send their last sent
// RevokeAndAck message iff:
//
// * localCommitChain.tail().Height == RemoteCommitTailHeight + 1
//
// This covers the case of a lost revocation, wherein the receiver of
// the message sent a revocation for a prior state, but the sender of
// the message never fully processed it.
RemoteCommitTailHeight uint64
// LastRemoteCommitSecret is the last commitment secret that the
// receiving node has sent to the sending party. This will be the
// secret of the last revoked commitment transaction. Including this
// provides proof that the sending node at least knows of this state,
// as they couldn't have produced it if it wasn't sent, as the value
// can be authenticated by querying the shachain or the receiving
// party.
LastRemoteCommitSecret [32]byte
// LocalUnrevokedCommitPoint is the commitment point used in the
// current un-revoked commitment transaction of the sending party.
LocalUnrevokedCommitPoint *btcec.PublicKey
}
// A compile time check to ensure ChannelReestablish implements the
// lnwire.Message interface.
var _ Message = (*ChannelReestablish)(nil)
// Encode serializes the target ChannelReestablish into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error {
err := WriteElements(w,
a.ChanID,
a.NextLocalCommitHeight,
a.RemoteCommitTailHeight,
)
if err != nil {
return err
}
// If the commit point wasn't sent, then we won't write out any of the
// remaining fields as they're optional.
if a.LocalUnrevokedCommitPoint == nil {
return nil
}
// Otherwise, we'll write out the remaining elements.
return WriteElements(w, a.LastRemoteCommitSecret[:],
a.LocalUnrevokedCommitPoint)
}
// Decode deserializes a serialized ChannelReestablish stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.ChanID,
&a.NextLocalCommitHeight,
&a.RemoteCommitTailHeight,
)
if err != nil {
return err
}
// This message has currently defined optional fields. As a result,
// we'll only proceed if there's still bytes remaining within the
// reader.
//
// We'll manually parse out the optional fields in order to be able to
// still utilize the io.Reader interface.
// We'll first attempt to read the optional commit secret, if we're at
// the EOF, then this means the field wasn't included so we can exit
// early.
var buf [32]byte
_, err = io.ReadFull(r, buf[:32])
if err == io.EOF {
return nil
} else if err != nil {
return err
}
// If the field is present, then we'll copy it over and proceed.
copy(a.LastRemoteCommitSecret[:], buf[:])
// We'll conclude by parsing out the commitment point. We don't check
// the error in this case, as it has included the commit secret, then
// they MUST also include the commit point.
return ReadElement(r, &a.LocalUnrevokedCommitPoint)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) MsgType() MessageType {
return MsgChannelReestablish
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 {
var length uint32
// ChanID - 32 bytes
length += 32
// NextLocalCommitHeight - 8 bytes
length += 8
// RemoteCommitTailHeight - 8 bytes
length += 8
// LastRemoteCommitSecret - 32 bytes
length += 32
// LocalUnrevokedCommitPoint - 33 bytes
length += 33
return length
}
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ChanUpdateMsgFlags is a bitfield that signals whether optional fields are
// present in the ChannelUpdate.
type ChanUpdateMsgFlags uint8
const (
// ChanUpdateRequiredMaxHtlc is a bit that indicates whether the
// optional htlc_maximum_msat field is present in this ChannelUpdate.
ChanUpdateRequiredMaxHtlc ChanUpdateMsgFlags = 1 << iota
)
// String returns the bitfield flags as a string.
func (c ChanUpdateMsgFlags) String() string {
return fmt.Sprintf("%08b", c)
}
// HasMaxHtlc returns true if the htlc_maximum_msat option bit is set in the
// message flags.
func (c ChanUpdateMsgFlags) HasMaxHtlc() bool {
return c&ChanUpdateRequiredMaxHtlc != 0
}
// ChanUpdateChanFlags is a bitfield that signals various options concerning a
// particular channel edge. Each bit is to be examined in order to determine
// how the ChannelUpdate message is to be interpreted.
type ChanUpdateChanFlags uint8
const (
// ChanUpdateDirection indicates the direction of a channel update. If
// this bit is set to 0 if Node1 (the node with the "smaller" Node ID)
// is updating the channel, and to 1 otherwise.
ChanUpdateDirection ChanUpdateChanFlags = 1 << iota
// ChanUpdateDisabled is a bit that indicates if the channel edge
// selected by the ChanUpdateDirection bit is to be treated as being
// disabled.
ChanUpdateDisabled
)
// IsDisabled determines whether the channel flags has the disabled bit set.
func (c ChanUpdateChanFlags) IsDisabled() bool {
return c&ChanUpdateDisabled == ChanUpdateDisabled
}
// String returns the bitfield flags as a string.
func (c ChanUpdateChanFlags) String() string {
return fmt.Sprintf("%08b", c)
}
// ChannelUpdate message is used after channel has been initially announced.
// Each side independently announces its fees and minimum expiry for HTLCs and
// other parameters. Also this message is used to redeclare initially set
// channel parameters.
type ChannelUpdate struct {
// Signature is used to validate the announced data and prove the
// ownership of node id.
Signature Sig
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
// Along with the short channel ID, this uniquely identifies the
// channel globally in a blockchain.
ChainHash chainhash.Hash
// ShortChannelID is the unique description of the funding transaction.
ShortChannelID ShortChannelID
// Timestamp allows ordering in the case of multiple announcements. We
// should ignore the message if timestamp is not greater than
// the last-received.
Timestamp uint32
// MessageFlags is a bitfield that describes whether optional fields
// are present in this update. Currently, the least-significant bit
// must be set to 1 if the optional field MaxHtlc is present.
MessageFlags ChanUpdateMsgFlags
// ChannelFlags is a bitfield that describes additional meta-data
// concerning how the update is to be interpreted. Currently, the
// least-significant bit must be set to 0 if the creating node
// corresponds to the first node in the previously sent channel
// announcement and 1 otherwise. If the second bit is set, then the
// channel is set to be disabled.
ChannelFlags ChanUpdateChanFlags
// TimeLockDelta is the minimum number of blocks this node requires to
// be added to the expiry of HTLCs. This is a security parameter
// determined by the node operator. This value represents the required
// gap between the time locks of the incoming and outgoing HTLC's set
// to this node.
TimeLockDelta uint16
// HtlcMinimumMsat is the minimum HTLC value which will be accepted.
HtlcMinimumMsat MilliSatoshi
// BaseFee is the base fee that must be used for incoming HTLC's to
// this particular channel. This value will be tacked onto the required
// for a payment independent of the size of the payment.
BaseFee uint32
// FeeRate is the fee rate that will be charged per millionth of a
// satoshi.
FeeRate uint32
// HtlcMaximumMsat is the maximum HTLC value which will be accepted.
HtlcMaximumMsat MilliSatoshi
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// A compile time check to ensure ChannelUpdate implements the lnwire.Message
// interface.
var _ Message = (*ChannelUpdate)(nil)
// Decode deserializes a serialized ChannelUpdate stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.Signature,
a.ChainHash[:],
&a.ShortChannelID,
&a.Timestamp,
&a.MessageFlags,
&a.ChannelFlags,
&a.TimeLockDelta,
&a.HtlcMinimumMsat,
&a.BaseFee,
&a.FeeRate,
)
if err != nil {
return err
}
// Now check whether the max HTLC field is present and read it if so.
if a.MessageFlags.HasMaxHtlc() {
if err := ReadElements(r, &a.HtlcMaximumMsat); err != nil {
return err
}
}
// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = io.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}
return nil
}
// Encode serializes the target ChannelUpdate into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error {
err := WriteElements(w,
a.Signature,
a.ChainHash[:],
a.ShortChannelID,
a.Timestamp,
a.MessageFlags,
a.ChannelFlags,
a.TimeLockDelta,
a.HtlcMinimumMsat,
a.BaseFee,
a.FeeRate,
)
if err != nil {
return err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(w, a.HtlcMaximumMsat); err != nil {
return err
}
}
// Finally, append any extra opaque data.
return WriteElements(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate) MsgType() MessageType {
return MsgChannelUpdate
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 {
return 65533
}
// DataToSign is used to retrieve part of the announcement message which should
// be signed.
func (a *ChannelUpdate) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
var w bytes.Buffer
err := WriteElements(&w,
a.ChainHash[:],
a.ShortChannelID,
a.Timestamp,
a.MessageFlags,
a.ChannelFlags,
a.TimeLockDelta,
a.HtlcMinimumMsat,
a.BaseFee,
a.FeeRate,
)
if err != nil {
return nil, err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil {
return nil, err
}
}
// Finally, append any extra opaque data.
if err := WriteElements(&w, a.ExtraOpaqueData); err != nil {
return nil, err
}
return w.Bytes(), nil
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcutil"
)
// ClosingSigned is sent by both parties to a channel once the channel is clear
// of HTLCs, and is primarily concerned with negotiating fees for the close
// transaction. Each party provides a signature for a transaction with a fee
// that they believe is fair. The process terminates when both sides agree on
// the same fee, or when one side force closes the channel.
//
// NOTE: The responder is able to send a signature without any additional
// messages as all transactions are assembled observing BIP 69 which defines a
// canonical ordering for input/outputs. Therefore, both sides are able to
// arrive at an identical closure transaction as they know the order of the
// inputs/outputs.
type ClosingSigned struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// FeeSatoshis is the total fee in satoshis that the party to the
// channel would like to propose for the close transaction.
FeeSatoshis btcutil.Amount
// Signature is for the proposed channel close transaction.
Signature Sig
}
// NewClosingSigned creates a new empty ClosingSigned message.
func NewClosingSigned(cid ChannelID, fs btcutil.Amount,
sig Sig) *ClosingSigned {
return &ClosingSigned{
ChannelID: cid,
FeeSatoshis: fs,
Signature: sig,
}
}
// A compile time check to ensure ClosingSigned implements the lnwire.Message
// interface.
var _ Message = (*ClosingSigned)(nil)
// Decode deserializes a serialized ClosingSigned message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error {
return ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Signature)
}
// Encode serializes the target ClosingSigned into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, c.ChannelID, c.FeeSatoshis, c.Signature)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) MsgType() MessageType {
return MsgClosingSigned
}
// MaxPayloadLength returns the maximum allowed payload size for a
// ClosingSigned complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) MaxPayloadLength(uint32) uint32 {
var length uint32
// ChannelID - 32 bytes
length += 32
// FeeSatoshis - 8 bytes
length += 8
// Signature - 64 bytes
length += 64
return length
}
package lnwire
import (
"io"
)
// CommitSig is sent by either side to stage any pending HTLC's in the
// receiver's pending set into a new commitment state. Implicitly, the new
// commitment transaction constructed which has been signed by CommitSig
// includes all HTLC's in the remote node's pending set. A CommitSig message
// may be sent after a series of UpdateAddHTLC/UpdateFulfillHTLC messages in
// order to batch add several HTLC's with a single signature covering all
// implicitly accepted HTLC's.
type CommitSig struct {
// ChanID uniquely identifies to which currently active channel this
// CommitSig applies to.
ChanID ChannelID
// CommitSig is Alice's signature for Bob's new commitment transaction.
// Alice is able to send this signature without requesting any
// additional data due to the piggybacking of Bob's next revocation
// hash in his prior RevokeAndAck message, as well as the canonical
// ordering used for all inputs/outputs within commitment transactions.
// If initiating a new commitment state, this signature should ONLY
// cover all of the sending party's pending log updates, and the log
// updates of the remote party that have been ACK'd.
CommitSig Sig
// HtlcSigs is a signature for each relevant HTLC output within the
// created commitment. The order of the signatures is expected to be
// identical to the placement of the HTLC's within the BIP 69 sorted
// commitment transaction. For each outgoing HTLC (from the PoV of the
// sender of this message), a signature for an HTLC timeout transaction
// should be signed, for each incoming HTLC the HTLC timeout
// transaction should be signed.
HtlcSigs []Sig
}
// NewCommitSig creates a new empty CommitSig message.
func NewCommitSig() *CommitSig {
return &CommitSig{}
}
// A compile time check to ensure CommitSig implements the lnwire.Message
// interface.
var _ Message = (*CommitSig)(nil)
// Decode deserializes a serialized CommitSig message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.CommitSig,
&c.HtlcSigs,
)
}
// Encode serializes the target CommitSig into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.CommitSig,
c.HtlcSigs,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) MsgType() MessageType {
return MsgCommitSig
}
// MaxPayloadLength returns the maximum allowed payload size for a
// CommitSig complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) MaxPayloadLength(uint32) uint32 {
// 32 + 64 + 2 + max_allowed_htlcs
return MaxMessagePayload
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *CommitSig) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"bytes"
"fmt"
"io"
"sort"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// MinCustomRecordsTlvType is the minimum custom records TLV type as
// defined in BOLT 01.
MinCustomRecordsTlvType = 65536
)
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
// which must be greater than or equal to MinCustomRecordsTlvType.
type CustomRecords map[uint64][]byte
// NewCustomRecords creates a new CustomRecords instance from a
// tlv.TypeMap.
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
// Make comparisons in unit tests easy by returning nil if the map is
// empty.
if len(tlvMap) == 0 {
return nil, nil
}
customRecords := make(CustomRecords, len(tlvMap))
for k, v := range tlvMap {
customRecords[uint64(k)] = v
}
// Validate the custom records.
err := customRecords.Validate()
if err != nil {
return nil, fmt.Errorf("custom records from tlv map "+
"validation error: %w", err)
}
return customRecords, nil
}
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
return ParseCustomRecordsFrom(bytes.NewReader(b))
}
// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
typeMap, err := DecodeRecords(r)
if err != nil {
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
}
return NewCustomRecords(typeMap)
}
// Validate checks that all custom records are in the custom type range.
func (c CustomRecords) Validate() error {
if c == nil {
return nil
}
for key := range c {
if key < MinCustomRecordsTlvType {
return fmt.Errorf("custom records entry with TLV "+
"type below min: %d", MinCustomRecordsTlvType)
}
}
return nil
}
// Copy returns a copy of the custom records.
func (c CustomRecords) Copy() CustomRecords {
if c == nil {
return nil
}
customRecords := make(CustomRecords, len(c))
for k, v := range c {
customRecords[k] = v
}
return customRecords
}
// ExtendRecordProducers extends the given records slice with the custom
// records. The resultant records slice will be sorted if the given records
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
func (c CustomRecords) ExtendRecordProducers(
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
// If the custom records are nil or empty, there is nothing to do.
if len(c) == 0 {
return producers, nil
}
// Validate the custom records.
err := c.Validate()
if err != nil {
return nil, err
}
// Ensure that the existing records slice TLV types are not also present
// in the custom records. If they are, the resultant extended records
// slice would erroneously contain duplicate TLV types.
for _, rp := range producers {
record := rp.Record()
recordTlvType := uint64(record.Type())
_, foundDuplicateTlvType := c[recordTlvType]
if foundDuplicateTlvType {
return nil, fmt.Errorf("custom records contains a TLV "+
"type that is already present in the "+
"existing records: %d", recordTlvType)
}
}
// Convert the custom records map to a TLV record producer slice and
// append them to the exiting records slice.
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
producers = append(producers, customRecordProducers...)
// If the records slice which was given as an argument included TLV
// values greater than or equal to the minimum custom records TLV type
// we will sort the extended records slice to ensure that it is ordered
// correctly.
SortProducers(producers)
return producers, nil
}
// RecordProducers returns a slice of record producers for the custom records.
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
// If the custom records are nil or empty, return an empty slice.
if len(c) == 0 {
return nil
}
// Convert the custom records map to a TLV record producer slice.
records := tlv.MapToRecords(c)
return RecordsAsProducers(records)
}
// Serialize serializes the custom records into a byte slice.
func (c CustomRecords) Serialize() ([]byte, error) {
records := tlv.MapToRecords(c)
return EncodeRecords(records)
}
// SerializeTo serializes the custom records into the given writer.
func (c CustomRecords) SerializeTo(w io.Writer) error {
records := tlv.MapToRecords(c)
return EncodeRecordsTo(w, records)
}
// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
records := fn.Map(
recordProducers,
func(producer tlv.RecordProducer) tlv.Record {
return producer.Record()
},
)
// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
tlv.SortRecords(records)
return records
}
// SortProducers sorts the given record producers by their type.
func SortProducers(producers []tlv.RecordProducer) {
sort.Slice(producers, func(i, j int) bool {
recordI := producers[i].Record()
recordJ := producers[j].Record()
return recordI.Type() < recordJ.Type()
})
}
// TlvMapToRecords converts a TLV map into a slice of records.
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
tlvMapGeneric := make(map[uint64][]byte)
for k, v := range tlvMap {
tlvMapGeneric[uint64(k)] = v
}
return tlv.MapToRecords(tlvMapGeneric)
}
// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
return fn.Map(records, func(record tlv.Record) tlv.RecordProducer {
return &record
})
}
// EncodeRecords encodes the given records into a byte slice.
func EncodeRecords(records []tlv.Record) ([]byte, error) {
var buf bytes.Buffer
if err := EncodeRecordsTo(&buf, records); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// EncodeRecordsTo encodes the given records into the given writer.
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// DecodeRecords decodes the given byte slice into the given records and returns
// the rest as a TLV type map.
func DecodeRecords(r io.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
return tlvStream.DecodeWithParsedTypes(r)
}
// DecodeRecordsP2P decodes the given byte slice into the given records and
// returns the rest as a TLV type map. This function is identical to
// DecodeRecords except that the record size is capped at 65535.
func DecodeRecordsP2P(r *bytes.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
return tlvStream.DecodeWithParsedTypesP2P(r)
}
// AssertUniqueTypes asserts that the given records have unique types.
func AssertUniqueTypes(r []tlv.Record) error {
seen := make(fn.Set[tlv.Type], len(r))
for _, record := range r {
t := record.Type()
if seen.Contains(t) {
return fmt.Errorf("duplicate record type: %d", t)
}
seen.Add(t)
}
return nil
}
package lnwire
import (
"fmt"
"io"
)
// FundingError represents a set of errors that can be encountered and sent
// during the funding workflow.
type FundingError uint8
const (
// ErrMaxPendingChannels is returned by remote peer when the number of
// active pending channels exceeds their maximum policy limit.
ErrMaxPendingChannels FundingError = 1
// ErrSynchronizingChain is returned by a remote peer that receives a
// channel update or a funding request while it's still syncing to the
// latest state of the blockchain.
ErrSynchronizingChain FundingError = 2
// ErrChanTooLarge is returned by a remote peer that receives a
// FundingOpen request for a channel that is above their current
// soft-limit.
ErrChanTooLarge FundingError = 3
)
// String returns a human readable version of the target FundingError.
func (e FundingError) String() string {
switch e {
case ErrMaxPendingChannels:
return "Number of pending channels exceed maximum"
case ErrSynchronizingChain:
return "Synchronizing blockchain"
case ErrChanTooLarge:
return "channel too large"
default:
return "unknown error"
}
}
// Error returns the human readable version of the target FundingError.
//
// NOTE: Satisfies the Error interface.
func (e FundingError) Error() string {
return e.String()
}
// ErrorData is a set of bytes associated with a particular sent error. A
// receiving node SHOULD only print out data verbatim if the string is composed
// solely of printable ASCII characters. For reference, the printable character
// set includes byte values 32 through 127 inclusive.
type ErrorData []byte
// Error represents a generic error bound to an exact channel. The message
// format is purposefully general in order to allow expression of a wide array
// of possible errors. Each Error message is directed at a particular open
// channel referenced by ChannelPoint.
type Error struct {
// ChanID references the active channel in which the error occurred
// within. If the ChanID is all zeros, then this error applies to the
// entire established connection.
ChanID ChannelID
// Data is the attached error data that describes the exact failure
// which caused the error message to be sent.
Data ErrorData
}
// NewError creates a new Error message.
func NewError() *Error {
return &Error{}
}
// A compile time check to ensure Error implements the lnwire.Message
// interface.
var _ Message = (*Error)(nil)
// Error returns the string representation to Error.
//
// NOTE: Satisfies the error interface.
func (c *Error) Error() string {
errMsg := "non-ascii data"
if isASCII(c.Data) {
errMsg = string(c.Data)
}
return fmt.Sprintf("chan_id=%v, err=%v", c.ChanID, errMsg)
}
// Decode deserializes a serialized Error message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *Error) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.Data,
)
}
// Encode serializes the target Error into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *Error) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.Data,
)
}
// MsgType returns the integer uniquely identifying an Error message on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *Error) MsgType() MessageType {
return MsgError
}
// MaxPayloadLength returns the maximum allowed payload size for an Error
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *Error) MaxPayloadLength(uint32) uint32 {
// 32 + 2 + 65501
return MaxMessagePayload
}
// isASCII is a helper method that checks whether all bytes in `data` would be
// printable ASCII characters if interpreted as a string.
func isASCII(data []byte) bool {
for _, c := range data {
if c < 32 || c > 126 {
return false
}
}
return true
}
package lnwire
import (
"encoding/binary"
"errors"
"io"
)
var (
// ErrFeaturePairExists signals an error in feature vector construction
// where the opposing bit in a feature pair has already been set.
ErrFeaturePairExists = errors.New("feature pair exists")
)
// FeatureBit represents a feature that can be enabled in either a local or
// global feature vector at a specific bit position. Feature bits follow the
// "it's OK to be odd" rule, where features at even bit positions must be known
// to a node receiving them from a peer while odd bits do not. In accordance,
// feature bits are usually assigned in pairs, first being assigned an odd bit
// position which may later be changed to the preceding even position once
// knowledge of the feature becomes required on the network.
type FeatureBit uint16
const (
// DataLossProtectRequired is a feature bit that indicates that a peer
// *requires* the other party know about the data-loss-protect optional
// feature. If the remote peer does not know of such a feature, then
// the sending peer SHOLUD disconnect them. The data-loss-protect
// feature allows a peer that's lost partial data to recover their
// settled funds of the latest commitment state.
DataLossProtectRequired FeatureBit = 0
// DataLossProtectOptional is an optional feature bit that indicates
// that the sending peer knows of this new feature and can activate it
// it. The data-loss-protect feature allows a peer that's lost partial
// data to recover their settled funds of the latest commitment state.
DataLossProtectOptional FeatureBit = 1
// InitialRoutingSync is a local feature bit meaning that the receiving
// node should send a complete dump of routing information when a new
// connection is established.
InitialRoutingSync FeatureBit = 3
// UpfrontShutdownScriptRequired is a feature bit which indicates that a
// peer *requires* that the remote peer accept an upfront shutdown script to
// which payout is enforced on cooperative closes.
UpfrontShutdownScriptRequired FeatureBit = 4
// UpfrontShutdownScriptOptional is an optional feature bit which indicates
// that the peer will accept an upfront shutdown script to which payout is
// enforced on cooperative closes.
UpfrontShutdownScriptOptional FeatureBit = 5
// GossipQueriesRequired is a feature bit that indicates that the
// receiving peer MUST know of the set of features that allows nodes to
// more efficiently query the network view of peers on the network for
// reconciliation purposes.
GossipQueriesRequired FeatureBit = 6
// GossipQueriesOptional is an optional feature bit that signals that
// the setting peer knows of the set of features that allows more
// efficient network view reconciliation.
GossipQueriesOptional FeatureBit = 7
// TLVOnionPayloadRequired is a feature bit that indicates a node is
// able to decode the new TLV information included in the onion packet.
TLVOnionPayloadRequired FeatureBit = 8
// TLVOnionPayloadOptional is an optional feature bit that indicates a
// node is able to decode the new TLV information included in the onion
// packet.
TLVOnionPayloadOptional FeatureBit = 9
// StaticRemoteKeyRequired is a required feature bit that signals that
// within one's commitment transaction, the key used for the remote
// party's non-delay output should not be tweaked.
StaticRemoteKeyRequired FeatureBit = 12
// StaticRemoteKeyOptional is an optional feature bit that signals that
// within one's commitment transaction, the key used for the remote
// party's non-delay output should not be tweaked.
StaticRemoteKeyOptional FeatureBit = 13
// PaymentAddrRequired is a required feature bit that signals that a
// node requires payment addresses, which are used to mitigate probing
// attacks on the receiver of a payment.
PaymentAddrRequired FeatureBit = 14
// PaymentAddrOptional is an optional feature bit that signals that a
// node supports payment addresses, which are used to mitigate probing
// attacks on the receiver of a payment.
PaymentAddrOptional FeatureBit = 15
// MPPOptional is a required feature bit that signals that the receiver
// of a payment requires settlement of an invoice with more than one
// HTLC.
MPPRequired FeatureBit = 16
// MPPOptional is an optional feature bit that signals that the receiver
// of a payment supports settlement of an invoice with more than one
// HTLC.
MPPOptional FeatureBit = 17
// WumboChannelsRequired is a required feature bit that signals that a
// node is willing to accept channels larger than 2^24 satoshis.
WumboChannelsRequired FeatureBit = 18
// WumboChannelsOptional is an optional feature bit that signals that a
// node is willing to accept channels larger than 2^24 satoshis.
WumboChannelsOptional FeatureBit = 19
// AnchorsRequired is a required feature bit that signals that the node
// requires channels to be made using commitments having anchor
// outputs.
AnchorsRequired FeatureBit = 20
// AnchorsOptional is an optional feature bit that signals that the
// node supports channels to be made using commitments having anchor
// outputs.
AnchorsOptional FeatureBit = 21
// AnchorsZeroFeeHtlcTxRequired is a required feature bit that signals
// that the node requires channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments.
AnchorsZeroFeeHtlcTxRequired FeatureBit = 22
// AnchorsZeroFeeHtlcTxRequired is an optional feature bit that signals
// that the node supports channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments.
AnchorsZeroFeeHtlcTxOptional FeatureBit = 23
// maxAllowedSize is a maximum allowed size of feature vector.
//
// NOTE: Within the protocol, the maximum allowed message size is 65535
// bytes for all messages. Accounting for the overhead within the feature
// message to signal the type of message, that leaves us with 65533 bytes
// for the init message itself. Next, we reserve 4 bytes to encode the
// lengths of both the local and global feature vectors, so 65529 bytes
// for the local and global features. Knocking off one byte for the sake
// of the calculation, that leads us to 32764 bytes for each feature
// vector, or 131056 different features.
maxAllowedSize = 32764
)
// IsRequired returns true if the feature bit is even, and false otherwise.
func (b FeatureBit) IsRequired() bool {
return b&0x01 == 0x00
}
// Features is a mapping of known feature bits to a descriptive name. All known
// feature bits must be assigned a name in this mapping, and feature bit pairs
// must be assigned together for correct behavior.
var Features = map[FeatureBit]string{
DataLossProtectRequired: "data-loss-protect",
DataLossProtectOptional: "data-loss-protect",
InitialRoutingSync: "initial-routing-sync",
UpfrontShutdownScriptRequired: "upfront-shutdown-script",
UpfrontShutdownScriptOptional: "upfront-shutdown-script",
GossipQueriesRequired: "gossip-queries",
GossipQueriesOptional: "gossip-queries",
TLVOnionPayloadRequired: "tlv-onion",
TLVOnionPayloadOptional: "tlv-onion",
StaticRemoteKeyOptional: "static-remote-key",
StaticRemoteKeyRequired: "static-remote-key",
PaymentAddrOptional: "payment-addr",
PaymentAddrRequired: "payment-addr",
MPPOptional: "multi-path-payments",
MPPRequired: "multi-path-payments",
AnchorsRequired: "anchor-commitments",
AnchorsOptional: "anchor-commitments",
AnchorsZeroFeeHtlcTxRequired: "anchors-zero-fee-htlc-tx",
AnchorsZeroFeeHtlcTxOptional: "anchors-zero-fee-htlc-tx",
WumboChannelsRequired: "wumbo-channels",
WumboChannelsOptional: "wumbo-channels",
}
// RawFeatureVector represents a set of feature bits as defined in BOLT-09. A
// RawFeatureVector itself just stores a set of bit flags but can be used to
// construct a FeatureVector which binds meaning to each bit. Feature vectors
// can be serialized and deserialized to/from a byte representation that is
// transmitted in Lightning network messages.
type RawFeatureVector struct {
features map[FeatureBit]bool
}
// NewRawFeatureVector creates a feature vector with all of the feature bits
// given as arguments enabled.
func NewRawFeatureVector(bits ...FeatureBit) *RawFeatureVector {
fv := &RawFeatureVector{features: make(map[FeatureBit]bool)}
for _, bit := range bits {
fv.Set(bit)
}
return fv
}
// Merges sets all feature bits in other on the receiver's feature vector.
func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error {
for bit := range other.features {
err := fv.SafeSet(bit)
if err != nil {
return err
}
}
return nil
}
// Clone makes a copy of a feature vector.
func (fv *RawFeatureVector) Clone() *RawFeatureVector {
newFeatures := NewRawFeatureVector()
for bit := range fv.features {
newFeatures.Set(bit)
}
return newFeatures
}
// IsSet returns whether a particular feature bit is enabled in the vector.
func (fv *RawFeatureVector) IsSet(feature FeatureBit) bool {
return fv.features[feature]
}
// Set marks a feature as enabled in the vector.
func (fv *RawFeatureVector) Set(feature FeatureBit) {
fv.features[feature] = true
}
// SafeSet sets the chosen feature bit in the feature vector, but returns an
// error if the opposing feature bit is already set. This ensures both that we
// are creating properly structured feature vectors, and in some cases, that
// peers are sending properly encoded ones, i.e. it can't be both optional and
// required.
func (fv *RawFeatureVector) SafeSet(feature FeatureBit) error {
if _, ok := fv.features[feature^1]; ok {
return ErrFeaturePairExists
}
fv.Set(feature)
return nil
}
// Unset marks a feature as disabled in the vector.
func (fv *RawFeatureVector) Unset(feature FeatureBit) {
delete(fv.features, feature)
}
// SerializeSize returns the number of bytes needed to represent feature vector
// in byte format.
func (fv *RawFeatureVector) SerializeSize() int {
// We calculate byte-length via the largest bit index.
return fv.serializeSize(8)
}
// SerializeSize32 returns the number of bytes needed to represent feature
// vector in base32 format.
func (fv *RawFeatureVector) SerializeSize32() int {
// We calculate base32-length via the largest bit index.
return fv.serializeSize(5)
}
// serializeSize returns the number of bytes required to encode the feature
// vector using at most width bits per encoded byte.
func (fv *RawFeatureVector) serializeSize(width int) int {
// Find the largest feature bit index
max := -1
for feature := range fv.features {
index := int(feature)
if index > max {
max = index
}
}
if max == -1 {
return 0
}
return max/width + 1
}
// Encode writes the feature vector in byte representation. Every feature
// encoded as a bit, and the bit vector is serialized using the least number of
// bytes. Since the bit vector length is variable, the first two bytes of the
// serialization represent the length.
func (fv *RawFeatureVector) Encode(w io.Writer) error {
// Write length of feature vector.
var l [2]byte
length := fv.SerializeSize()
binary.BigEndian.PutUint16(l[:], uint16(length))
if _, err := w.Write(l[:]); err != nil {
return err
}
return fv.encode(w, length, 8)
}
// EncodeBase256 writes the feature vector in base256 representation. Every
// feature is encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) EncodeBase256(w io.Writer) error {
length := fv.SerializeSize()
return fv.encode(w, length, 8)
}
// EncodeBase32 writes the feature vector in base32 representation. Every feature
// is encoded as a bit, and the bit vector is serialized using the least number of
// bytes.
func (fv *RawFeatureVector) EncodeBase32(w io.Writer) error {
length := fv.SerializeSize32()
return fv.encode(w, length, 5)
}
// encode writes the feature vector
func (fv *RawFeatureVector) encode(w io.Writer, length, width int) error {
// Generate the data and write it.
data := make([]byte, length)
for feature := range fv.features {
byteIndex := int(feature) / width
bitIndex := int(feature) % width
data[length-byteIndex-1] |= 1 << uint(bitIndex)
}
_, err := w.Write(data)
return err
}
// Decode reads the feature vector from its byte representation. Every feature
// is encoded as a bit, and the bit vector is serialized using the least number
// of bytes. Since the bit vector length is variable, the first two bytes of the
// serialization represent the length.
func (fv *RawFeatureVector) Decode(r io.Reader) error {
// Read the length of the feature vector.
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
length := binary.BigEndian.Uint16(l[:])
return fv.decode(r, int(length), 8)
}
// DecodeBase256 reads the feature vector from its base256 representation. Every
// feature encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) DecodeBase256(r io.Reader, length int) error {
return fv.decode(r, length, 8)
}
// DecodeBase32 reads the feature vector from its base32 representation. Every
// feature encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) DecodeBase32(r io.Reader, length int) error {
return fv.decode(r, length, 5)
}
// decode reads a feature vector from the next length bytes of the io.Reader,
// assuming each byte has width feature bits encoded per byte.
func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error {
// Read the feature vector data.
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return err
}
// Set feature bits from parsed data.
bitsNumber := len(data) * width
for i := 0; i < bitsNumber; i++ {
byteIndex := int(i / width)
bitIndex := uint(i % width)
if (data[length-byteIndex-1]>>bitIndex)&1 == 1 {
fv.Set(FeatureBit(i))
}
}
return nil
}
// FeatureVector represents a set of enabled features. The set stores
// information on enabled flags and metadata about the feature names. A feature
// vector is serializable to a compact byte representation that is included in
// Lightning network messages.
type FeatureVector struct {
*RawFeatureVector
featureNames map[FeatureBit]string
}
// NewFeatureVector constructs a new FeatureVector from a raw feature vector
// and mapping of feature definitions. If the feature vector argument is nil, a
// new one will be constructed with no enabled features.
func NewFeatureVector(featureVector *RawFeatureVector,
featureNames map[FeatureBit]string) *FeatureVector {
if featureVector == nil {
featureVector = NewRawFeatureVector()
}
return &FeatureVector{
RawFeatureVector: featureVector,
featureNames: featureNames,
}
}
// EmptyFeatureVector returns a feature vector with no bits set.
func EmptyFeatureVector() *FeatureVector {
return NewFeatureVector(nil, Features)
}
// HasFeature returns whether a particular feature is included in the set. The
// feature can be seen as set either if the bit is set directly OR the queried
// bit has the same meaning as its corresponding even/odd bit, which is set
// instead. The second case is because feature bits are generally assigned in
// pairs where both the even and odd position represent the same feature.
func (fv *FeatureVector) HasFeature(feature FeatureBit) bool {
return fv.IsSet(feature) ||
(fv.isFeatureBitPair(feature) && fv.IsSet(feature^1))
}
// RequiresFeature returns true if the referenced feature vector *requires*
// that the given required bit be set. This method can be used with both
// optional and required feature bits as a parameter.
func (fv *FeatureVector) RequiresFeature(feature FeatureBit) bool {
// If we weren't passed a required feature bit, then we'll flip the
// lowest bit to query for the required version of the feature. This
// lets callers pass in both the optional and required bits.
if !feature.IsRequired() {
feature ^= 1
}
return fv.IsSet(feature)
}
// UnknownRequiredFeatures returns a list of feature bits set in the vector
// that are unknown and in an even bit position. Feature bits with an even
// index must be known to a node receiving the feature vector in a message.
func (fv *FeatureVector) UnknownRequiredFeatures() []FeatureBit {
var unknown []FeatureBit
for feature := range fv.features {
if feature%2 == 0 && !fv.IsKnown(feature) {
unknown = append(unknown, feature)
}
}
return unknown
}
// Name returns a string identifier for the feature represented by this bit. If
// the bit does not represent a known feature, this returns a string indicating
// as such.
func (fv *FeatureVector) Name(bit FeatureBit) string {
name, known := fv.featureNames[bit]
if !known {
return "unknown"
}
return name
}
// IsKnown returns whether this feature bit represents a known feature.
func (fv *FeatureVector) IsKnown(bit FeatureBit) bool {
_, known := fv.featureNames[bit]
return known
}
// isFeatureBitPair returns whether this feature bit and its corresponding
// even/odd bit both represent the same feature. This may often be the case as
// bits are generally assigned in pairs, first being assigned an odd bit
// position then being promoted to an even bit position once the network is
// ready.
func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool {
name1, known1 := fv.featureNames[bit]
name2, known2 := fv.featureNames[bit^1]
return known1 && known2 && name1 == name2
}
// Features returns the set of raw features contained in the feature vector.
func (fv *FeatureVector) Features() map[FeatureBit]struct{} {
fs := make(map[FeatureBit]struct{}, len(fv.RawFeatureVector.features))
for b := range fv.RawFeatureVector.features {
fs[b] = struct{}{}
}
return fs
}
// Clone copies a feature vector, carrying over its feature bits. The feature
// names are not copied.
func (fv *FeatureVector) Clone() *FeatureVector {
features := fv.RawFeatureVector.Clone()
return NewFeatureVector(features, fv.featureNames)
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/wire"
)
// FundingCreated is sent from Alice (the initiator) to Bob (the responder),
// once Alice receives Bob's contributions as well as his channel constraints.
// Once bob receives this message, he'll gain access to an immediately
// broadcastable commitment transaction and will reply with a signature for
// Alice's version of the commitment transaction.
type FundingCreated struct {
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// FundingPoint is the outpoint of the funding transaction created by
// Alice. With this, Bob is able to generate both his version and
// Alice's version of the commitment transaction.
FundingPoint wire.OutPoint
// CommitSig is Alice's signature from Bob's version of the commitment
// transaction.
CommitSig Sig
}
// A compile time check to ensure FundingCreated implements the lnwire.Message
// interface.
var _ Message = (*FundingCreated)(nil)
// Encode serializes the target FundingCreated into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig)
}
// Decode deserializes the serialized FundingCreated stored in the passed
// io.Reader into the target FundingCreated using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) Decode(r io.Reader, pver uint32) error {
return ReadElements(r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// FundingCreated on the wire.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) MsgType() MessageType {
return MsgFundingCreated
}
// MaxPayloadLength returns the maximum allowed payload length for a
// FundingCreated message.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) MaxPayloadLength(uint32) uint32 {
// 32 + 32 + 2 + 64
return 130
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
)
// FundingLocked is the message that both parties to a new channel creation
// send once they have observed the funding transaction being confirmed on the
// blockchain. FundingLocked contains the signatures necessary for the channel
// participants to advertise the existence of the channel to the rest of the
// network.
type FundingLocked struct {
// ChanID is the outpoint of the channel's funding transaction. This
// can be used to query for the channel in the database.
ChanID ChannelID
// NextPerCommitmentPoint is the secret that can be used to revoke the
// next commitment transaction for the channel.
NextPerCommitmentPoint *btcec.PublicKey
}
// NewFundingLocked creates a new FundingLocked message, populating it with the
// necessary IDs and revocation secret.
func NewFundingLocked(cid ChannelID, npcp *btcec.PublicKey) *FundingLocked {
return &FundingLocked{
ChanID: cid,
NextPerCommitmentPoint: npcp,
}
}
// A compile time check to ensure FundingLocked implements the lnwire.Message
// interface.
var _ Message = (*FundingLocked)(nil)
// Decode deserializes the serialized FundingLocked message stored in the
// passed io.Reader into the target FundingLocked using the deserialization
// rules defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (c *FundingLocked) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.NextPerCommitmentPoint)
}
// Encode serializes the target FundingLocked message into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (c *FundingLocked) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.NextPerCommitmentPoint)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// FundingLocked message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *FundingLocked) MsgType() MessageType {
return MsgFundingLocked
}
// MaxPayloadLength returns the maximum allowed payload length for a
// FundingLocked message. This is calculated by summing the max length of all
// the fields within a FundingLocked message.
//
// This is part of the lnwire.Message interface.
func (c *FundingLocked) MaxPayloadLength(uint32) uint32 {
var length uint32
// ChanID - 32 bytes
length += 32
// NextPerCommitmentPoint - 33 bytes
length += 33
// 65 bytes
return length
}
package lnwire
import "io"
// FundingSigned is sent from Bob (the responder) to Alice (the initiator)
// after receiving the funding outpoint and her signature for Bob's version of
// the commitment transaction.
type FundingSigned struct {
// ChannelPoint is the particular active channel that this
// FundingSigned is bound to.
ChanID ChannelID
// CommitSig is Bob's signature for Alice's version of the commitment
// transaction.
CommitSig Sig
}
// A compile time check to ensure FundingSigned implements the lnwire.Message
// interface.
var _ Message = (*FundingSigned)(nil)
// Encode serializes the target FundingSigned into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, f.ChanID, f.CommitSig)
}
// Decode deserializes the serialized FundingSigned stored in the passed
// io.Reader into the target FundingSigned using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) Decode(r io.Reader, pver uint32) error {
return ReadElements(r, &f.ChanID, &f.CommitSig)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// FundingSigned on the wire.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) MsgType() MessageType {
return MsgFundingSigned
}
// MaxPayloadLength returns the maximum allowed payload length for a
// FundingSigned message.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) MaxPayloadLength(uint32) uint32 {
// 32 + 64
return 96
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// GossipTimestampRange is a message that allows the sender to restrict the set
// of future gossip announcements sent by the receiver. Nodes should send this
// if they have the gossip-queries feature bit active. Nodes are able to send
// new GossipTimestampRange messages to replace the prior window.
type GossipTimestampRange struct {
// ChainHash denotes the chain that the sender wishes to restrict the
// set of received announcements of.
ChainHash chainhash.Hash
// FirstTimestamp is the timestamp of the earliest announcement message
// that should be sent by the receiver.
FirstTimestamp uint32
// TimestampRange is the horizon beyond the FirstTimestamp that any
// announcement messages should be sent for. The receiving node MUST
// NOT send any announcements that have a timestamp greater than
// FirstTimestamp + TimestampRange.
TimestampRange uint32
}
// NewGossipTimestampRange creates a new empty GossipTimestampRange message.
func NewGossipTimestampRange() *GossipTimestampRange {
return &GossipTimestampRange{}
}
// A compile time check to ensure GossipTimestampRange implements the
// lnwire.Message interface.
var _ Message = (*GossipTimestampRange)(nil)
// Decode deserializes a serialized GossipTimestampRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
g.ChainHash[:],
&g.FirstTimestamp,
&g.TimestampRange,
)
}
// Encode serializes the target GossipTimestampRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
g.ChainHash[:],
g.FirstTimestamp,
g.TimestampRange,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) MsgType() MessageType {
return MsgGossipTimestampRange
}
// MaxPayloadLength returns the maximum allowed payload size for a
// GossipTimestampRange complete message observing the specified protocol
// version.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) MaxPayloadLength(uint32) uint32 {
// 32 + 4 + 4
//
// TODO(roasbeef): update to 8 byte timestamps?
return 40
}
package lnwire
import "io"
// Init is the first message reveals the features supported or required by this
// node. Nodes wait for receipt of the other's features to simplify error
// diagnosis where features are incompatible. Each node MUST wait to receive
// init before sending any other messages.
type Init struct {
// GlobalFeatures is a legacy feature vector used for backwards
// compatibility with older nodes. Any features defined here should be
// merged with those presented in Features.
GlobalFeatures *RawFeatureVector
// Features is a feature vector containing the features supported by
// the remote node.
//
// NOTE: Older nodes may place some features in GlobalFeatures, but all
// new features are to be added in Features. When handling an Init
// message, any GlobalFeatures should be merged into the unified
// Features field.
Features *RawFeatureVector
}
// NewInitMessage creates new instance of init message object.
func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init {
return &Init{
GlobalFeatures: gf,
Features: f,
}
}
// A compile time check to ensure Init implements the lnwire.Message
// interface.
var _ Message = (*Init)(nil)
// Decode deserializes a serialized Init message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (msg *Init) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&msg.GlobalFeatures,
&msg.Features,
)
}
// Encode serializes the target Init into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (msg *Init) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
msg.GlobalFeatures,
msg.Features,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (msg *Init) MsgType() MessageType {
return MsgInit
}
// MaxPayloadLength returns the maximum allowed payload size for an Init
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (msg *Init) MaxPayloadLength(uint32) uint32 {
return 2 + 2 + maxAllowedSize + 2 + maxAllowedSize
}
package lnwire
import (
"bytes"
"encoding/binary"
"fmt"
"image/color"
"io"
"math"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/tor"
)
// MaxSliceLength is the maximum allowed length for any opaque byte slices in
// the wire protocol.
const MaxSliceLength = 65535
// PkScript is simple type definition which represents a raw serialized public
// key script.
type PkScript []byte
// addressType specifies the network protocol and version that should be used
// when connecting to a node at a particular address.
type addressType uint8
const (
// noAddr denotes a blank address. An address of this type indicates
// that a node doesn't have any advertised addresses.
noAddr addressType = 0
// tcp4Addr denotes an IPv4 TCP address.
tcp4Addr addressType = 1
// tcp6Addr denotes an IPv6 TCP address.
tcp6Addr addressType = 2
// v2OnionAddr denotes a version 2 Tor onion service address.
v2OnionAddr addressType = 3
// v3OnionAddr denotes a version 3 Tor (prop224) onion service address.
v3OnionAddr addressType = 4
)
// AddrLen returns the number of bytes that it takes to encode the target
// address.
func (a addressType) AddrLen() uint16 {
switch a {
case noAddr:
return 0
case tcp4Addr:
return 6
case tcp6Addr:
return 18
case v2OnionAddr:
return 12
case v3OnionAddr:
return 37
default:
return 0
}
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for the wire protocol. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
//
// TODO(roasbeef): this should eventually draw from a buffer pool for
// serialization.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case NodeAlias:
if _, err := w.Write(e[:]); err != nil {
return err
}
case ShortChanIDEncoding:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint8:
var b [1]byte
b[0] = e
if _, err := w.Write(b[:]); err != nil {
return err
}
case FundingFlag:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint16:
var b [2]byte
binary.BigEndian.PutUint16(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case ChanUpdateMsgFlags:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case ChanUpdateChanFlags:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case MilliSatoshi:
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case btcutil.Amount:
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint32:
var b [4]byte
binary.BigEndian.PutUint32(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint64:
var b [8]byte
binary.BigEndian.PutUint64(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case *btcec.PublicKey:
if e == nil {
return fmt.Errorf("cannot write nil pubkey")
}
var b [33]byte
serializedPubkey := e.SerializeCompressed()
copy(b[:], serializedPubkey)
if _, err := w.Write(b[:]); err != nil {
return err
}
case []Sig:
var b [2]byte
numSigs := uint16(len(e))
binary.BigEndian.PutUint16(b[:], numSigs)
if _, err := w.Write(b[:]); err != nil {
return err
}
for _, sig := range e {
if err := WriteElement(w, sig); err != nil {
return err
}
}
case Sig:
// Write buffer
if _, err := w.Write(e[:]); err != nil {
return err
}
case PingPayload:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case PongPayload:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case ErrorData:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case OpaqueReason:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case [33]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case PkScript:
// The largest script we'll accept is a p2wsh which is exactly
// 34 bytes long.
scriptLength := len(e)
if scriptLength > 34 {
return fmt.Errorf("'PkScript' too long")
}
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case *RawFeatureVector:
if e == nil {
return fmt.Errorf("cannot write nil feature vector")
}
if err := e.Encode(w); err != nil {
return err
}
case wire.OutPoint:
var h [32]byte
copy(h[:], e.Hash[:])
if _, err := w.Write(h[:]); err != nil {
return err
}
if e.Index > math.MaxUint16 {
return fmt.Errorf("index for outpoint (%v) is "+
"greater than max index of %v", e.Index,
math.MaxUint16)
}
var idx [2]byte
binary.BigEndian.PutUint16(idx[:], uint16(e.Index))
if _, err := w.Write(idx[:]); err != nil {
return err
}
case ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case FailCode:
if err := WriteElement(w, uint16(e)); err != nil {
return err
}
case ShortChannelID:
// Check that field fit in 3 bytes and write the blockHeight
if e.BlockHeight > ((1 << 24) - 1) {
return errors.New("block height should fit in 3 bytes")
}
var blockHeight [4]byte
binary.BigEndian.PutUint32(blockHeight[:], e.BlockHeight)
if _, err := w.Write(blockHeight[1:]); err != nil {
return err
}
// Check that field fit in 3 bytes and write the txIndex
if e.TxIndex > ((1 << 24) - 1) {
return errors.New("tx index should fit in 3 bytes")
}
var txIndex [4]byte
binary.BigEndian.PutUint32(txIndex[:], e.TxIndex)
if _, err := w.Write(txIndex[1:]); err != nil {
return err
}
// Write the txPosition
var txPosition [2]byte
binary.BigEndian.PutUint16(txPosition[:], e.TxPosition)
if _, err := w.Write(txPosition[:]); err != nil {
return err
}
case *net.TCPAddr:
if e == nil {
return fmt.Errorf("cannot write nil TCPAddr")
}
if e.IP.To4() != nil {
var descriptor [1]byte
descriptor[0] = uint8(tcp4Addr)
if _, err := w.Write(descriptor[:]); err != nil {
return err
}
var ip [4]byte
copy(ip[:], e.IP.To4())
if _, err := w.Write(ip[:]); err != nil {
return err
}
} else {
var descriptor [1]byte
descriptor[0] = uint8(tcp6Addr)
if _, err := w.Write(descriptor[:]); err != nil {
return err
}
var ip [16]byte
copy(ip[:], e.IP.To16())
if _, err := w.Write(ip[:]); err != nil {
return err
}
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], uint16(e.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
case *tor.OnionAddr:
if e == nil {
return errors.New("cannot write nil onion address")
}
var suffixIndex int
switch len(e.OnionService) {
case tor.V2Len:
descriptor := []byte{byte(v2OnionAddr)}
if _, err := w.Write(descriptor); err != nil {
return err
}
suffixIndex = tor.V2Len - tor.OnionSuffixLen
case tor.V3Len:
descriptor := []byte{byte(v3OnionAddr)}
if _, err := w.Write(descriptor); err != nil {
return err
}
suffixIndex = tor.V3Len - tor.OnionSuffixLen
default:
return errors.New("unknown onion service length")
}
host, err := tor.Base32Encoding.DecodeString(
e.OnionService[:suffixIndex],
)
if err != nil {
return err
}
if _, err := w.Write(host); err != nil {
return err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], uint16(e.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
case []net.Addr:
// First, we'll encode all the addresses into an intermediate
// buffer. We need to do this in order to compute the total
// length of the addresses.
var addrBuf bytes.Buffer
for _, address := range e {
if err := WriteElement(&addrBuf, address); err != nil {
return err
}
}
// With the addresses fully encoded, we can now write out the
// number of bytes needed to encode them.
addrLen := addrBuf.Len()
if err := WriteElement(w, uint16(addrLen)); err != nil {
return err
}
// Finally, we'll write out the raw addresses themselves, but
// only if we have any bytes to write.
if addrLen > 0 {
if _, err := w.Write(addrBuf.Bytes()); err != nil {
return err
}
}
case color.RGBA:
if err := WriteElements(w, e.R, e.G, e.B); err != nil {
return err
}
case DeliveryAddress:
var length [2]byte
binary.BigEndian.PutUint16(length[:], uint16(len(e)))
if _, err := w.Write(length[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case bool:
var b [1]byte
if e {
b[0] = 1
}
if _, err := w.Write(b[:]); err != nil {
return err
}
default:
return fmt.Errorf("unknown type in WriteElement: %T", e)
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of lnwire.
func ReadElement(r io.Reader, element interface{}) error {
var err error
switch e := element.(type) {
case *bool:
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
if b[0] == 1 {
*e = true
}
case *NodeAlias:
var a [32]byte
if _, err := io.ReadFull(r, a[:]); err != nil {
return err
}
alias, err := NewNodeAlias(string(a[:]))
if err != nil {
return err
}
*e = alias
case *ShortChanIDEncoding:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ShortChanIDEncoding(b[0])
case *uint8:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = b[0]
case *FundingFlag:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = FundingFlag(b[0])
case *uint16:
var b [2]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint16(b[:])
case *ChanUpdateMsgFlags:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ChanUpdateMsgFlags(b[0])
case *ChanUpdateChanFlags:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ChanUpdateChanFlags(b[0])
case *uint32:
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint32(b[:])
case *uint64:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint64(b[:])
case *MilliSatoshi:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = MilliSatoshi(int64(binary.BigEndian.Uint64(b[:])))
case *btcutil.Amount:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = btcutil.Amount(int64(binary.BigEndian.Uint64(b[:])))
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err = io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case **RawFeatureVector:
f := NewRawFeatureVector()
err = f.Decode(r)
if err != nil {
return err
}
*e = f
case *[]Sig:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
numSigs := binary.BigEndian.Uint16(l[:])
var sigs []Sig
if numSigs > 0 {
sigs = make([]Sig, numSigs)
for i := 0; i < int(numSigs); i++ {
if err := ReadElement(r, &sigs[i]); err != nil {
return err
}
}
}
*e = sigs
case *Sig:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *OpaqueReason:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
reasonLen := binary.BigEndian.Uint16(l[:])
*e = OpaqueReason(make([]byte, reasonLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *ErrorData:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
errorLen := binary.BigEndian.Uint16(l[:])
*e = ErrorData(make([]byte, errorLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *PingPayload:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
pingLen := binary.BigEndian.Uint16(l[:])
*e = PingPayload(make([]byte, pingLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *PongPayload:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
pongLen := binary.BigEndian.Uint16(l[:])
*e = PongPayload(make([]byte, pongLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *[33]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case []byte:
if _, err := io.ReadFull(r, e); err != nil {
return err
}
case *PkScript:
pkScript, err := wire.ReadVarBytes(r, 0, 34, "pkscript")
if err != nil {
return err
}
*e = pkScript
case *wire.OutPoint:
var h [32]byte
if _, err = io.ReadFull(r, h[:]); err != nil {
return err
}
hash, err := chainhash.NewHash(h[:])
if err != nil {
return err
}
var idxBytes [2]byte
_, err = io.ReadFull(r, idxBytes[:])
if err != nil {
return err
}
index := binary.BigEndian.Uint16(idxBytes[:])
*e = wire.OutPoint{
Hash: *hash,
Index: uint32(index),
}
case *FailCode:
if err := ReadElement(r, (*uint16)(e)); err != nil {
return err
}
case *ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *ShortChannelID:
var blockHeight [4]byte
if _, err = io.ReadFull(r, blockHeight[1:]); err != nil {
return err
}
var txIndex [4]byte
if _, err = io.ReadFull(r, txIndex[1:]); err != nil {
return err
}
var txPosition [2]byte
if _, err = io.ReadFull(r, txPosition[:]); err != nil {
return err
}
*e = ShortChannelID{
BlockHeight: binary.BigEndian.Uint32(blockHeight[:]),
TxIndex: binary.BigEndian.Uint32(txIndex[:]),
TxPosition: binary.BigEndian.Uint16(txPosition[:]),
}
case *[]net.Addr:
// First, we'll read the number of total bytes that have been
// used to encode the set of addresses.
var numAddrsBytes [2]byte
if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil {
return err
}
addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:])
// With the number of addresses, read, we'll now pull in the
// buffer of the encoded addresses into memory.
addrs := make([]byte, addrsLen)
if _, err := io.ReadFull(r, addrs[:]); err != nil {
return err
}
addrBuf := bytes.NewReader(addrs)
// Finally, we'll parse the remaining address payload in
// series, using the first byte to denote how to decode the
// address itself.
var (
addresses []net.Addr
addrBytesRead uint16
)
for addrBytesRead < addrsLen {
var descriptor [1]byte
if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil {
return err
}
addrBytesRead++
var address net.Addr
switch aType := addressType(descriptor[0]); aType {
case noAddr:
addrBytesRead += aType.AddrLen()
continue
case tcp4Addr:
var ip [4]byte
if _, err := io.ReadFull(addrBuf, ip[:]); err != nil {
return err
}
var port [2]byte
if _, err := io.ReadFull(addrBuf, port[:]); err != nil {
return err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
addrBytesRead += aType.AddrLen()
case tcp6Addr:
var ip [16]byte
if _, err := io.ReadFull(addrBuf, ip[:]); err != nil {
return err
}
var port [2]byte
if _, err := io.ReadFull(addrBuf, port[:]); err != nil {
return err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
addrBytesRead += aType.AddrLen()
case v2OnionAddr:
var h [tor.V2DecodedLen]byte
if _, err := io.ReadFull(addrBuf, h[:]); err != nil {
return err
}
var p [2]byte
if _, err := io.ReadFull(addrBuf, p[:]); err != nil {
return err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
addrBytesRead += aType.AddrLen()
case v3OnionAddr:
var h [tor.V3DecodedLen]byte
if _, err := io.ReadFull(addrBuf, h[:]); err != nil {
return err
}
var p [2]byte
if _, err := io.ReadFull(addrBuf, p[:]); err != nil {
return err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
addrBytesRead += aType.AddrLen()
default:
return &ErrUnknownAddrType{aType}
}
addresses = append(addresses, address)
}
*e = addresses
case *color.RGBA:
err := ReadElements(r,
&e.R,
&e.G,
&e.B,
)
if err != nil {
return err
}
case *DeliveryAddress:
var addrLen [2]byte
if _, err = io.ReadFull(r, addrLen[:]); err != nil {
return err
}
length := binary.BigEndian.Uint16(addrLen[:])
var addrBytes [deliveryAddressMaxSize]byte
if length > deliveryAddressMaxSize {
return fmt.Errorf("cannot read %d bytes into addrBytes", length)
}
if _, err = io.ReadFull(r, addrBytes[:length]); err != nil {
return err
}
*e = addrBytes[:length]
default:
return fmt.Errorf("unknown type in ReadElement: %T", e)
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
// Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2015-2016 The Decred developers
// code derived from https://github .com/btcsuite/btcd/blob/master/wire/message.go
// Copyright (C) 2015-2022 The Lightning Network Developers
package lnwire
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
// MaxMessagePayload is the maximum bytes a message can be regardless of other
// individual limits imposed by messages themselves.
const MaxMessagePayload = 65535 // 65KB
// MessageType is the unique 2 byte big-endian integer that indicates the type
// of message on the wire. All messages have a very simple header which
// consists simply of 2-byte message type. We omit a length field, and checksum
// as the Lightning Protocol is intended to be encapsulated within a
// confidential+authenticated cryptographic messaging protocol.
type MessageType uint16
// The currently defined message types within this current version of the
// Lightning protocol.
const (
MsgInit MessageType = 16
MsgError = 17
MsgPing = 18
MsgPong = 19
MsgOpenChannel = 32
MsgAcceptChannel = 33
MsgFundingCreated = 34
MsgFundingSigned = 35
MsgFundingLocked = 36
MsgShutdown = 38
MsgClosingSigned = 39
MsgUpdateAddHTLC = 128
MsgUpdateFulfillHTLC = 130
MsgUpdateFailHTLC = 131
MsgCommitSig = 132
MsgRevokeAndAck = 133
MsgUpdateFee = 134
MsgUpdateFailMalformedHTLC = 135
MsgChannelReestablish = 136
MsgChannelAnnouncement = 256
MsgNodeAnnouncement = 257
MsgChannelUpdate = 258
MsgAnnounceSignatures = 259
MsgQueryShortChanIDs = 261
MsgReplyShortChanIDsEnd = 262
MsgQueryChannelRange = 263
MsgReplyChannelRange = 264
MsgGossipTimestampRange = 265
)
// String return the string representation of message type.
func (t MessageType) String() string {
switch t {
case MsgInit:
return "Init"
case MsgOpenChannel:
return "MsgOpenChannel"
case MsgAcceptChannel:
return "MsgAcceptChannel"
case MsgFundingCreated:
return "MsgFundingCreated"
case MsgFundingSigned:
return "MsgFundingSigned"
case MsgFundingLocked:
return "FundingLocked"
case MsgShutdown:
return "Shutdown"
case MsgClosingSigned:
return "ClosingSigned"
case MsgUpdateAddHTLC:
return "UpdateAddHTLC"
case MsgUpdateFailHTLC:
return "UpdateFailHTLC"
case MsgUpdateFulfillHTLC:
return "UpdateFulfillHTLC"
case MsgCommitSig:
return "CommitSig"
case MsgRevokeAndAck:
return "RevokeAndAck"
case MsgUpdateFailMalformedHTLC:
return "UpdateFailMalformedHTLC"
case MsgChannelReestablish:
return "ChannelReestablish"
case MsgError:
return "Error"
case MsgChannelAnnouncement:
return "ChannelAnnouncement"
case MsgChannelUpdate:
return "ChannelUpdate"
case MsgNodeAnnouncement:
return "NodeAnnouncement"
case MsgPing:
return "Ping"
case MsgAnnounceSignatures:
return "AnnounceSignatures"
case MsgPong:
return "Pong"
case MsgUpdateFee:
return "UpdateFee"
case MsgQueryShortChanIDs:
return "QueryShortChanIDs"
case MsgReplyShortChanIDsEnd:
return "ReplyShortChanIDsEnd"
case MsgQueryChannelRange:
return "QueryChannelRange"
case MsgReplyChannelRange:
return "ReplyChannelRange"
case MsgGossipTimestampRange:
return "GossipTimestampRange"
default:
return "<unknown>"
}
}
// UnknownMessage is an implementation of the error interface that allows the
// creation of an error in response to an unknown message.
type UnknownMessage struct {
messageType MessageType
}
// Error returns a human readable string describing the error.
//
// This is part of the error interface.
func (u *UnknownMessage) Error() string {
return fmt.Sprintf("unable to parse message of unknown type: %v",
u.messageType)
}
// Serializable is an interface which defines a lightning wire serializable
// object.
type Serializable interface {
// Decode reads the bytes stream and converts it to the object.
Decode(io.Reader, uint32) error
// Encode converts object to the bytes stream and write it into the
// writer.
Encode(io.Writer, uint32) error
}
// Message is an interface that defines a lightning wire protocol message. The
// interface is general in order to allow implementing types full control over
// the representation of its data.
type Message interface {
Serializable
MsgType() MessageType
MaxPayloadLength(uint32) uint32
}
// makeEmptyMessage creates a new empty message of the proper concrete type
// based on the passed message type.
func makeEmptyMessage(msgType MessageType) (Message, error) {
var msg Message
switch msgType {
case MsgInit:
msg = &Init{}
case MsgOpenChannel:
msg = &OpenChannel{}
case MsgAcceptChannel:
msg = &AcceptChannel{}
case MsgFundingCreated:
msg = &FundingCreated{}
case MsgFundingSigned:
msg = &FundingSigned{}
case MsgFundingLocked:
msg = &FundingLocked{}
case MsgShutdown:
msg = &Shutdown{}
case MsgClosingSigned:
msg = &ClosingSigned{}
case MsgUpdateAddHTLC:
msg = &UpdateAddHTLC{}
case MsgUpdateFailHTLC:
msg = &UpdateFailHTLC{}
case MsgUpdateFulfillHTLC:
msg = &UpdateFulfillHTLC{}
case MsgCommitSig:
msg = &CommitSig{}
case MsgRevokeAndAck:
msg = &RevokeAndAck{}
case MsgUpdateFee:
msg = &UpdateFee{}
case MsgUpdateFailMalformedHTLC:
msg = &UpdateFailMalformedHTLC{}
case MsgChannelReestablish:
msg = &ChannelReestablish{}
case MsgError:
msg = &Error{}
case MsgChannelAnnouncement:
msg = &ChannelAnnouncement{}
case MsgChannelUpdate:
msg = &ChannelUpdate{}
case MsgNodeAnnouncement:
msg = &NodeAnnouncement{}
case MsgPing:
msg = &Ping{}
case MsgAnnounceSignatures:
msg = &AnnounceSignatures{}
case MsgPong:
msg = &Pong{}
case MsgQueryShortChanIDs:
msg = &QueryShortChanIDs{}
case MsgReplyShortChanIDsEnd:
msg = &ReplyShortChanIDsEnd{}
case MsgQueryChannelRange:
msg = &QueryChannelRange{}
case MsgReplyChannelRange:
msg = &ReplyChannelRange{}
case MsgGossipTimestampRange:
msg = &GossipTimestampRange{}
default:
return nil, &UnknownMessage{msgType}
}
return msg, nil
}
// WriteMessage writes a lightning Message to w including the necessary header
// information and returns the number of bytes written.
func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
totalBytes := 0
// Encode the message payload itself into a temporary buffer.
// TODO(roasbeef): create buffer pool
var bw bytes.Buffer
if err := msg.Encode(&bw, pver); err != nil {
return totalBytes, err
}
payload := bw.Bytes()
lenp := len(payload)
// Enforce maximum overall message payload.
if lenp > MaxMessagePayload {
return totalBytes, fmt.Errorf("message payload is too large - "+
"encoded %d bytes, but maximum message payload is %d bytes",
lenp, MaxMessagePayload)
}
// Enforce maximum message payload on the message type.
mpl := msg.MaxPayloadLength(pver)
if uint32(lenp) > mpl {
return totalBytes, fmt.Errorf("message payload is too large - "+
"encoded %d bytes, but maximum message payload of "+
"type %v is %d bytes", lenp, msg.MsgType(), mpl)
}
// With the initial sanity checks complete, we'll now write out the
// message type itself.
var mType [2]byte
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
n, err := w.Write(mType[:])
totalBytes += n
if err != nil {
return totalBytes, err
}
// With the message type written, we'll now write out the raw payload
// itself.
n, err = w.Write(payload)
totalBytes += n
return totalBytes, err
}
// ReadMessage reads, validates, and parses the next Lightning message from r
// for the provided protocol version.
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
// First, we'll read out the first two bytes of the message so we can
// create the proper empty message.
var mType [2]byte
if _, err := io.ReadFull(r, mType[:]); err != nil {
return nil, err
}
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
// Now that we know the target message type, we can create the proper
// empty message type and decode the message into it.
msg, err := makeEmptyMessage(msgType)
if err != nil {
return nil, err
}
if err := msg.Decode(r, pver); err != nil {
return nil, err
}
return msg, nil
}
package lnwire
import (
"fmt"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// mSatScale is a value that's used to scale satoshis to milli-satoshis, and
// the other way around.
mSatScale uint64 = 1000
// MaxMilliSatoshi is the maximum number of msats that can be expressed
// in this data type.
MaxMilliSatoshi = ^MilliSatoshi(0)
)
// MilliSatoshi are the native unit of the Lightning Network. A milli-satoshi
// is simply 1/1000th of a satoshi. There are 1000 milli-satoshis in a single
// satoshi. Within the network, all HTLC payments are denominated in
// milli-satoshis. As milli-satoshis aren't deliverable on the native
// blockchain, before settling to broadcasting, the values are rounded down to
// the nearest satoshi.
type MilliSatoshi uint64
// NewMSatFromSatoshis creates a new MilliSatoshi instance from a target amount
// of satoshis.
func NewMSatFromSatoshis(sat btcutil.Amount) MilliSatoshi {
return MilliSatoshi(uint64(sat) * mSatScale)
}
// ToBTC converts the target MilliSatoshi amount to its corresponding value
// when expressed in BTC.
func (m MilliSatoshi) ToBTC() float64 {
sat := m.ToSatoshis()
return sat.ToBTC()
}
// ToSatoshis converts the target MilliSatoshi amount to satoshis. Simply, this
// sheds a factor of 1000 from the mSAT amount in order to convert it to SAT.
func (m MilliSatoshi) ToSatoshis() btcutil.Amount {
return btcutil.Amount(uint64(m) / mSatScale)
}
// String returns the string representation of the mSAT amount.
func (m MilliSatoshi) String() string {
return fmt.Sprintf("%v mSAT", uint64(m))
}
// TODO(roasbeef): extend with arithmetic operations?
// Record returns a TLV record that can be used to encode/decode a MilliSatoshi
// to/from a TLV stream.
func (m *MilliSatoshi) Record() tlv.Record {
return tlv.MakeDynamicRecord(
0, m, tlv.SizeBigSize(m), encodeMilliSatoshis,
decodeMilliSatoshis,
)
}
func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*MilliSatoshi); ok {
bigSize := uint64(*v)
return tlv.EBigSize(w, &bigSize, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi")
}
func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*MilliSatoshi); ok {
var bigSize uint64
err := tlv.DBigSize(r, &bigSize, buf, l)
if err != nil {
return err
}
*v = MilliSatoshi(bigSize)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l)
}
package lnwire
import (
"fmt"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
)
// NetAddress represents information pertaining to the identity and network
// reachability of a peer. Information stored includes the node's identity
// public key for establishing a confidential+authenticated connection, the
// service bits it supports, and a TCP address the node is reachable at.
//
// TODO(roasbeef): merge with LinkNode in some fashion
type NetAddress struct {
// IdentityKey is the long-term static public key for a node. This node is
// used throughout the network as a node's identity key. It is used to
// authenticate any data sent to the network on behalf of the node, and
// additionally to establish a confidential+authenticated connection with
// the node.
IdentityKey *btcec.PublicKey
// Address is the IP address and port of the node. This is left
// general so that multiple implementations can be used.
Address net.Addr
// ChainNet is the Bitcoin network this node is associated with.
// TODO(roasbeef): make a slice in the future for multi-chain
ChainNet wire.BitcoinNet
}
// A compile time assertion to ensure that NetAddress meets the net.Addr
// interface.
var _ net.Addr = (*NetAddress)(nil)
// String returns a human readable string describing the target NetAddress. The
// current string format is: <pubkey>@host.
//
// This part of the net.Addr interface.
func (n *NetAddress) String() string {
// TODO(roasbeef): use base58?
pubkey := n.IdentityKey.SerializeCompressed()
return fmt.Sprintf("%x@%v", pubkey, n.Address)
}
// Network returns the name of the network this address is bound to.
//
// This part of the net.Addr interface.
func (n *NetAddress) Network() string {
return n.Address.Network()
}
package lnwire
import (
"bytes"
"fmt"
"image/color"
"io"
"net"
"unicode/utf8"
)
// ErrUnknownAddrType is an error returned if we encounter an unknown address type
// when parsing addresses.
type ErrUnknownAddrType struct {
addrType addressType
}
// Error returns a human readable string describing the error.
//
// NOTE: implements the error interface.
func (e ErrUnknownAddrType) Error() string {
return fmt.Sprintf("unknown address type: %v", e.addrType)
}
// ErrInvalidNodeAlias is an error returned if a node alias we parse on the
// wire is invalid, as in it has non UTF-8 characters.
type ErrInvalidNodeAlias struct{}
// Error returns a human readable string describing the error.
//
// NOTE: implements the error interface.
func (e ErrInvalidNodeAlias) Error() string {
return "node alias has non-utf8 characters"
}
// NodeAlias is a hex encoded UTF-8 string that may be displayed as an
// alternative to the node's ID. Notice that aliases are not unique and may be
// freely chosen by the node operators.
type NodeAlias [32]byte
// NewNodeAlias creates a new instance of a NodeAlias. Verification is
// performed on the passed string to ensure it meets the alias requirements.
func NewNodeAlias(s string) (NodeAlias, error) {
var n NodeAlias
if len(s) > 32 {
return n, fmt.Errorf("alias too large: max is %v, got %v", 32,
len(s))
}
if !utf8.ValidString(s) {
return n, &ErrInvalidNodeAlias{}
}
copy(n[:], []byte(s))
return n, nil
}
// String returns a utf8 string representation of the alias bytes.
func (n NodeAlias) String() string {
// Trim trailing zero-bytes for presentation
return string(bytes.Trim(n[:], "\x00"))
}
// NodeAnnouncement message is used to announce the presence of a Lightning
// node and also to signal that the node is accepting incoming connections.
// Each NodeAnnouncement authenticating the advertised information within the
// announcement via a signature using the advertised node pubkey.
type NodeAnnouncement struct {
// Signature is used to prove the ownership of node id.
Signature Sig
// Features is the list of protocol features this node supports.
Features *RawFeatureVector
// Timestamp allows ordering in the case of multiple announcements.
Timestamp uint32
// NodeID is a public key which is used as node identification.
NodeID [33]byte
// RGBColor is used to customize their node's appearance in maps and
// graphs
RGBColor color.RGBA
// Alias is used to customize their node's appearance in maps and
// graphs
Alias NodeAlias
// Address includes two specification fields: 'ipv6' and 'port' on
// which the node is accepting incoming connections.
Addresses []net.Addr
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// A compile time check to ensure NodeAnnouncement implements the
// lnwire.Message interface.
var _ Message = (*NodeAnnouncement)(nil)
// Decode deserializes a serialized NodeAnnouncement stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.Signature,
&a.Features,
&a.Timestamp,
&a.NodeID,
&a.RGBColor,
&a.Alias,
&a.Addresses,
)
if err != nil {
return err
}
// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = io.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}
return nil
}
// Encode serializes the target NodeAnnouncement into the passed io.Writer
// observing the protocol version specified.
func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
a.Signature,
a.Features,
a.Timestamp,
a.NodeID,
a.RGBColor,
a.Alias,
a.Addresses,
a.ExtraOpaqueData,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) MsgType() MessageType {
return MsgNodeAnnouncement
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) MaxPayloadLength(pver uint32) uint32 {
return 65533
}
// DataToSign returns the part of the message that should be signed.
func (a *NodeAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
var w bytes.Buffer
err := WriteElements(&w,
a.Features,
a.Timestamp,
a.NodeID,
a.RGBColor,
a.Alias[:],
a.Addresses,
a.ExtraOpaqueData,
)
if err != nil {
return nil, err
}
return w.Bytes(), nil
}
package lnwire
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/tlv"
)
// FailureMessage represents the onion failure object identified by its unique
// failure code.
type FailureMessage interface {
// Code returns a failure code describing the exact nature of the
// error.
Code() FailCode
// Error returns a human readable string describing the error. With
// this method, the FailureMessage interface meets the built-in error
// interface.
Error() string
}
// FailureMessageLength is the size of the failure message plus the size of
// padding. The FailureMessage message should always be EXACTLY this size.
const FailureMessageLength = 256
const (
// FlagBadOnion error flag describes an unparsable, encrypted by
// previous node.
FlagBadOnion FailCode = 0x8000
// FlagPerm error flag indicates a permanent failure.
FlagPerm FailCode = 0x4000
// FlagNode error flag indicates a node failure.
FlagNode FailCode = 0x2000
// FlagUpdate error flag indicates a new channel update is enclosed
// within the error.
FlagUpdate FailCode = 0x1000
)
// FailCode specifies the precise reason that an upstream HTLC was canceled.
// Each UpdateFailHTLC message carries a FailCode which is to be passed
// backwards, encrypted at each step back to the source of the HTLC within the
// route.
type FailCode uint16
// The currently defined onion failure types within this current version of the
// Lightning protocol.
//
//nolint:ll
const (
CodeNone FailCode = 0
CodeInvalidRealm = FlagBadOnion | 1
CodeTemporaryNodeFailure = FlagNode | 2
CodePermanentNodeFailure = FlagPerm | FlagNode | 2
CodeRequiredNodeFeatureMissing = FlagPerm | FlagNode | 3
CodeInvalidOnionVersion = FlagBadOnion | FlagPerm | 4
CodeInvalidOnionHmac = FlagBadOnion | FlagPerm | 5
CodeInvalidOnionKey = FlagBadOnion | FlagPerm | 6
CodeTemporaryChannelFailure = FlagUpdate | 7
CodePermanentChannelFailure = FlagPerm | 8
CodeRequiredChannelFeatureMissing = FlagPerm | 9
CodeUnknownNextPeer = FlagPerm | 10
CodeAmountBelowMinimum = FlagUpdate | 11
CodeFeeInsufficient = FlagUpdate | 12
CodeIncorrectCltvExpiry = FlagUpdate | 13
CodeExpiryTooSoon = FlagUpdate | 14
CodeChannelDisabled = FlagUpdate | 20
CodeIncorrectOrUnknownPaymentDetails = FlagPerm | 15
CodeIncorrectPaymentAmount = FlagPerm | 16
CodeFinalExpiryTooSoon FailCode = 17
CodeFinalIncorrectCltvExpiry FailCode = 18
CodeFinalIncorrectHtlcAmount FailCode = 19
CodeExpiryTooFar FailCode = 21
CodeInvalidOnionPayload = FlagPerm | 22
CodeMPPTimeout FailCode = 23
CodeInvalidBlinding = FlagBadOnion | FlagPerm | 24
)
// String returns the string representation of the failure code.
func (c FailCode) String() string {
switch c {
case CodeInvalidRealm:
return "InvalidRealm"
case CodeTemporaryNodeFailure:
return "TemporaryNodeFailure"
case CodePermanentNodeFailure:
return "PermanentNodeFailure"
case CodeRequiredNodeFeatureMissing:
return "RequiredNodeFeatureMissing"
case CodeInvalidOnionVersion:
return "InvalidOnionVersion"
case CodeInvalidOnionHmac:
return "InvalidOnionHmac"
case CodeInvalidOnionKey:
return "InvalidOnionKey"
case CodeTemporaryChannelFailure:
return "TemporaryChannelFailure"
case CodePermanentChannelFailure:
return "PermanentChannelFailure"
case CodeRequiredChannelFeatureMissing:
return "RequiredChannelFeatureMissing"
case CodeUnknownNextPeer:
return "UnknownNextPeer"
case CodeAmountBelowMinimum:
return "AmountBelowMinimum"
case CodeFeeInsufficient:
return "FeeInsufficient"
case CodeIncorrectCltvExpiry:
return "IncorrectCltvExpiry"
case CodeIncorrectPaymentAmount:
return "IncorrectPaymentAmount"
case CodeExpiryTooSoon:
return "ExpiryTooSoon"
case CodeChannelDisabled:
return "ChannelDisabled"
case CodeIncorrectOrUnknownPaymentDetails:
return "IncorrectOrUnknownPaymentDetails"
case CodeFinalExpiryTooSoon:
return "FinalExpiryTooSoon"
case CodeFinalIncorrectCltvExpiry:
return "FinalIncorrectCltvExpiry"
case CodeFinalIncorrectHtlcAmount:
return "FinalIncorrectHtlcAmount"
case CodeExpiryTooFar:
return "ExpiryTooFar"
case CodeInvalidOnionPayload:
return "InvalidOnionPayload"
case CodeMPPTimeout:
return "MPPTimeout"
case CodeInvalidBlinding:
return "InvalidBlinding"
default:
return "<unknown>"
}
}
// FailInvalidRealm is returned if the realm byte is unknown.
//
// NOTE: May be returned by any node in the payment route.
type FailInvalidRealm struct{}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidRealm) Error() string {
return f.Code().String()
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidRealm) Code() FailCode {
return CodeInvalidRealm
}
// FailTemporaryNodeFailure is returned if an otherwise unspecified transient
// error occurs for the entire node.
//
// NOTE: May be returned by any node in the payment route.
type FailTemporaryNodeFailure struct{}
// Code returns the failure unique code.
// NOTE: Part of the FailureMessage interface.
func (f *FailTemporaryNodeFailure) Code() FailCode {
return CodeTemporaryNodeFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailTemporaryNodeFailure) Error() string {
return f.Code().String()
}
// FailPermanentNodeFailure is returned if an otherwise unspecified permanent
// error occurs for the entire node.
//
// NOTE: May be returned by any node in the payment route.
type FailPermanentNodeFailure struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailPermanentNodeFailure) Code() FailCode {
return CodePermanentNodeFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailPermanentNodeFailure) Error() string {
return f.Code().String()
}
// FailRequiredNodeFeatureMissing is returned if a node has requirement
// advertised in its node_announcement features which were not present in the
// onion.
//
// NOTE: May be returned by any node in the payment route.
type FailRequiredNodeFeatureMissing struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailRequiredNodeFeatureMissing) Code() FailCode {
return CodeRequiredNodeFeatureMissing
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailRequiredNodeFeatureMissing) Error() string {
return f.Code().String()
}
// FailPermanentChannelFailure is return if an otherwise unspecified permanent
// error occurs for the outgoing channel (eg. channel (recently).
//
// NOTE: May be returned by any node in the payment route.
type FailPermanentChannelFailure struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailPermanentChannelFailure) Code() FailCode {
return CodePermanentChannelFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailPermanentChannelFailure) Error() string {
return f.Code().String()
}
// FailRequiredChannelFeatureMissing is returned if the outgoing channel has a
// requirement advertised in its channel announcement features which were not
// present in the onion.
//
// NOTE: May only be returned by intermediate nodes.
type FailRequiredChannelFeatureMissing struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailRequiredChannelFeatureMissing) Code() FailCode {
return CodeRequiredChannelFeatureMissing
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailRequiredChannelFeatureMissing) Error() string {
return f.Code().String()
}
// FailUnknownNextPeer is returned if the next peer specified by the onion is
// not known.
//
// NOTE: May only be returned by intermediate nodes.
type FailUnknownNextPeer struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailUnknownNextPeer) Code() FailCode {
return CodeUnknownNextPeer
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailUnknownNextPeer) Error() string {
return f.Code().String()
}
// FailIncorrectPaymentAmount is returned if the amount paid is less than the
// amount expected, the final node MUST fail the HTLC. If the amount paid is
// more than twice the amount expected, the final node SHOULD fail the HTLC.
// This allows the sender to reduce information leakage by altering the amount,
// without allowing accidental gross overpayment.
//
// NOTE: May only be returned by the final node in the path.
type FailIncorrectPaymentAmount struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectPaymentAmount) Code() FailCode {
return CodeIncorrectPaymentAmount
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailIncorrectPaymentAmount) Error() string {
return f.Code().String()
}
// FailIncorrectDetails is returned for two reasons:
//
// 1) if the payment hash has already been paid, the final node MAY treat the
// payment hash as unknown, or may succeed in accepting the HTLC. If the
// payment hash is unknown, the final node MUST fail the HTLC.
//
// 2) if the amount paid is less than the amount expected, the final node MUST
// fail the HTLC. If the amount paid is more than twice the amount expected,
// the final node SHOULD fail the HTLC. This allows the sender to reduce
// information leakage by altering the amount, without allowing accidental
// gross overpayment.
//
// NOTE: May only be returned by the final node in the path.
type FailIncorrectDetails struct {
// amount is the value of the extended HTLC.
amount MilliSatoshi
// height is the block height when the htlc was received.
height uint32
}
// NewFailIncorrectDetails makes a new instance of the FailIncorrectDetails
// error bound to the specified HTLC amount and acceptance height.
func NewFailIncorrectDetails(amt MilliSatoshi,
height uint32) *FailIncorrectDetails {
return &FailIncorrectDetails{
amount: amt,
height: height,
}
}
// Amount is the value of the extended HTLC.
func (f *FailIncorrectDetails) Amount() MilliSatoshi {
return f.amount
}
// Height is the block height when the htlc was received.
func (f *FailIncorrectDetails) Height() uint32 {
return f.height
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectDetails) Code() FailCode {
return CodeIncorrectOrUnknownPaymentDetails
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailIncorrectDetails) Error() string {
return fmt.Sprintf(
"%v(amt=%v, height=%v)", CodeIncorrectOrUnknownPaymentDetails,
f.amount, f.height,
)
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error {
err := ReadElement(r, &f.amount)
switch {
// This is an optional tack on that was added later in the protocol. As
// a result, older nodes may not include this value. We'll account for
// this by checking for io.EOF here which means that no bytes were read
// at all.
case err == io.EOF:
return nil
case err != nil:
return err
}
// At a later stage, the height field was also tacked on. We need to
// check for io.EOF here as well.
err = ReadElement(r, &f.height)
switch {
case err == io.EOF:
return nil
case err != nil:
return err
}
return nil
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectDetails) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, f.amount, f.height)
}
// FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final
// node MUST fail the HTLC.
//
// NOTE: May only be returned by the final node in the path.
type FailFinalExpiryTooSoon struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalExpiryTooSoon) Code() FailCode {
return CodeFinalExpiryTooSoon
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalExpiryTooSoon) Error() string {
return f.Code().String()
}
// NewFinalExpiryTooSoon creates new instance of the FailFinalExpiryTooSoon.
func NewFinalExpiryTooSoon() *FailFinalExpiryTooSoon {
return &FailFinalExpiryTooSoon{}
}
// FailInvalidOnionVersion is returned if the onion version byte is unknown.
//
// NOTE: May be returned only by intermediate nodes.
type FailInvalidOnionVersion struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionVersion) Error() string {
return fmt.Sprintf("InvalidOnionVersion(onion_sha=%x)", f.OnionSHA256[:])
}
// NewInvalidOnionVersion creates new instance of the FailInvalidOnionVersion.
func NewInvalidOnionVersion(onion []byte) *FailInvalidOnionVersion {
return &FailInvalidOnionVersion{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionVersion) Code() FailCode {
return CodeInvalidOnionVersion
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionVersion) Encode(w io.Writer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}
// FailInvalidOnionHmac is return if the onion HMAC is incorrect.
//
// NOTE: May only be returned by intermediate nodes.
type FailInvalidOnionHmac struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// NewInvalidOnionHmac creates new instance of the FailInvalidOnionHmac.
func NewInvalidOnionHmac(onion []byte) *FailInvalidOnionHmac {
return &FailInvalidOnionHmac{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionHmac) Code() FailCode {
return CodeInvalidOnionHmac
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionHmac) Encode(w io.Writer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionHmac) Error() string {
return fmt.Sprintf("InvalidOnionHMAC(onion_sha=%x)", f.OnionSHA256[:])
}
// FailInvalidOnionKey is return if the ephemeral key in the onion is
// unparsable.
//
// NOTE: May only be returned by intermediate nodes.
type FailInvalidOnionKey struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// NewInvalidOnionKey creates new instance of the FailInvalidOnionKey.
func NewInvalidOnionKey(onion []byte) *FailInvalidOnionKey {
return &FailInvalidOnionKey{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionKey) Code() FailCode {
return CodeInvalidOnionKey
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionKey) Encode(w io.Writer, pver uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionKey) Error() string {
return fmt.Sprintf("InvalidOnionKey(onion_sha=%x)", f.OnionSHA256[:])
}
// FailInvalidBlinding is returned if there has been a route blinding related
// error.
type FailInvalidBlinding struct {
OnionSHA256 [sha256.Size]byte
}
// A compile-time check to ensure that FailInvalidBlinding implements the
// Serializable interface.
var _ Serializable = (*FailInvalidBlinding)(nil)
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidBlinding) Code() FailCode {
return CodeInvalidBlinding
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidBlinding) Error() string {
return f.Code().String()
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Decode(r io.Reader, _ uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Encode(w io.Writer, _ uint32) error {
return WriteElement(w, f.OnionSHA256[:])
}
// NewInvalidBlinding creates new instance of FailInvalidBlinding.
func NewInvalidBlinding(onion []byte) *FailInvalidBlinding {
// The spec allows empty onion hashes for invalid blinding, so we only
// include our onion hash if it's provided.
if onion == nil {
return &FailInvalidBlinding{}
}
return &FailInvalidBlinding{OnionSHA256: sha256.Sum256(onion)}
}
// parseChannelUpdateCompatabilityMode will attempt to parse a channel updated
// encoded into an onion error payload in two ways. First, we'll try the
// compatibility oriented version wherein we'll _skip_ the length prefixing on
// the channel update message. Older versions of c-lighting do this so we'll
// attempt to parse these messages in order to retain compatibility. If we're
// unable to pull out a fully valid version, then we'll fall back to the
// regular parsing mechanism which includes the length prefix an NO type byte.
func parseChannelUpdateCompatabilityMode(r *bufio.Reader,
chanUpdate *ChannelUpdate, pver uint32) error {
// We'll peek out two bytes from the buffer without advancing the
// buffer so we can decide how to parse the remainder of it.
maybeTypeBytes, err := r.Peek(2)
if err != nil {
return err
}
// Some nodes well prefix an additional set of bytes in front of their
// channel updates. These bytes will _almost_ always be 258 or the type
// of the ChannelUpdate message.
typeInt := binary.BigEndian.Uint16(maybeTypeBytes)
if typeInt == MsgChannelUpdate {
// At this point it's likely the case that this is a channel
// update message with its type prefixed, so we'll snip off the
// first two bytes and parse it as normal.
var throwAwayTypeBytes [2]byte
_, err := r.Read(throwAwayTypeBytes[:])
if err != nil {
return err
}
}
// At this pint, we've either decided to keep the entire thing, or snip
// off the first two bytes. In either case, we can just read it as
// normal.
return chanUpdate.Decode(r, pver)
}
// FailTemporaryChannelFailure is if an otherwise unspecified transient error
// occurs for the outgoing channel (eg. channel capacity reached, too many
// in-flight htlcs)
//
// NOTE: May only be returned by intermediate nodes.
type FailTemporaryChannelFailure struct {
// Update is used to update information about state of the channel
// which caused the failure.
//
// NOTE: This field is optional.
Update *ChannelUpdate
}
// NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure.
func NewTemporaryChannelFailure(update *ChannelUpdate) *FailTemporaryChannelFailure {
return &FailTemporaryChannelFailure{Update: update}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailTemporaryChannelFailure) Code() FailCode {
return CodeTemporaryChannelFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailTemporaryChannelFailure) Error() string {
if f.Update == nil {
return f.Code().String()
}
return fmt.Sprintf("TemporaryChannelFailure(update=%v)",
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error {
var length uint16
err := ReadElement(r, &length)
if err != nil {
return err
}
if length != 0 {
f.Update = &ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), f.Update, pver,
)
}
return nil
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailTemporaryChannelFailure) Encode(w io.Writer, pver uint32) error {
var payload []byte
if f.Update != nil {
var bw bytes.Buffer
if err := f.Update.Encode(&bw, pver); err != nil {
return err
}
payload = bw.Bytes()
}
if err := WriteElement(w, uint16(len(payload))); err != nil {
return err
}
_, err := w.Write(payload)
return err
}
// FailAmountBelowMinimum is returned if the HTLC does not reach the current
// minimum amount, we tell them the amount of the incoming HTLC and the current
// channel setting for the outgoing channel.
//
// NOTE: May only be returned by the intermediate nodes in the path.
type FailAmountBelowMinimum struct {
// HtlcMsat is the wrong amount of the incoming HTLC.
HtlcMsat MilliSatoshi
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate
}
// NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum.
func NewAmountBelowMinimum(htlcMsat MilliSatoshi,
update ChannelUpdate) *FailAmountBelowMinimum {
return &FailAmountBelowMinimum{
HtlcMsat: htlcMsat,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailAmountBelowMinimum) Code() FailCode {
return CodeAmountBelowMinimum
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailAmountBelowMinimum) Error() string {
return fmt.Sprintf("AmountBelowMinimum(amt=%v, update=%v", f.HtlcMsat,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.HtlcMsat); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error {
if err := WriteElement(w, f.HtlcMsat); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we
// tell them the amount of the incoming HTLC and the current channel setting
// for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailFeeInsufficient struct {
// HtlcMsat is the wrong amount of the incoming HTLC.
HtlcMsat MilliSatoshi
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate
}
// NewFeeInsufficient creates new instance of the FailFeeInsufficient.
func NewFeeInsufficient(htlcMsat MilliSatoshi,
update ChannelUpdate) *FailFeeInsufficient {
return &FailFeeInsufficient{
HtlcMsat: htlcMsat,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFeeInsufficient) Code() FailCode {
return CodeFeeInsufficient
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFeeInsufficient) Error() string {
return fmt.Sprintf("FeeInsufficient(htlc_amt==%v, update=%v", f.HtlcMsat,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.HtlcMsat); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error {
if err := WriteElement(w, f.HtlcMsat); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match
// the update add htlc's cltv expiry minus cltv expiry delta for the outgoing
// channel, we tell them the cltv expiry and the current channel setting for
// the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailIncorrectCltvExpiry struct {
// CltvExpiry is the wrong absolute timeout in blocks, after which
// outgoing HTLC expires.
CltvExpiry uint32
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate
}
// NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry.
func NewIncorrectCltvExpiry(cltvExpiry uint32,
update ChannelUpdate) *FailIncorrectCltvExpiry {
return &FailIncorrectCltvExpiry{
CltvExpiry: cltvExpiry,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectCltvExpiry) Code() FailCode {
return CodeIncorrectCltvExpiry
}
func (f *FailIncorrectCltvExpiry) Error() string {
return fmt.Sprintf("IncorrectCltvExpiry(expiry=%v, update=%v",
f.CltvExpiry, spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.CltvExpiry); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error {
if err := WriteElement(w, f.CltvExpiry); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them
// the current channel setting for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailExpiryTooSoon struct {
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate
}
// NewExpiryTooSoon creates new instance of the FailExpiryTooSoon.
func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon {
return &FailExpiryTooSoon{
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailExpiryTooSoon) Code() FailCode {
return CodeExpiryTooSoon
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailExpiryTooSoon) Error() string {
return fmt.Sprintf("ExpiryTooSoon(update=%v", spew.Sdump(f.Update))
}
// Decode decodes the failure from l stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error {
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error {
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailChannelDisabled is returned if the channel is disabled, we tell them the
// current channel setting for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailChannelDisabled struct {
// Flags least-significant bit must be set to 0 if the creating node
// corresponds to the first node in the previously sent channel
// announcement and 1 otherwise.
Flags uint16
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate
}
// NewChannelDisabled creates new instance of the FailChannelDisabled.
func NewChannelDisabled(flags uint16, update ChannelUpdate) *FailChannelDisabled {
return &FailChannelDisabled{
Flags: flags,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailChannelDisabled) Code() FailCode {
return CodeChannelDisabled
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailChannelDisabled) Error() string {
return fmt.Sprintf("ChannelDisabled(flags=%v, update=%v", f.Flags,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.Flags); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate{}
return parseChannelUpdateCompatabilityMode(
bufio.NewReader(r), &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error {
if err := WriteElement(w, f.Flags); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not
// match the ctlv_expiry of the HTLC at the final hop.
//
// NOTE: might be returned by final node only.
type FailFinalIncorrectCltvExpiry struct {
// CltvExpiry is the wrong absolute timeout in blocks, after which
// outgoing HTLC expires.
CltvExpiry uint32
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalIncorrectCltvExpiry) Error() string {
return fmt.Sprintf("FinalIncorrectCltvExpiry(expiry=%v)", f.CltvExpiry)
}
// NewFinalIncorrectCltvExpiry creates new instance of the
// FailFinalIncorrectCltvExpiry.
func NewFinalIncorrectCltvExpiry(cltvExpiry uint32) *FailFinalIncorrectCltvExpiry {
return &FailFinalIncorrectCltvExpiry{
CltvExpiry: cltvExpiry,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalIncorrectCltvExpiry) Code() FailCode {
return CodeFinalIncorrectCltvExpiry
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, &f.CltvExpiry)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error {
return WriteElement(w, f.CltvExpiry)
}
// FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher
// than incoming_htlc_amt of the HTLC at the final hop.
//
// NOTE: May only be returned by the final node.
type FailFinalIncorrectHtlcAmount struct {
// IncomingHTLCAmount is the wrong forwarded htlc amount.
IncomingHTLCAmount MilliSatoshi
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalIncorrectHtlcAmount) Error() string {
return fmt.Sprintf("FinalIncorrectHtlcAmount(amt=%v)",
f.IncomingHTLCAmount)
}
// NewFinalIncorrectHtlcAmount creates new instance of the
// FailFinalIncorrectHtlcAmount.
func NewFinalIncorrectHtlcAmount(amount MilliSatoshi) *FailFinalIncorrectHtlcAmount {
return &FailFinalIncorrectHtlcAmount{
IncomingHTLCAmount: amount,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalIncorrectHtlcAmount) Code() FailCode {
return CodeFinalIncorrectHtlcAmount
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, &f.IncomingHTLCAmount)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error {
return WriteElement(w, f.IncomingHTLCAmount)
}
// FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the
// future.
//
// NOTE: May be returned by any node in the payment route.
type FailExpiryTooFar struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailExpiryTooFar) Code() FailCode {
return CodeExpiryTooFar
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailExpiryTooFar) Error() string {
return f.Code().String()
}
// InvalidOnionPayload is returned if the hop could not process the TLV payload
// enclosed in the onion.
type InvalidOnionPayload struct {
// Type is the TLV type that caused the specific failure.
Type uint64
// Offset is the byte offset within the payload where the failure
// occurred.
Offset uint16
}
// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure.
func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload {
return &InvalidOnionPayload{
Type: typ,
Offset: offset,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *InvalidOnionPayload) Code() FailCode {
return CodeInvalidOnionPayload
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *InvalidOnionPayload) Error() string {
return fmt.Sprintf("%v(type=%v, offset=%d)",
f.Code(), f.Type, f.Offset)
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error {
var buf [8]byte
typ, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
f.Type = typ
return ReadElements(r, &f.Offset)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error {
var buf [8]byte
if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil {
return err
}
return WriteElements(w, f.Offset)
}
// FailMPPTimeout is returned if the complete amount for a multi part payment
// was not received within a reasonable time.
//
// NOTE: May only be returned by the final node in the path.
type FailMPPTimeout struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailMPPTimeout) Code() FailCode {
return CodeMPPTimeout
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailMPPTimeout) Error() string {
return f.Code().String()
}
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for
// the provided protocol version.
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
// First, we'll parse out the encapsulated failure message itself. This
// is a 2 byte length followed by the payload itself.
var failureLength uint16
if err := ReadElement(r, &failureLength); err != nil {
return nil, fmt.Errorf("unable to read error len: %w", err)
}
if failureLength > FailureMessageLength {
return nil, fmt.Errorf("failure message is too "+
"long: %v", failureLength)
}
failureData := make([]byte, failureLength)
if _, err := io.ReadFull(r, failureData); err != nil {
return nil, fmt.Errorf("unable to full read payload of "+
"%v: %v", failureLength, err)
}
dataReader := bytes.NewReader(failureData)
return DecodeFailureMessage(dataReader, pver)
}
// DecodeFailureMessage decodes just the failure message, ignoring any padding
// that may be present at the end.
func DecodeFailureMessage(r io.Reader, pver uint32) (FailureMessage, error) {
// Once we have the failure data, we can obtain the failure code from
// the first two bytes of the buffer.
var codeBytes [2]byte
if _, err := io.ReadFull(r, codeBytes[:]); err != nil {
return nil, fmt.Errorf("unable to read failure code: %w", err)
}
failCode := FailCode(binary.BigEndian.Uint16(codeBytes[:]))
// Create the empty failure by given code and populate the failure with
// additional data if needed.
failure, err := makeEmptyOnionError(failCode)
if err != nil {
return nil, fmt.Errorf("unable to make empty error: %w", err)
}
// Finally, if this failure has a payload, then we'll read that now as
// well.
switch f := failure.(type) {
case Serializable:
if err := f.Decode(r, pver); err != nil {
return nil, fmt.Errorf("unable to decode error "+
"update (type=%T): %v", failure, err)
}
}
return failure, nil
}
// EncodeFailure encodes, including the necessary onion failure header
// information.
func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error {
var failureMessageBuffer bytes.Buffer
err := EncodeFailureMessage(&failureMessageBuffer, failure, pver)
if err != nil {
return err
}
// The combined size of this message must be below the max allowed
// failure message length.
failureMessage := failureMessageBuffer.Bytes()
if len(failureMessage) > FailureMessageLength {
return fmt.Errorf("failure message exceed max "+
"available size: %v", len(failureMessage))
}
// Finally, we'll add some padding in order to ensure that all failure
// messages are fixed size.
pad := make([]byte, FailureMessageLength-len(failureMessage))
return WriteElements(w,
uint16(len(failureMessage)),
failureMessage,
uint16(len(pad)),
pad,
)
}
// EncodeFailureMessage encodes just the failure message without adding a length
// and padding the message for the onion protocol.
func EncodeFailureMessage(w io.Writer, failure FailureMessage, pver uint32) error {
// First, we'll write out the error code itself into the failure
// buffer.
var codeBytes [2]byte
code := uint16(failure.Code())
binary.BigEndian.PutUint16(codeBytes[:], code)
_, err := w.Write(codeBytes[:])
if err != nil {
return err
}
// Next, some message have an additional message payload, if this is
// one of those types, then we'll also encode the error payload as
// well.
switch failure := failure.(type) {
case Serializable:
if err := failure.Encode(w, pver); err != nil {
return err
}
}
return nil
}
// makeEmptyOnionError creates a new empty onion error of the proper concrete
// type based on the passed failure code.
func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
switch code {
case CodeInvalidRealm:
return &FailInvalidRealm{}, nil
case CodeTemporaryNodeFailure:
return &FailTemporaryNodeFailure{}, nil
case CodePermanentNodeFailure:
return &FailPermanentNodeFailure{}, nil
case CodeRequiredNodeFeatureMissing:
return &FailRequiredNodeFeatureMissing{}, nil
case CodePermanentChannelFailure:
return &FailPermanentChannelFailure{}, nil
case CodeRequiredChannelFeatureMissing:
return &FailRequiredChannelFeatureMissing{}, nil
case CodeUnknownNextPeer:
return &FailUnknownNextPeer{}, nil
case CodeIncorrectOrUnknownPaymentDetails:
return &FailIncorrectDetails{}, nil
case CodeIncorrectPaymentAmount:
return &FailIncorrectPaymentAmount{}, nil
case CodeFinalExpiryTooSoon:
return &FailFinalExpiryTooSoon{}, nil
case CodeInvalidOnionVersion:
return &FailInvalidOnionVersion{}, nil
case CodeInvalidOnionHmac:
return &FailInvalidOnionHmac{}, nil
case CodeInvalidOnionKey:
return &FailInvalidOnionKey{}, nil
case CodeTemporaryChannelFailure:
return &FailTemporaryChannelFailure{}, nil
case CodeAmountBelowMinimum:
return &FailAmountBelowMinimum{}, nil
case CodeFeeInsufficient:
return &FailFeeInsufficient{}, nil
case CodeIncorrectCltvExpiry:
return &FailIncorrectCltvExpiry{}, nil
case CodeExpiryTooSoon:
return &FailExpiryTooSoon{}, nil
case CodeChannelDisabled:
return &FailChannelDisabled{}, nil
case CodeFinalIncorrectCltvExpiry:
return &FailFinalIncorrectCltvExpiry{}, nil
case CodeFinalIncorrectHtlcAmount:
return &FailFinalIncorrectHtlcAmount{}, nil
case CodeExpiryTooFar:
return &FailExpiryTooFar{}, nil
case CodeInvalidOnionPayload:
return &InvalidOnionPayload{}, nil
case CodeMPPTimeout:
return &FailMPPTimeout{}, nil
case CodeInvalidBlinding:
return &FailInvalidBlinding{}, nil
default:
return nil, errors.Errorf("unknown error code: %v", code)
}
}
// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error
// format. The format is that we first write out the true serialized length of
// the channel update, followed by the serialized channel update itself.
func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate,
pver uint32) error {
// First, we encode the channel update in a temporary buffer in order
// to get the exact serialized size.
var b bytes.Buffer
if err := chanUpdate.Encode(&b, pver); err != nil {
return err
}
// Now that we know the size, we can write the length out in the main
// writer.
updateLen := b.Len()
if err := WriteElement(w, uint16(updateLen)); err != nil {
return err
}
// With the length written, we'll then write out the serialized channel
// update.
if _, err := w.Write(b.Bytes()); err != nil {
return err
}
return nil
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// FundingFlag represents the possible bit mask values for the ChannelFlags
// field within the OpenChannel struct.
type FundingFlag uint8
const (
// FFAnnounceChannel is a FundingFlag that when set, indicates the
// initiator of a funding flow wishes to announce the channel to the
// greater network.
FFAnnounceChannel FundingFlag = 1 << iota
)
// OpenChannel is the message Alice sends to Bob if we should like to create a
// channel with Bob where she's the sole provider of funds to the channel.
// Single funder channels simplify the initial funding workflow, are supported
// by nodes backed by SPV Bitcoin clients, and have a simpler security models
// than dual funded channels.
type OpenChannel struct {
// ChainHash is the target chain that the initiator wishes to open a
// channel within.
ChainHash chainhash.Hash
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// FundingAmount is the amount of satoshis that the initiator of the
// channel wishes to use as the total capacity of the channel. The
// initial balance of the funding will be this value minus the push
// amount (if set).
FundingAmount btcutil.Amount
// PushAmount is the value that the initiating party wishes to "push"
// to the responding as part of the first commitment state. If the
// responder accepts, then this will be their initial balance.
PushAmount MilliSatoshi
// DustLimit is the specific dust limit the sender of this message
// would like enforced on their version of the commitment transaction.
// Any output below this value will be "trimmed" from the commitment
// transaction, with the amount of the HTLC going to dust.
DustLimit btcutil.Amount
// MaxValueInFlight represents the maximum amount of coins that can be
// pending within the channel at any given time. If the amount of funds
// in limbo exceeds this amount, then the channel will be failed.
MaxValueInFlight MilliSatoshi
// ChannelReserve is the amount of BTC that the receiving party MUST
// maintain a balance above at all times. This is a safety mechanism to
// ensure that both sides always have skin in the game during the
// channel's lifetime.
ChannelReserve btcutil.Amount
// HtlcMinimum is the smallest HTLC that the sender of this message
// will accept.
HtlcMinimum MilliSatoshi
// FeePerKiloWeight is the initial fee rate that the initiator suggests
// for both commitment transaction. This value is expressed in sat per
// kilo-weight.
//
// TODO(halseth): make SatPerKWeight when fee estimation is in own
// package. Currently this will cause an import cycle.
FeePerKiloWeight uint32
// CsvDelay is the number of blocks to use for the relative time lock
// in the pay-to-self output of both commitment transactions.
CsvDelay uint16
// MaxAcceptedHTLCs is the total number of incoming HTLC's that the
// sender of this channel will accept.
MaxAcceptedHTLCs uint16
// FundingKey is the key that should be used on behalf of the sender
// within the 2-of-2 multi-sig output that it contained within the
// funding transaction.
FundingKey *btcec.PublicKey
// RevocationPoint is the base revocation point for the sending party.
// Any commitment transaction belonging to the receiver of this message
// should use this key and their per-commitment point to derive the
// revocation key for the commitment transaction.
RevocationPoint *btcec.PublicKey
// PaymentPoint is the base payment point for the sending party. This
// key should be combined with the per commitment point for a
// particular commitment state in order to create the key that should
// be used in any output that pays directly to the sending party, and
// also within the HTLC covenant transactions.
PaymentPoint *btcec.PublicKey
// DelayedPaymentPoint is the delay point for the sending party. This
// key should be combined with the per commitment point to derive the
// keys that are used in outputs of the sender's commitment transaction
// where they claim funds.
DelayedPaymentPoint *btcec.PublicKey
// HtlcPoint is the base point used to derive the set of keys for this
// party that will be used within the HTLC public key scripts. This
// value is combined with the receiver's revocation base point in order
// to derive the keys that are used within HTLC scripts.
HtlcPoint *btcec.PublicKey
// FirstCommitmentPoint is the first commitment point for the sending
// party. This value should be combined with the receiver's revocation
// base point in order to derive the revocation keys that are placed
// within the commitment transaction of the sender.
FirstCommitmentPoint *btcec.PublicKey
// ChannelFlags is a bit-field which allows the initiator of the
// channel to specify further behavior surrounding the channel.
// Currently, the least significant bit of this bit field indicates the
// initiator of the channel wishes to advertise this channel publicly.
ChannelFlags FundingFlag
// UpfrontShutdownScript is the script to which the channel funds should
// be paid when mutually closing the channel. This field is optional, and
// and has a length prefix, so a zero will be written if it is not set
// and its length followed by the script will be written if it is set.
UpfrontShutdownScript DeliveryAddress
}
// A compile time check to ensure OpenChannel implements the lnwire.Message
// interface.
var _ Message = (*OpenChannel)(nil)
// Encode serializes the target OpenChannel into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
o.ChainHash[:],
o.PendingChannelID[:],
o.FundingAmount,
o.PushAmount,
o.DustLimit,
o.MaxValueInFlight,
o.ChannelReserve,
o.HtlcMinimum,
o.FeePerKiloWeight,
o.CsvDelay,
o.MaxAcceptedHTLCs,
o.FundingKey,
o.RevocationPoint,
o.PaymentPoint,
o.DelayedPaymentPoint,
o.HtlcPoint,
o.FirstCommitmentPoint,
o.ChannelFlags,
o.UpfrontShutdownScript,
)
}
// Decode deserializes the serialized OpenChannel stored in the passed
// io.Reader into the target OpenChannel using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) Decode(r io.Reader, pver uint32) error {
if err := ReadElements(r,
o.ChainHash[:],
o.PendingChannelID[:],
&o.FundingAmount,
&o.PushAmount,
&o.DustLimit,
&o.MaxValueInFlight,
&o.ChannelReserve,
&o.HtlcMinimum,
&o.FeePerKiloWeight,
&o.CsvDelay,
&o.MaxAcceptedHTLCs,
&o.FundingKey,
&o.RevocationPoint,
&o.PaymentPoint,
&o.DelayedPaymentPoint,
&o.HtlcPoint,
&o.FirstCommitmentPoint,
&o.ChannelFlags,
); err != nil {
return err
}
// Check for the optional upfront shutdown script field. If it is not there,
// silence the EOF error.
err := ReadElement(r, &o.UpfrontShutdownScript)
if err != nil && err != io.EOF {
return err
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as an OpenChannel on the wire.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) MsgType() MessageType {
return MsgOpenChannel
}
// MaxPayloadLength returns the maximum allowed payload length for a
// OpenChannel message.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) MaxPayloadLength(uint32) uint32 {
// (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1
var length uint32 = 319 // base length
// Upfront shutdown script max length.
length += 2 + deliveryAddressMaxSize
return length
}
package lnwire
import "io"
// PingPayload is a set of opaque bytes used to pad out a ping message.
type PingPayload []byte
// Ping defines a message which is sent by peers periodically to determine if
// the connection is still valid. Each ping message carries the number of bytes
// to pad the pong response with, and also a number of bytes to be ignored at
// the end of the ping message (which is padding).
type Ping struct {
// NumPongBytes is the number of bytes the pong response to this
// message should carry.
NumPongBytes uint16
// PaddingBytes is a set of opaque bytes used to pad out this ping
// message. Using this field in conjunction to the one above, it's
// possible for node to generate fake cover traffic.
PaddingBytes PingPayload
}
// NewPing returns a new Ping message.
func NewPing(numBytes uint16) *Ping {
return &Ping{
NumPongBytes: numBytes,
}
}
// A compile time check to ensure Ping implements the lnwire.Message interface.
var _ Message = (*Ping)(nil)
// Decode deserializes a serialized Ping message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p *Ping) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&p.NumPongBytes,
&p.PaddingBytes)
}
// Encode serializes the target Ping into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (p *Ping) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
p.NumPongBytes,
p.PaddingBytes)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (p *Ping) MsgType() MessageType {
return MsgPing
}
// MaxPayloadLength returns the maximum allowed payload size for a Ping
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p Ping) MaxPayloadLength(uint32) uint32 {
return 65532
}
package lnwire
import "io"
// PongPayload is a set of opaque bytes sent in response to a ping message.
type PongPayload []byte
// Pong defines a message which is the direct response to a received Ping
// message. A Pong reply indicates that a connection is still active. The Pong
// reply to a Ping message should contain the nonce carried in the original
// Pong message.
type Pong struct {
// PongBytes is a set of opaque bytes that corresponds to the
// NumPongBytes defined in the ping message that this pong is
// replying to.
PongBytes PongPayload
}
// NewPong returns a new Pong message.
func NewPong(pongBytes []byte) *Pong {
return &Pong{
PongBytes: pongBytes,
}
}
// A compile time check to ensure Pong implements the lnwire.Message interface.
var _ Message = (*Pong)(nil)
// Decode deserializes a serialized Pong message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p *Pong) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&p.PongBytes,
)
}
// Encode serializes the target Pong into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (p *Pong) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
p.PongBytes,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (p *Pong) MsgType() MessageType {
return MsgPong
}
// MaxPayloadLength returns the maximum allowed payload size for a Pong
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p *Pong) MaxPayloadLength(uint32) uint32 {
return 65532
}
package lnwire
import (
"io"
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// QueryChannelRange is a message sent by a node in order to query the
// receiving node of the set of open channel they know of with short channel
// ID's after the specified block height, capped at the number of blocks beyond
// that block height. This will be used by nodes upon initial connect to
// synchronize their views of the network.
type QueryChannelRange struct {
// ChainHash denotes the target chain that we're trying to synchronize
// channel graph state for.
ChainHash chainhash.Hash
// FirstBlockHeight is the first block in the query range. The
// responder should send all new short channel IDs from this block
// until this block plus the specified number of blocks.
FirstBlockHeight uint32
// NumBlocks is the number of blocks beyond the first block that short
// channel ID's should be sent for.
NumBlocks uint32
}
// NewQueryChannelRange creates a new empty QueryChannelRange message.
func NewQueryChannelRange() *QueryChannelRange {
return &QueryChannelRange{}
}
// A compile time check to ensure QueryChannelRange implements the
// lnwire.Message interface.
var _ Message = (*QueryChannelRange)(nil)
// Decode deserializes a serialized QueryChannelRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
q.ChainHash[:],
&q.FirstBlockHeight,
&q.NumBlocks,
)
}
// Encode serializes the target QueryChannelRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
q.ChainHash[:],
q.FirstBlockHeight,
q.NumBlocks,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) MsgType() MessageType {
return MsgQueryChannelRange
}
// MaxPayloadLength returns the maximum allowed payload size for a
// QueryChannelRange complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 {
// 32 + 4 + 4
return 40
}
// LastBlockHeight returns the last block height covered by the range of a
// QueryChannelRange message.
func (q *QueryChannelRange) LastBlockHeight() uint32 {
// Handle overflows by casting to uint64.
lastBlockHeight := uint64(q.FirstBlockHeight) + uint64(q.NumBlocks) - 1
if lastBlockHeight > math.MaxUint32 {
return math.MaxUint32
}
return uint32(lastBlockHeight)
}
package lnwire
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sort"
"sync"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ShortChanIDEncoding is an enum-like type that represents exactly how a set
// of short channel ID's is encoded on the wire. The set of encodings allows us
// to take advantage of the structure of a list of short channel ID's to
// achieving a high degree of compression.
type ShortChanIDEncoding uint8
const (
// EncodingSortedPlain signals that the set of short channel ID's is
// encoded using the regular encoding, in a sorted order.
EncodingSortedPlain ShortChanIDEncoding = 0
// EncodingSortedZlib signals that the set of short channel ID's is
// encoded by first sorting the set of channel ID's, as then
// compressing them using zlib.
EncodingSortedZlib ShortChanIDEncoding = 1
)
const (
// maxZlibBufSize is the max number of bytes that we'll accept from a
// zlib decoding instance. We do this in order to limit the total
// amount of memory allocated during a decoding instance.
maxZlibBufSize = 67413630
)
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
// items were not sorted.
type ErrUnsortedSIDs struct {
prevSID ShortChannelID
curSID ShortChannelID
}
// Error returns a human-readable description of the error.
func (e ErrUnsortedSIDs) Error() string {
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
e.curSID, e.prevSID)
}
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
// that we'll only attempt a single zlib decoding instance at a time. This
// allows us to also further bound our memory usage.
var zlibDecodeMtx sync.Mutex
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
// came across an unknown short channel ID encoding, and therefore were unable
// to continue parsing.
func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error {
return fmt.Errorf("unknown short chan id encoding: %v", encoding)
}
// QueryShortChanIDs is a message that allows the sender to query a set of
// channel announcement and channel update messages that correspond to the set
// of encoded short channel ID's. The encoding of the short channel ID's is
// detailed in the query message ensuring that the receiver knows how to
// properly decode each encode short channel ID which may be encoded using a
// compression format. The receiver should respond with a series of channel
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
// message.
type QueryShortChanIDs struct {
// ChainHash denotes the target chain that we're querying for the
// channel ID's of.
ChainHash chainhash.Hash
// EncodingType is a signal to the receiver of the message that
// indicates exactly how the set of short channel ID's that follow have
// been encoded.
EncodingType ShortChanIDEncoding
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used during testing.
noSort bool
}
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding,
s []ShortChannelID) *QueryShortChanIDs {
return &QueryShortChanIDs{
ChainHash: h,
EncodingType: e,
ShortChanIDs: s,
}
}
// A compile time check to ensure QueryShortChanIDs implements the
// lnwire.Message interface.
var _ Message = (*QueryShortChanIDs)(nil)
// Decode deserializes a serialized QueryShortChanIDs message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r, q.ChainHash[:])
if err != nil {
return err
}
q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
return err
}
// decodeShortChanIDs decodes a set of short channel ID's that have been
// encoded. The first byte of the body details how the short chan ID's were
// encoded. We'll use this type to govern exactly how we go about encoding the
// set of short channel ID's.
func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) {
// First, we'll attempt to read the number of bytes in the body of the
// set of encoded short channel ID's.
var numBytesResp uint16
err := ReadElements(r, &numBytesResp)
if err != nil {
return 0, nil, err
}
if numBytesResp == 0 {
return 0, nil, nil
}
queryBody := make([]byte, numBytesResp)
if _, err := io.ReadFull(r, queryBody); err != nil {
return 0, nil, err
}
// The first byte is the encoding type, so we'll extract that so we can
// continue our parsing.
encodingType := ShortChanIDEncoding(queryBody[0])
// Before continuing, we'll snip off the first byte of the query body
// as that was just the encoding type.
queryBody = queryBody[1:]
// Otherwise, depending on the encoding type, we'll decode the encode
// short channel ID's in a different manner.
switch encodingType {
// In this encoding, we'll simply read a sort array of encoded short
// channel ID's from the buffer.
case EncodingSortedPlain:
// If after extracting the encoding type, the number of
// remaining bytes is not a whole multiple of the size of an
// encoded short channel ID (8 bytes), then we'll return a
// parsing error.
if len(queryBody)%8 != 0 {
return 0, nil, fmt.Errorf("whole number of short "+
"chan ID's cannot be encoded in len=%v",
len(queryBody))
}
// As each short channel ID is encoded as 8 bytes, we can
// compute the number of bytes encoded based on the size of the
// query body.
numShortChanIDs := len(queryBody) / 8
if numShortChanIDs == 0 {
return encodingType, nil, nil
}
// Finally, we'll read out the exact number of short channel
// ID's to conclude our parsing.
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
bodyReader := bytes.NewReader(queryBody)
var lastChanID ShortChannelID
for i := 0; i < numShortChanIDs; i++ {
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
return 0, nil, fmt.Errorf("unable to parse "+
"short chan ID: %v", err)
}
// We'll ensure that this short chan ID is greater than
// the last one. This is a requirement within the
// encoding, and if violated can aide us in detecting
// malicious payloads. This can only be true starting
// at the second chanID.
cid := shortChanIDs[i]
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
}
return encodingType, shortChanIDs, nil
// In this encoding, we'll use zlib to decode the compressed payload.
// However, we'll pay attention to ensure that we don't open our selves
// up to a memory exhaustion attack.
case EncodingSortedZlib:
// We'll obtain an ultimately release the zlib decode mutex.
// This guards us against allocating too much memory to decode
// each instance from concurrent peers.
zlibDecodeMtx.Lock()
defer zlibDecodeMtx.Unlock()
// At this point, if there's no body remaining, then only the encoding
// type was specified, meaning that there're no further bytes to be
// parsed.
if len(queryBody) == 0 {
return encodingType, nil, nil
}
// Before we start to decode, we'll create a limit reader over
// the current reader. This will ensure that we can control how
// much memory we're allocating during the decoding process.
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
R: bytes.NewReader(queryBody),
N: maxZlibBufSize,
})
if err != nil {
return 0, nil, fmt.Errorf("unable to create zlib "+
"reader: %w", err)
}
var (
shortChanIDs []ShortChannelID
lastChanID ShortChannelID
i int
)
for {
// We'll now attempt to read the next short channel ID
// encoded in the payload.
var cid ShortChannelID
err := ReadElements(limitedDecompressor, &cid)
switch {
// If we get an EOF error, then that either means we've
// read all that's contained in the buffer, or have hit
// our limit on the number of bytes we'll read. In
// either case, we'll return what we have so far.
case err == io.ErrUnexpectedEOF || err == io.EOF:
return encodingType, shortChanIDs, nil
// Otherwise, we hit some other sort of error, possibly
// an invalid payload, so we'll exit early with the
// error.
case err != nil:
return 0, nil, fmt.Errorf("unable to "+
"deflate next short chan "+
"ID: %v", err)
}
// We successfully read the next ID, so we'll collect
// that in the set of final ID's to return.
shortChanIDs = append(shortChanIDs, cid)
// Finally, we'll ensure that this short chan ID is
// greater than the last one. This is a requirement
// within the encoding, and if violated can aide us in
// detecting malicious payloads. This can only be true
// starting at the second chanID.
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
i++
}
default:
// If we've been sent an encoding type that we don't know of,
// then we'll return a parsing error as we can't continue if
// we're unable to encode them.
return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
}
}
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
// First, we'll write out the chain hash.
err := WriteElements(w, q.ChainHash[:])
if err != nil {
return err
}
// Base on our encoding type, we'll write out the set of short channel
// ID's.
return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
}
// encodeShortChanIDs encodes the passed short channel ID's into the passed
// io.Writer, respecting the specified encoding type.
func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
shortChanIDs []ShortChannelID, noSort bool) error {
// For both of the current encoding types, the channel ID's are to be
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !noSort {
sort.Slice(shortChanIDs, func(i, j int) bool {
return shortChanIDs[i].ToUint64() <
shortChanIDs[j].ToUint64()
})
}
switch encodingType {
// In this encoding, we'll simply write a sorted array of encoded short
// channel ID's from the buffer.
case EncodingSortedPlain:
// First, we'll write out the number of bytes of the query
// body. We add 1 as the response will have the encoding type
// prepended to it.
numBytesBody := uint16(len(shortChanIDs)*8) + 1
if err := WriteElements(w, numBytesBody); err != nil {
return err
}
// We'll then write out the encoding that that follows the
// actual encoded short channel ID's.
if err := WriteElements(w, encodingType); err != nil {
return err
}
// Now that we know they're sorted, we can write out each short
// channel ID to the buffer.
for _, chanID := range shortChanIDs {
if err := WriteElements(w, chanID); err != nil {
return fmt.Errorf("unable to write short chan "+
"ID: %v", err)
}
}
return nil
// For this encoding we'll first write out a serialized version of all
// the channel ID's into a buffer, then zlib encode that. The final
// payload is what we'll write out to the passed io.Writer.
//
// TODO(roasbeef): assumes the caller knows the proper chunk size to
// pass to avoid bin-packing here
case EncodingSortedZlib:
// We'll make a new buffer, then wrap that with a zlib writer
// so we can write directly to the buffer and encode in a
// streaming manner.
var buf bytes.Buffer
zlibWriter := zlib.NewWriter(&buf)
// If we don't have anything at all to write, then we'll write
// an empty payload so we don't include things like the zlib
// header when the remote party is expecting no actual short
// channel IDs.
var compressedPayload []byte
if len(shortChanIDs) > 0 {
// Next, we'll write out all the channel ID's directly
// into the zlib writer, which will do compressing on
// the fly.
for _, chanID := range shortChanIDs {
err := WriteElements(zlibWriter, chanID)
if err != nil {
return fmt.Errorf("unable to write short chan "+
"ID: %v", err)
}
}
// Now that we've written all the elements, we'll
// ensure the compressed stream is written to the
// underlying buffer.
if err := zlibWriter.Close(); err != nil {
return fmt.Errorf("unable to finalize "+
"compression: %v", err)
}
compressedPayload = buf.Bytes()
}
// Now that we have all the items compressed, we can compute
// what the total payload size will be. We add one to account
// for the byte to encode the type.
//
// If we don't have any actual bytes to write, then we'll end
// up emitting one byte for the length, followed by the
// encoding type, and nothing more. The spec isn't 100% clear
// in this area, but we do this as this is what most of the
// other implementations do.
numBytesBody := len(compressedPayload) + 1
// Finally, we can write out the number of bytes, the
// compression type, and finally the buffer itself.
if err := WriteElements(w, uint16(numBytesBody)); err != nil {
return err
}
if err := WriteElements(w, encodingType); err != nil {
return err
}
_, err := w.Write(compressedPayload)
return err
default:
// If we're trying to encode with an encoding type that we
// don't know of, then we'll return a parsing error as we can't
// continue if we're unable to encode them.
return ErrUnknownShortChanIDEncoding(encodingType)
}
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) MsgType() MessageType {
return MsgQueryShortChanIDs
}
// MaxPayloadLength returns the maximum allowed payload size for a
// QueryShortChanIDs complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}
package lnwire
import "io"
// ReplyChannelRange is the response to the QueryChannelRange message. It
// includes the original query, and the next streaming chunk of encoded short
// channel ID's as the response. We'll also include a byte that indicates if
// this is the last query in the message.
type ReplyChannelRange struct {
// QueryChannelRange is the corresponding query to this response.
QueryChannelRange
// Complete denotes if this is the conclusion of the set of streaming
// responses to the original query.
Complete uint8
// EncodingType is a signal to the receiver of the message that
// indicates exactly how the set of short channel ID's that follow have
// been encoded.
EncodingType ShortChanIDEncoding
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used for testing.
noSort bool
}
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
func NewReplyChannelRange() *ReplyChannelRange {
return &ReplyChannelRange{}
}
// A compile time check to ensure ReplyChannelRange implements the
// lnwire.Message interface.
var _ Message = (*ReplyChannelRange)(nil)
// Decode deserializes a serialized ReplyChannelRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
err := c.QueryChannelRange.Decode(r, pver)
if err != nil {
return err
}
if err := ReadElements(r, &c.Complete); err != nil {
return err
}
c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r)
return err
}
// Encode serializes the target ReplyChannelRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
if err := c.QueryChannelRange.Encode(w, pver); err != nil {
return err
}
if err := WriteElements(w, c.Complete); err != nil {
return err
}
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) MsgType() MessageType {
return MsgReplyChannelRange
}
// MaxPayloadLength returns the maximum allowed payload size for a
// ReplyChannelRange complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ReplyShortChanIDsEnd is a message that marks the end of a streaming message
// response to an initial QueryShortChanIDs message. This marks that the
// receiver of the original QueryShortChanIDs for the target chain has either
// sent all adequate responses it knows of, or doesn't know of any short chan
// ID's for the target chain.
type ReplyShortChanIDsEnd struct {
// ChainHash denotes the target chain that we're respond to a short
// chan ID query for.
ChainHash chainhash.Hash
// Complete will be set to 0 if we don't know of the chain that the
// remote peer sent their query for. Otherwise, we'll set this to 1 in
// order to indicate that we've sent all known responses for the prior
// set of short chan ID's in the corresponding QueryShortChanIDs
// message.
Complete uint8
}
// NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message.
func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd {
return &ReplyShortChanIDsEnd{}
}
// A compile time check to ensure ReplyShortChanIDsEnd implements the
// lnwire.Message interface.
var _ Message = (*ReplyShortChanIDsEnd)(nil)
// Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
c.ChainHash[:],
&c.Complete,
)
}
// Encode serializes the target ReplyShortChanIDsEnd into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChainHash[:],
c.Complete,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) MsgType() MessageType {
return MsgReplyShortChanIDsEnd
}
// MaxPayloadLength returns the maximum allowed payload size for a
// ReplyShortChanIDsEnd complete message observing the specified protocol
// version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 {
// 32 (chain hash) + 1 (complete)
return 33
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
)
// RevokeAndAck is sent by either side once a CommitSig message has been
// received, and validated. This message serves to revoke the prior commitment
// transaction, which was the most up to date version until a CommitSig message
// referencing the specified ChannelPoint was received. Additionally, this
// message also piggyback's the next revocation hash that Alice should use when
// constructing the Bob's version of the next commitment transaction (which
// would be done before sending a CommitSig message). This piggybacking allows
// Alice to send the next CommitSig message modifying Bob's commitment
// transaction without first asking for a revocation hash initially.
type RevokeAndAck struct {
// ChanID uniquely identifies to which currently active channel this
// RevokeAndAck applies to.
ChanID ChannelID
// Revocation is the preimage to the revocation hash of the now prior
// commitment transaction.
Revocation [32]byte
// NextRevocationKey is the next commitment point which should be used
// for the next commitment transaction the remote peer creates for us.
// This, in conjunction with revocation base point will be used to
// create the proper revocation key used within the commitment
// transaction.
NextRevocationKey *btcec.PublicKey
}
// NewRevokeAndAck creates a new RevokeAndAck message.
func NewRevokeAndAck() *RevokeAndAck {
return &RevokeAndAck{}
}
// A compile time check to ensure RevokeAndAck implements the lnwire.Message
// interface.
var _ Message = (*RevokeAndAck)(nil)
// Decode deserializes a serialized RevokeAndAck message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
c.Revocation[:],
&c.NextRevocationKey,
)
}
// Encode serializes the target RevokeAndAck into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.Revocation[:],
c.NextRevocationKey,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) MsgType() MessageType {
return MsgRevokeAndAck
}
// MaxPayloadLength returns the maximum allowed payload size for a RevokeAndAck
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) MaxPayloadLength(uint32) uint32 {
// 32 + 32 + 33
return 97
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *RevokeAndAck) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"fmt"
)
// ShortChannelID represents the set of data which is needed to retrieve all
// necessary data to validate the channel existence.
type ShortChannelID struct {
// BlockHeight is the height of the block where funding transaction
// located.
//
// NOTE: This field is limited to 3 bytes.
BlockHeight uint32
// TxIndex is a position of funding transaction within a block.
//
// NOTE: This field is limited to 3 bytes.
TxIndex uint32
// TxPosition indicating transaction output which pays to the channel.
TxPosition uint16
}
// NewShortChanIDFromInt returns a new ShortChannelID which is the decoded
// version of the compact channel ID encoded within the uint64. The format of
// the compact channel ID is as follows: 3 bytes for the block height, 3 bytes
// for the transaction index, and 2 bytes for the output index.
func NewShortChanIDFromInt(chanID uint64) ShortChannelID {
return ShortChannelID{
BlockHeight: uint32(chanID >> 40),
TxIndex: uint32(chanID>>16) & 0xFFFFFF,
TxPosition: uint16(chanID),
}
}
// ToUint64 converts the ShortChannelID into a compact format encoded within a
// uint64 (8 bytes).
func (c ShortChannelID) ToUint64() uint64 {
// TODO(roasbeef): explicit error on overflow?
return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) |
(uint64(c.TxPosition)))
}
// String generates a human-readable representation of the channel ID.
func (c ShortChannelID) String() string {
return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition)
}
package lnwire
import (
"io"
)
// Shutdown is sent by either side in order to initiate the cooperative closure
// of a channel. This message is sparse as both sides implicitly have the
// information necessary to construct a transaction that will send the settled
// funds of both parties to the final delivery addresses negotiated during the
// funding workflow.
type Shutdown struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// Address is the script to which the channel funds will be paid.
Address DeliveryAddress
}
// DeliveryAddress is used to communicate the address to which funds from a
// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or
// p2wpkh.
type DeliveryAddress []byte
// deliveryAddressMaxSize is the maximum expected size in bytes of a
// DeliveryAddress based on the types of scripts we know.
// Following are the known scripts and their sizes in bytes.
// - pay to witness script hash: 34
// - pay to pubkey hash: 25
// - pay to script hash: 22
// - pay to witness pubkey hash: 22.
const deliveryAddressMaxSize = 34
// NewShutdown creates a new Shutdown message.
func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
return &Shutdown{
ChannelID: cid,
Address: addr,
}
}
// A compile-time check to ensure Shutdown implements the lnwire.Message
// interface.
var _ Message = (*Shutdown)(nil)
// Decode deserializes a serialized Shutdown stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
return ReadElements(r, &s.ChannelID, &s.Address)
}
// Encode serializes the target Shutdown into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, s.ChannelID, s.Address)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) MsgType() MessageType {
return MsgShutdown
}
// MaxPayloadLength returns the maximum allowed payload size for this message
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) MaxPayloadLength(pver uint32) uint32 {
var length uint32
// ChannelID - 32bytes
length += 32
// Len - 2 bytes
length += 2
// ScriptPubKey - maximum delivery address size.
length += deliveryAddressMaxSize
return length
}
package lnwire
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/lightningnetwork/lnd/input"
)
// Sig is a fixed-sized ECDSA signature. Unlike Bitcoin, we use fixed sized
// signatures on the wire, instead of DER encoded signatures. This type
// provides several methods to convert to/from a regular Bitcoin DER encoded
// signature (raw bytes and *ecdsa.Signature).
type Sig [64]byte
// NewSigFromRawSignature returns a Sig from a Bitcoin raw signature encoded in
// the canonical DER encoding.
func NewSigFromRawSignature(sig []byte) (Sig, error) {
var b Sig
if len(sig) == 0 {
return b, fmt.Errorf("cannot decode empty signature")
}
// Extract lengths of R and S. The DER representation is laid out as
// 0x30 <length> 0x02 <length r> r 0x02 <length s> s
// which means the length of R is the 4th byte and the length of S
// is the second byte after R ends. 0x02 signifies a length-prefixed,
// zero-padded, big-endian bigint. 0x30 signifies a DER signature.
// See the Serialize() method for ecdsa.Signature for details.
rLen := sig[3]
sLen := sig[5+rLen]
// Check to make sure R and S can both fit into their intended buffers.
// We check S first because these code blocks decrement sLen and rLen
// in the case of a 33-byte 0-padded integer returned from Serialize()
// and rLen is used in calculating array indices for S. We can track
// this with additional variables, but it's more efficient to just
// check S first.
if sLen > 32 {
if (sLen > 33) || (sig[6+rLen] != 0x00) {
return b, fmt.Errorf("S is over 32 bytes long " +
"without padding")
}
sLen--
copy(b[64-sLen:], sig[7+rLen:])
} else {
copy(b[64-sLen:], sig[6+rLen:])
}
// Do the same for R as we did for S
if rLen > 32 {
if (rLen > 33) || (sig[4] != 0x00) {
return b, fmt.Errorf("R is over 32 bytes long " +
"without padding")
}
rLen--
copy(b[32-rLen:], sig[5:5+rLen])
} else {
copy(b[32-rLen:], sig[4:4+rLen])
}
return b, nil
}
// NewSigFromSignature creates a new signature as used on the wire, from an
// existing ecdsa.Signature.
func NewSigFromSignature(e input.Signature) (Sig, error) {
if e == nil {
return Sig{}, fmt.Errorf("cannot decode empty signature")
}
// Serialize the signature with all the checks that entails.
return NewSigFromRawSignature(e.Serialize())
}
// ToSignature converts the fixed-sized signature to a ecdsa.Signature objects
// which can be used for signature validation checks.
func (b *Sig) ToSignature() (*ecdsa.Signature, error) {
// Parse the signature with strict checks.
sigBytes := b.ToSignatureBytes()
sig, err := ecdsa.ParseDERSignature(sigBytes)
if err != nil {
return nil, err
}
return sig, nil
}
// ToSignatureBytes serializes the target fixed-sized signature into the raw
// bytes of a DER encoding.
func (b *Sig) ToSignatureBytes() []byte {
// Extract canonically-padded bigint representations from buffer
r := extractCanonicalPadding(b[0:32])
s := extractCanonicalPadding(b[32:64])
rLen := uint8(len(r))
sLen := uint8(len(s))
// Create a canonical serialized signature. DER format is:
// 0x30 <length> 0x02 <length r> r 0x02 <length s> s
sigBytes := make([]byte, 6+rLen+sLen)
sigBytes[0] = 0x30 // DER signature magic value
sigBytes[1] = 4 + rLen + sLen // Length of rest of signature
sigBytes[2] = 0x02 // Big integer magic value
sigBytes[3] = rLen // Length of R
sigBytes[rLen+4] = 0x02 // Big integer magic value
sigBytes[rLen+5] = sLen // Length of S
copy(sigBytes[4:], r) // Copy R
copy(sigBytes[rLen+6:], s) // Copy S
return sigBytes
}
// extractCanonicalPadding is a utility function to extract the canonical
// padding of a big-endian integer from the wire encoding (a 0-padded
// big-endian integer) such that it passes btcec.canonicalPadding test.
func extractCanonicalPadding(b []byte) []byte {
for i := 0; i < len(b); i++ {
// Found first non-zero byte.
if b[i] > 0 {
// If the MSB is set, we need zero padding.
if b[i]&0x80 == 0x80 {
return append([]byte{0x00}, b[i:]...)
}
return b[i:]
}
}
return []byte{0x00}
}
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
// TrueBoolean is a record that indicates true or false using the presence of
// the record. If the record is absent, it indicates false. If it is present,
// it indicates true.
type TrueBoolean struct{}
// Record returns the tlv record for the boolean entry.
func (b *TrueBoolean) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, b, 0, booleanEncoder, booleanDecoder,
)
}
func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error {
if _, ok := val.(*TrueBoolean); ok {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}
func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}
package lnwire
import (
"io"
)
// OnionPacketSize is the size of the serialized Sphinx onion packet included
// in each UpdateAddHTLC message. The breakdown of the onion packet is as
// follows: 1-byte version, 33-byte ephemeral public key (for ECDH), 1300-bytes
// of per-hop data, and a 32-byte HMAC over the entire packet.
const OnionPacketSize = 1366
// UpdateAddHTLC is the message sent by Alice to Bob when she wishes to add an
// HTLC to his remote commitment transaction. In addition to information
// detailing the value, the ID, expiry, and the onion blob is also included
// which allows Bob to derive the next hop in the route. The HTLC added by this
// message is to be added to the remote node's "pending" HTLC's. A subsequent
// CommitSig message will move the pending HTLC to the newly created commitment
// transaction, marking them as "staged".
type UpdateAddHTLC struct {
// ChanID is the particular active channel that this UpdateAddHTLC is
// bound to.
ChanID ChannelID
// ID is the identification server for this HTLC. This value is
// explicitly included as it allows nodes to survive single-sided
// restarts. The ID value for this sides starts at zero, and increases
// with each offered HTLC.
ID uint64
// Amount is the amount of millisatoshis this HTLC is worth.
Amount MilliSatoshi
// PaymentHash is the payment hash to be included in the HTLC this
// request creates. The pre-image to this HTLC must be revealed by the
// upstream peer in order to fully settle the HTLC.
PaymentHash [32]byte
// Expiry is the number of blocks after which this HTLC should expire.
// It is the receiver's duty to ensure that the outgoing HTLC has a
// sufficient expiry value to allow her to redeem the incoming HTLC.
Expiry uint32
// OnionBlob is the raw serialized mix header used to route an HTLC in
// a privacy-preserving manner. The mix header is defined currently to
// be parsed as a 4-tuple: (groupElement, routingInfo, headerMAC,
// body). First the receiving node should use the groupElement, and
// its current onion key to derive a shared secret with the source.
// Once the shared secret has been derived, the headerMAC should be
// checked FIRST. Note that the MAC only covers the routingInfo field.
// If the MAC matches, and the shared secret is fresh, then the node
// should strip off a layer of encryption, exposing the next hop to be
// used in the subsequent UpdateAddHTLC message.
OnionBlob [OnionPacketSize]byte
}
// NewUpdateAddHTLC returns a new empty UpdateAddHTLC message.
func NewUpdateAddHTLC() *UpdateAddHTLC {
return &UpdateAddHTLC{}
}
// A compile time check to ensure UpdateAddHTLC implements the lnwire.Message
// interface.
var _ Message = (*UpdateAddHTLC)(nil)
// Decode deserializes a serialized UpdateAddHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
&c.Amount,
c.PaymentHash[:],
&c.Expiry,
c.OnionBlob[:],
)
}
// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.Amount,
c.PaymentHash[:],
c.Expiry,
c.OnionBlob[:],
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) MsgType() MessageType {
return MsgUpdateAddHTLC
}
// MaxPayloadLength returns the maximum allowed payload size for an UpdateAddHTLC
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) MaxPayloadLength(uint32) uint32 {
// 1450
return 32 + 8 + 4 + 8 + 32 + 1366
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateAddHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"io"
)
// OpaqueReason is an opaque encrypted byte slice that encodes the exact
// failure reason and additional some supplemental data. The contents of this
// slice can only be decrypted by the sender of the original HTLC.
type OpaqueReason []byte
// UpdateFailHTLC is sent by Alice to Bob in order to remove a previously added
// HTLC. Upon receipt of an UpdateFailHTLC the HTLC should be removed from the
// next commitment transaction, with the UpdateFailHTLC propagated backwards in
// the route to fully undo the HTLC.
type UpdateFailHTLC struct {
// ChanID is the particular active channel that this
// UpdateFailHTLC is bound to.
ChanID ChannelID
// ID references which HTLC on the remote node's commitment transaction
// has timed out.
ID uint64
// Reason is an onion-encrypted blob that details why the HTLC was
// failed. This blob is only fully decryptable by the initiator of the
// HTLC message.
Reason OpaqueReason
}
// A compile time check to ensure UpdateFailHTLC implements the lnwire.Message
// interface.
var _ Message = (*UpdateFailHTLC)(nil)
// Decode deserializes a serialized UpdateFailHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
&c.Reason,
)
}
// Encode serializes the target UpdateFailHTLC into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.Reason,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) MsgType() MessageType {
return MsgUpdateFailHTLC
}
// MaxPayloadLength returns the maximum allowed payload size for an UpdateFailHTLC
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 {
var length uint32
// Length of the ChanID
length += 32
// Length of the ID
length += 8
// Length of the length opaque reason
length += 2
// Length of the Reason
length += 292
return length
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFailHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"crypto/sha256"
"io"
)
// UpdateFailMalformedHTLC is sent by either the payment forwarder or by
// payment receiver to the payment sender in order to notify it that the onion
// blob can't be parsed. For that reason we send this message instead of
// obfuscate the onion failure.
type UpdateFailMalformedHTLC struct {
// ChanID is the particular active channel that this
// UpdateFailMalformedHTLC is bound to.
ChanID ChannelID
// ID references which HTLC on the remote node's commitment transaction
// has timed out.
ID uint64
// ShaOnionBlob hash of the onion blob on which can't be parsed by the
// node in the payment path.
ShaOnionBlob [sha256.Size]byte
// FailureCode the exact reason why onion blob haven't been parsed.
FailureCode FailCode
}
// A compile time check to ensure UpdateFailMalformedHTLC implements the
// lnwire.Message interface.
var _ Message = (*UpdateFailMalformedHTLC)(nil)
// Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
c.ShaOnionBlob[:],
&c.FailureCode,
)
}
// Encode serializes the target UpdateFailMalformedHTLC into the passed
// io.Writer observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.ShaOnionBlob[:],
c.FailureCode,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) MsgType() MessageType {
return MsgUpdateFailMalformedHTLC
}
// MaxPayloadLength returns the maximum allowed payload size for a
// UpdateFailMalformedHTLC complete message observing the specified protocol
// version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 {
// 32 + 8 + 32 + 2
return 74
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFailMalformedHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"io"
)
// UpdateFee is the message the channel initiator sends to the other peer if
// the channel commitment fee needs to be updated.
type UpdateFee struct {
// ChanID is the channel that this UpdateFee is meant for.
ChanID ChannelID
// FeePerKw is the fee-per-kw on commit transactions that the sender of
// this message wants to use for this channel.
//
// TODO(halseth): make SatPerKWeight when fee estimation is moved to
// own package. Currently this will cause an import cycle.
FeePerKw uint32
}
// NewUpdateFee creates a new UpdateFee message.
func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee {
return &UpdateFee{
ChanID: chanID,
FeePerKw: feePerKw,
}
}
// A compile time check to ensure UpdateFee implements the lnwire.Message
// interface.
var _ Message = (*UpdateFee)(nil)
// Decode deserializes a serialized UpdateFee message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.FeePerKw,
)
}
// Encode serializes the target UpdateFee into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.FeePerKw,
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) MsgType() MessageType {
return MsgUpdateFee
}
// MaxPayloadLength returns the maximum allowed payload size for an UpdateFee
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) MaxPayloadLength(uint32) uint32 {
// 32 + 4
return 36
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFee) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"io"
)
// UpdateFulfillHTLC is sent by Alice to Bob when she wishes to settle a
// particular HTLC referenced by its HTLCKey within a specific active channel
// referenced by ChannelPoint. A subsequent CommitSig message will be sent by
// Alice to "lock-in" the removal of the specified HTLC, possible containing a
// batch signature covering several settled HTLC's.
type UpdateFulfillHTLC struct {
// ChanID references an active channel which holds the HTLC to be
// settled.
ChanID ChannelID
// ID denotes the exact HTLC stage within the receiving node's
// commitment transaction to be removed.
ID uint64
// PaymentPreimage is the R-value preimage required to fully settle an
// HTLC.
PaymentPreimage [32]byte
}
// NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC.
func NewUpdateFulfillHTLC(chanID ChannelID, id uint64,
preimage [32]byte) *UpdateFulfillHTLC {
return &UpdateFulfillHTLC{
ChanID: chanID,
ID: id,
PaymentPreimage: preimage,
}
}
// A compile time check to ensure UpdateFulfillHTLC implements the lnwire.Message
// interface.
var _ Message = (*UpdateFulfillHTLC)(nil)
// Decode deserializes a serialized UpdateFulfillHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
c.PaymentPreimage[:],
)
}
// Encode serializes the target UpdateFulfillHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error {
return WriteElements(w,
c.ChanID,
c.ID,
c.PaymentPreimage[:],
)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) MsgType() MessageType {
return MsgUpdateFulfillHTLC
}
// MaxPayloadLength returns the maximum allowed payload size for an UpdateFulfillHTLC
// complete message observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) MaxPayloadLength(uint32) uint32 {
// 32 + 8 + 32
return 72
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFulfillHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package migration
import "github.com/btcsuite/btclog/v2"
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration12
import (
"bytes"
"encoding/binary"
"io"
"time"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// MaxMemoSize is maximum size of the memo field within invoices stored
// in the database.
MaxMemoSize = 1024
// maxReceiptSize is the maximum size of the payment receipt stored
// within the database along side incoming/outgoing invoices.
maxReceiptSize = 1024
// MaxPaymentRequestSize is the max size of a payment request for
// this invoice.
// TODO(halseth): determine the max length payment request when field
// lengths are final.
MaxPaymentRequestSize = 4096
memoType tlv.Type = 0
payReqType tlv.Type = 1
createTimeType tlv.Type = 2
settleTimeType tlv.Type = 3
addIndexType tlv.Type = 4
settleIndexType tlv.Type = 5
preimageType tlv.Type = 6
valueType tlv.Type = 7
cltvDeltaType tlv.Type = 8
expiryType tlv.Type = 9
paymentAddrType tlv.Type = 10
featuresType tlv.Type = 11
invStateType tlv.Type = 12
amtPaidType tlv.Type = 13
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
// Within the invoice bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// ContractState describes the state the invoice is in.
type ContractState uint8
// ContractTerm is a companion struct to the Invoice struct. This struct houses
// the necessary conditions required before the invoice can be considered fully
// settled by the payee.
type ContractTerm struct {
// PaymentPreimage is the preimage which is to be revealed in the
// occasion that an HTLC paying to the hash of this preimage is
// extended.
PaymentPreimage lntypes.Preimage
// Value is the expected amount of milli-satoshis to be paid to an HTLC
// which can be satisfied by the above preimage.
Value lnwire.MilliSatoshi
// State describes the state the invoice is in.
State ContractState
// PaymentAddr is a randomly generated value include in the MPP record
// by the sender to prevent probing of the receiver.
PaymentAddr [32]byte
// Features is the feature vectors advertised on the payment request.
Features *lnwire.FeatureVector
}
// Invoice is a payment invoice generated by a payee in order to request
// payment for some good or service. The inclusion of invoices within Lightning
// creates a payment work flow for merchants very similar to that of the
// existing financial system within PayPal, etc. Invoices are added to the
// database when a payment is requested, then can be settled manually once the
// payment is received at the upper layer. For record keeping purposes,
// invoices are never deleted from the database, instead a bit is toggled
// denoting the invoice has been fully settled. Within the database, all
// invoices must have a unique payment hash which is generated by taking the
// sha256 of the payment preimage.
type Invoice struct {
// Memo is an optional memo to be stored along side an invoice. The
// memo may contain further details pertaining to the invoice itself,
// or any other message which fits within the size constraints.
Memo []byte
// PaymentRequest is an optional field where a payment request created
// for this invoice can be stored.
PaymentRequest []byte
// FinalCltvDelta is the minimum required number of blocks before htlc
// expiry when the invoice is accepted.
FinalCltvDelta int32
// Expiry defines how long after creation this invoice should expire.
Expiry time.Duration
// CreationDate is the exact time the invoice was created.
CreationDate time.Time
// SettleDate is the exact time the invoice was settled.
SettleDate time.Time
// Terms are the contractual payment terms of the invoice. Once all the
// terms have been satisfied by the payer, then the invoice can be
// considered fully fulfilled.
//
// TODO(roasbeef): later allow for multiple terms to fulfill the final
// invoice: payment fragmentation, etc.
Terms ContractTerm
// AddIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all invoices created.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been added before they re-connected.
//
// NOTE: This index starts at 1.
AddIndex uint64
// SettleIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all settled invoices.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been settled before they re-connected.
//
// NOTE: This index starts at 1.
SettleIndex uint64
// AmtPaid is the final amount that we ultimately accepted for pay for
// this invoice. We specify this value independently as it's possible
// that the invoice originally didn't specify an amount, or the sender
// overpaid.
AmtPaid lnwire.MilliSatoshi
// Htlcs records all htlcs that paid to this invoice. Some of these
// htlcs may have been marked as canceled.
Htlcs []byte
}
// LegacyDeserializeInvoice decodes an invoice from the passed io.Reader using
// the pre-TLV serialization.
func LegacyDeserializeInvoice(r io.Reader) (Invoice, error) {
var err error
invoice := Invoice{}
// TODO(roasbeef): use read full everywhere
invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "")
if err != nil {
return invoice, err
}
_, err = wire.ReadVarBytes(r, 0, maxReceiptSize, "")
if err != nil {
return invoice, err
}
invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "")
if err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil {
return invoice, err
}
var expiry int64
if err := binary.Read(r, byteOrder, &expiry); err != nil {
return invoice, err
}
invoice.Expiry = time.Duration(expiry)
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
if err != nil {
return invoice, err
}
if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil {
return invoice, err
}
settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled")
if err != nil {
return invoice, err
}
if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil {
return invoice, err
}
if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil {
return invoice, err
}
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return invoice, err
}
invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil {
return invoice, err
}
invoice.Htlcs, err = deserializeHtlcs(r)
if err != nil {
return Invoice{}, err
}
return invoice, nil
}
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
// as a flattened byte slice.
func deserializeHtlcs(r io.Reader) ([]byte, error) {
var b bytes.Buffer
_, err := io.Copy(&b, r)
return b.Bytes(), err
}
// SerializeInvoice serializes an invoice to a writer.
//
// nolint: dupl
func SerializeInvoice(w io.Writer, i *Invoice) error {
creationDateBytes, err := i.CreationDate.MarshalBinary()
if err != nil {
return err
}
settleDateBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return err
}
var fb bytes.Buffer
err = i.Terms.Features.EncodeBase256(&fb)
if err != nil {
return err
}
featureBytes := fb.Bytes()
preimage := [32]byte(i.Terms.PaymentPreimage)
value := uint64(i.Terms.Value)
cltvDelta := uint32(i.FinalCltvDelta)
expiry := uint64(i.Expiry)
amtPaid := uint64(i.AmtPaid)
state := uint8(i.Terms.State)
tlvStream, err := tlv.NewStream(
// Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo),
tlv.MakePrimitiveRecord(payReqType, &i.PaymentRequest),
// Add/settle metadata.
tlv.MakePrimitiveRecord(createTimeType, &creationDateBytes),
tlv.MakePrimitiveRecord(settleTimeType, &settleDateBytes),
tlv.MakePrimitiveRecord(addIndexType, &i.AddIndex),
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms.
tlv.MakePrimitiveRecord(preimageType, &preimage),
tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry),
tlv.MakePrimitiveRecord(paymentAddrType, &i.Terms.PaymentAddr),
tlv.MakePrimitiveRecord(featuresType, &featureBytes),
// Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
)
if err != nil {
return err
}
var b bytes.Buffer
if err = tlvStream.Encode(&b); err != nil {
return err
}
err = binary.Write(w, byteOrder, uint64(b.Len()))
if err != nil {
return err
}
if _, err = w.Write(b.Bytes()); err != nil {
return err
}
return serializeHtlcs(w, i.Htlcs)
}
// serializeHtlcs writes a serialized list of invoice htlcs into a writer.
func serializeHtlcs(w io.Writer, htlcs []byte) error {
_, err := w.Write(htlcs)
return err
}
package migration12
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration12
import (
"bytes"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
)
var emptyFeatures = lnwire.NewFeatureVector(nil, nil)
// MigrateInvoiceTLV migrates all existing invoice bodies over to be serialized
// in a single TLV stream. In the process, we drop the Receipt field and add
// PaymentAddr and Features to the invoice Terms.
func MigrateInvoiceTLV(tx kvdb.RwTx) error {
log.Infof("Migrating invoice bodies to TLV, " +
"adding payment addresses and feature vectors.")
invoiceB := tx.ReadWriteBucket(invoiceBucket)
if invoiceB == nil {
return nil
}
type keyedInvoice struct {
key []byte
invoice Invoice
}
// Read in all existing invoices using the old format.
var invoices []keyedInvoice
err := invoiceB.ForEach(func(k, v []byte) error {
if v == nil {
return nil
}
invoiceReader := bytes.NewReader(v)
invoice, err := LegacyDeserializeInvoice(invoiceReader)
if err != nil {
return err
}
// Insert an empty feature vector on all old payments.
invoice.Terms.Features = emptyFeatures
invoices = append(invoices, keyedInvoice{
key: k,
invoice: invoice,
})
return nil
})
if err != nil {
return err
}
// Write out each one under its original key using TLV.
for _, ki := range invoices {
var b bytes.Buffer
err = SerializeInvoice(&b, &ki.invoice)
if err != nil {
return err
}
err = invoiceB.Put(ki.key, b.Bytes())
if err != nil {
return err
}
}
log.Infof("Migration to TLV invoice bodies, " +
"payment address, and features complete!")
return nil
}
package migration13
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration13
import (
"encoding/binary"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
paymentsRootBucket = []byte("payments-root-bucket")
// paymentCreationInfoKey is a key used in the payment's sub-bucket to
// store the creation info of the payment.
paymentCreationInfoKey = []byte("payment-creation-info")
// paymentFailInfoKey is a key used in the payment's sub-bucket to
// store information about the reason a payment failed.
paymentFailInfoKey = []byte("payment-fail-info")
// paymentAttemptInfoKey is a key used in the payment's sub-bucket to
// store the info about the latest attempt that was done for the
// payment in question.
paymentAttemptInfoKey = []byte("payment-attempt-info")
// paymentSettleInfoKey is a key used in the payment's sub-bucket to
// store the settle info of the payment.
paymentSettleInfoKey = []byte("payment-settle-info")
// paymentHtlcsBucket is a bucket where we'll store the information
// about the HTLCs that were attempted for a payment.
paymentHtlcsBucket = []byte("payment-htlcs-bucket")
// htlcAttemptInfoKey is a key used in a HTLC's sub-bucket to store the
// info about the attempt that was done for the HTLC in question.
htlcAttemptInfoKey = []byte("htlc-attempt-info")
// htlcSettleInfoKey is a key used in a HTLC's sub-bucket to store the
// settle info, if any.
htlcSettleInfoKey = []byte("htlc-settle-info")
// htlcFailInfoKey is a key used in a HTLC's sub-bucket to store
// failure information, if any.
htlcFailInfoKey = []byte("htlc-fail-info")
byteOrder = binary.BigEndian
)
// MigrateMPP migrates the payments to a new structure that accommodates for mpp
// payments.
func MigrateMPP(tx kvdb.RwTx) error {
log.Infof("Migrating payments to mpp structure")
// Iterate over all payments and store their indexing keys. This is
// needed, because no modifications are allowed inside a Bucket.ForEach
// loop.
paymentsBucket := tx.ReadWriteBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
var paymentKeys [][]byte
err := paymentsBucket.ForEach(func(k, v []byte) error {
paymentKeys = append(paymentKeys, k)
return nil
})
if err != nil {
return err
}
// With all keys retrieved, start the migration.
for _, k := range paymentKeys {
bucket := paymentsBucket.NestedReadWriteBucket(k)
// We only expect sub-buckets to be found in
// this top-level bucket.
if bucket == nil {
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
// Fetch old format creation info.
creationInfo := bucket.Get(paymentCreationInfoKey)
if creationInfo == nil {
return fmt.Errorf("creation info not found")
}
// Make a copy because bbolt doesn't allow this value to be
// changed in-place.
newCreationInfo := make([]byte, len(creationInfo))
copy(newCreationInfo, creationInfo)
// Convert to nano seconds.
timeBytes := newCreationInfo[32+8 : 32+8+8]
time := byteOrder.Uint64(timeBytes)
timeNs := time * 1000000000
byteOrder.PutUint64(timeBytes, timeNs)
// Write back new format creation info.
err := bucket.Put(paymentCreationInfoKey, newCreationInfo)
if err != nil {
return err
}
// No migration needed if there is no attempt stored.
attemptInfo := bucket.Get(paymentAttemptInfoKey)
if attemptInfo == nil {
continue
}
// Delete attempt info on the payment level.
if err := bucket.Delete(paymentAttemptInfoKey); err != nil {
return err
}
// Save attempt id for later use.
attemptID := attemptInfo[:8]
// Discard attempt id. It will become a bucket key in the new
// structure.
attemptInfo = attemptInfo[8:]
// Append unknown (zero) attempt time.
var zero [8]byte
attemptInfo = append(attemptInfo, zero[:]...)
// Create bucket that contains all htlcs.
htlcsBucket, err := bucket.CreateBucket(paymentHtlcsBucket)
if err != nil {
return err
}
// Create an htlc for this attempt.
htlcBucket, err := htlcsBucket.CreateBucket(attemptID)
if err != nil {
return err
}
// Save migrated attempt info.
err = htlcBucket.Put(htlcAttemptInfoKey, attemptInfo)
if err != nil {
return err
}
// Migrate settle info.
settleInfo := bucket.Get(paymentSettleInfoKey)
if settleInfo != nil {
// Payment-level settle info can be deleted.
err := bucket.Delete(paymentSettleInfoKey)
if err != nil {
return err
}
// Append unknown (zero) settle time.
settleInfo = append(settleInfo, zero[:]...)
// Save settle info.
err = htlcBucket.Put(htlcSettleInfoKey, settleInfo)
if err != nil {
return err
}
// Migration for settled htlc completed.
continue
}
// If there is no payment-level failure reason, the payment is
// still in flight and nothing else needs to be migrated.
// Otherwise the payment-level failure reason can remain
// unchanged.
inFlight := bucket.Get(paymentFailInfoKey) == nil
if inFlight {
continue
}
// The htlc failed. Add htlc fail info with reason unknown. We
// don't have access to the original failure reason anymore.
failInfo := []byte{
// Fail time unknown.
0, 0, 0, 0, 0, 0, 0, 0,
// Zero length wire message.
0,
// Failure reason unknown.
0,
// Failure source index zero.
0, 0, 0, 0,
}
// Save fail info.
err = htlcBucket.Put(htlcFailInfoKey, failInfo)
if err != nil {
return err
}
}
log.Infof("Migration of payments to mpp structure complete!")
return nil
}
package migration16
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration16
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
paymentsRootBucket = []byte("payments-root-bucket")
paymentSequenceKey = []byte("payment-sequence-key")
duplicatePaymentsBucket = []byte("payment-duplicate-bucket")
paymentsIndexBucket = []byte("payments-index-bucket")
byteOrder = binary.BigEndian
)
// paymentIndexType indicates the type of index we have recorded in the payment
// indexes bucket.
type paymentIndexType uint8
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
const paymentIndexTypeHash paymentIndexType = 0
// paymentIndex stores all the information we require to create an index by
// sequence number for a payment.
type paymentIndex struct {
// paymentHash is the hash of the payment, which is its key in the
// payment root bucket.
paymentHash []byte
// sequenceNumbers is the set of sequence numbers associated with this
// payment hash. There will be more than one sequence number in the
// case where duplicate payments are present.
sequenceNumbers [][]byte
}
// MigrateSequenceIndex migrates the payments db to contain a new bucket which
// provides an index from sequence number to payment hash. This is required
// for more efficient sequential lookup of payments, which are keyed by payment
// hash before this migration.
func MigrateSequenceIndex(tx kvdb.RwTx) error {
log.Infof("Migrating payments to add sequence number index")
// Get a list of indices we need to write.
indexList, err := getPaymentIndexList(tx)
if err != nil {
return err
}
// Create the top level bucket that we will use to index payments in.
bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket)
if err != nil {
return err
}
// Write an index for each of our payments.
for _, index := range indexList {
// Write indexes for each of our sequence numbers.
for _, seqNr := range index.sequenceNumbers {
err := putIndex(bucket, seqNr, index.paymentHash)
if err != nil {
return err
}
}
}
return nil
}
// putIndex performs a sanity check that ensures we are not writing duplicate
// indexes to disk then creates the index provided.
func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error {
// Add a sanity check that we do not already have an entry with
// this sequence number.
existingEntry := bucket.Get(sequenceNr)
if existingEntry != nil {
return fmt.Errorf("sequence number: %x duplicated",
sequenceNr)
}
bytes, err := serializePaymentIndexEntry(paymentHash)
if err != nil {
return err
}
return bucket.Put(sequenceNr, bytes)
}
// serializePaymentIndexEntry serializes a payment hash typed index. The value
// produced contains a payment index type (which can be used in future to
// signal different payment index types) and the payment hash.
func serializePaymentIndexEntry(hash []byte) ([]byte, error) {
var b bytes.Buffer
err := binary.Write(&b, byteOrder, paymentIndexTypeHash)
if err != nil {
return nil, err
}
if err := wire.WriteVarBytes(&b, 0, hash); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// getPaymentIndexList gets a list of indices we need to write for our current
// set of payments.
func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) {
// Iterate over all payments and store their indexing keys. This is
// needed, because no modifications are allowed inside a Bucket.ForEach
// loop.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil, nil
}
var indexList []paymentIndex
err := paymentsBucket.ForEach(func(k, v []byte) error {
// Get the bucket which contains the payment, fail if the key
// does not have a bucket.
bucket := paymentsBucket.NestedReadBucket(k)
if bucket == nil {
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return fmt.Errorf("nil sequence number bytes")
}
seqNrs, err := fetchSequenceNumbers(bucket)
if err != nil {
return err
}
// Create an index object with our payment hash and sequence
// numbers and append it to our set of indexes.
index := paymentIndex{
paymentHash: k,
sequenceNumbers: seqNrs,
}
indexList = append(indexList, index)
return nil
})
if err != nil {
return nil, err
}
return indexList, nil
}
// fetchSequenceNumbers fetches all the sequence numbers associated with a
// payment, including those belonging to any duplicate payments.
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
seqNum := paymentBucket.Get(paymentSequenceKey)
if seqNum == nil {
return nil, errors.New("expected sequence number")
}
sequenceNumbers := [][]byte{seqNum}
// Get the duplicate payments bucket, if it has no duplicates, just
// return early with the payment sequence number.
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
if duplicates == nil {
return sequenceNumbers, nil
}
// If we do have duplicated, they are keyed by sequence number, so we
// iterate through the duplicates bucket and add them to our set of
// sequence numbers.
if err := duplicates.ForEach(func(k, v []byte) error {
sequenceNumbers = append(sequenceNumbers, k)
return nil
}); err != nil {
return nil, err
}
return sequenceNumbers, nil
}
package migration20
import (
"encoding/binary"
"io"
"github.com/btcsuite/btcd/wire"
)
var (
byteOrder = binary.BigEndian
)
// writeOutpoint writes an outpoint from the passed writer.
func writeOutpoint(w io.Writer, o *wire.OutPoint) error {
if _, err := w.Write(o.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, o.Index); err != nil {
return err
}
return nil
}
// readOutpoint reads an outpoint from the passed reader.
func readOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
package migration20
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package
// will not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration20
import (
"bytes"
"fmt"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// openChanBucket stores all the open channel information.
openChanBucket = []byte("open-chan-bucket")
// closedChannelBucket stores all the closed channel information.
closedChannelBucket = []byte("closed-chan-bucket")
// outpointBucket is an index mapping outpoints to a tlv
// stream of channel data.
outpointBucket = []byte("outpoint-bucket")
)
const (
// A tlv type definition used to serialize an outpoint's indexStatus for
// use in the outpoint index.
indexStatusType tlv.Type = 0
)
// indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values.
type indexStatus uint8
const (
// outpointOpen represents an outpoint that is open in the outpoint index.
outpointOpen indexStatus = 0
// outpointClosed represents an outpoint that is closed in the outpoint
// index.
outpointClosed indexStatus = 1
)
// MigrateOutpointIndex populates the outpoint index with outpoints that
// the node already has. This takes every outpoint in the open channel
// bucket and every outpoint in the closed channel bucket and stores them
// in this index.
func MigrateOutpointIndex(tx kvdb.RwTx) error {
log.Info("Migrating to the outpoint index")
// First get the set of open outpoints.
openList, err := getOpenOutpoints(tx)
if err != nil {
return err
}
// Then get the set of closed outpoints.
closedList, err := getClosedOutpoints(tx)
if err != nil {
return err
}
// Get the outpoint bucket which was created in migration 19.
bucket := tx.ReadWriteBucket(outpointBucket)
// Store the set of open outpoints in the outpoint bucket.
if err := putOutpoints(bucket, openList, false); err != nil {
return err
}
// Store the set of closed outpoints in the outpoint bucket.
return putOutpoints(bucket, closedList, true)
}
// getOpenOutpoints traverses through the openChanBucket and returns the
// list of these channels' outpoints.
func getOpenOutpoints(tx kvdb.RwTx) ([]*wire.OutPoint, error) {
var ops []*wire.OutPoint
openBucket := tx.ReadBucket(openChanBucket)
if openBucket == nil {
return ops, nil
}
// Iterate through every node and chain bucket to get every open
// outpoint.
//
// The bucket tree:
// openChanBucket -> nodePub -> chainHash -> chanPoint
err := openBucket.ForEach(func(k, v []byte) error {
// Ensure that the key is the same size as a pubkey and the
// value is nil.
if len(k) != 33 || v != nil {
return nil
}
nodeBucket := openBucket.NestedReadBucket(k)
if nodeBucket == nil {
return nil
}
return nodeBucket.ForEach(func(k, v []byte) error {
// If there's a value it's not a bucket.
if v != nil {
return nil
}
chainBucket := nodeBucket.NestedReadBucket(k)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain: %x", k)
}
return chainBucket.ForEach(func(k, v []byte) error {
// If there's a value it's not a bucket.
if v != nil {
return nil
}
var op wire.OutPoint
r := bytes.NewReader(k)
if err := readOutpoint(r, &op); err != nil {
return err
}
ops = append(ops, &op)
return nil
})
})
})
if err != nil {
return nil, err
}
return ops, nil
}
// getClosedOutpoints traverses through the closedChanBucket and returns
// a list of closed outpoints.
func getClosedOutpoints(tx kvdb.RwTx) ([]*wire.OutPoint, error) {
var ops []*wire.OutPoint
closedBucket := tx.ReadBucket(closedChannelBucket)
if closedBucket == nil {
return ops, nil
}
// Iterate through every key-value pair to gather all outpoints.
err := closedBucket.ForEach(func(k, v []byte) error {
var op wire.OutPoint
r := bytes.NewReader(k)
if err := readOutpoint(r, &op); err != nil {
return err
}
ops = append(ops, &op)
return nil
})
if err != nil {
return nil, err
}
return ops, nil
}
// putOutpoints puts the set of outpoints into the outpoint bucket.
func putOutpoints(bucket kvdb.RwBucket, ops []*wire.OutPoint, isClosed bool) error {
status := uint8(outpointOpen)
if isClosed {
status = uint8(outpointClosed)
}
record := tlv.MakePrimitiveRecord(indexStatusType, &status)
stream, err := tlv.NewStream(record)
if err != nil {
return err
}
var b bytes.Buffer
if err := stream.Encode(&b); err != nil {
return err
}
// Store the set of outpoints with the encoded tlv stream.
for _, op := range ops {
var opBuf bytes.Buffer
if err := writeOutpoint(&opBuf, op); err != nil {
return err
}
if err := bucket.Put(opBuf.Bytes(), b.Bytes()); err != nil {
return err
}
}
return nil
}
package common
import (
"bytes"
"encoding/binary"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/keychain"
)
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
// from the switch, and is used to purge our in-memory state of HTLCs that have
// already been processed by a link. Two list of CircuitKeys are included in
// each CommitDiff to allow a link to determine which in-memory htlcs directed
// the opening and closing of circuits in the switch's circuit map.
type CircuitKey struct {
// ChanID is the short chanid indicating the HTLC's origin.
//
// NOTE: It is fine for this value to be blank, as this indicates a
// locally-sourced payment.
ChanID lnwire.ShortChannelID
// HtlcID is the unique htlc index predominately assigned by links,
// though can also be assigned by switch in the case of locally-sourced
// payments.
HtlcID uint64
}
// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are
// contained within ChannelDeltas which encode the current state of the
// commitment between state updates.
//
// TODO(roasbeef): save space by using smaller ints at tail end?
type HTLC struct {
// Signature is the signature for the second level covenant transaction
// for this HTLC. The second level transaction is a timeout tx in the
// case that this is an outgoing HTLC, and a success tx in the case
// that this is an incoming HTLC.
//
// TODO(roasbeef): make [64]byte instead?
Signature []byte
// RHash is the payment hash of the HTLC.
RHash [32]byte
// Amt is the amount of milli-satoshis this HTLC escrows.
Amt lnwire.MilliSatoshi
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout uint32
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
OutputIndex int32
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
Incoming bool
// OnionBlob is an opaque blob which is used to complete multi-hop
// routing.
OnionBlob []byte
// HtlcIndex is the HTLC counter index of this active, outstanding
// HTLC. This differs from the LogIndex, as the HtlcIndex is only
// incremented for each offered HTLC, while they LogIndex is
// incremented for each update (includes settle+fail).
HtlcIndex uint64
// LogIndex is the cumulative log index of this HTLC. This differs
// from the HtlcIndex as this will be incremented for each new log
// update added.
LogIndex uint64
}
// ChannelCommitment is a snapshot of the commitment state at a particular
// point in the commitment chain. With each state transition, a snapshot of the
// current state along with all non-settled HTLCs are recorded. These snapshots
// detail the state of the _remote_ party's commitment at a particular state
// number. For ourselves (the local node) we ONLY store our most recent
// (unrevoked) state for safety purposes.
type ChannelCommitment struct {
// CommitHeight is the update number that this ChannelDelta represents
// the total number of commitment updates to this point. This can be
// viewed as sort of a "commitment height" as this number is
// monotonically increasing.
CommitHeight uint64
// LocalLogIndex is the cumulative log index index of the local node at
// this point in the commitment chain. This value will be incremented
// for each _update_ added to the local update log.
LocalLogIndex uint64
// LocalHtlcIndex is the current local running HTLC index. This value
// will be incremented for each outgoing HTLC the local node offers.
LocalHtlcIndex uint64
// RemoteLogIndex is the cumulative log index index of the remote node
// at this point in the commitment chain. This value will be
// incremented for each _update_ added to the remote update log.
RemoteLogIndex uint64
// RemoteHtlcIndex is the current remote running HTLC index. This value
// will be incremented for each outgoing HTLC the remote node offers.
RemoteHtlcIndex uint64
// LocalBalance is the current available settled balance within the
// channel directly spendable by us.
//
// NOTE: This is the balance *after* subtracting any commitment fee,
// AND anchor output values.
LocalBalance lnwire.MilliSatoshi
// RemoteBalance is the current available settled balance within the
// channel directly spendable by the remote node.
//
// NOTE: This is the balance *after* subtracting any commitment fee,
// AND anchor output values.
RemoteBalance lnwire.MilliSatoshi
// CommitFee is the amount calculated to be paid in fees for the
// current set of commitment transactions. The fee amount is persisted
// with the channel in order to allow the fee amount to be removed and
// recalculated with each channel state update, including updates that
// happen after a system restart.
CommitFee btcutil.Amount
// FeePerKw is the min satoshis/kilo-weight that should be paid within
// the commitment transaction for the entire duration of the channel's
// lifetime. This field may be updated during normal operation of the
// channel as on-chain conditions change.
//
// TODO(halseth): make this SatPerKWeight. Cannot be done atm because
// this will cause the import cycle lnwallet<->channeldb. Fee
// estimation stuff should be in its own package.
FeePerKw btcutil.Amount
// CommitTx is the latest version of the commitment state, broadcast
// able by us.
CommitTx *wire.MsgTx
// CommitSig is one half of the signature required to fully complete
// the script for the commitment transaction above. This is the
// signature signed by the remote party for our version of the
// commitment transactions.
CommitSig []byte
// Htlcs is the set of HTLC's that are pending at this particular
// commitment height.
Htlcs []HTLC
// TODO(roasbeef): pending commit pointer?
// * lets just walk through
}
// LogUpdate represents a pending update to the remote commitment chain. The
// log update may be an add, fail, or settle entry. We maintain this data in
// order to be able to properly retransmit our proposed
// state if necessary.
type LogUpdate struct {
// LogIndex is the log index of this proposed commitment update entry.
LogIndex uint64
// UpdateMsg is the update message that was included within the our
// local update log. The LogIndex value denotes the log index of this
// update which will be used when restoring our local update log if
// we're left with a dangling update on restart.
UpdateMsg lnwire.Message
}
// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID
// is assumed to be that of the packager.
type AddRef struct {
// Height is the remote commitment height that locked in the Add.
Height uint64
// Index is the index of the Add within the fwd pkg's Adds.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A
// channel does not remove its own Settle/Fail htlcs, so the source is provided
// to locate a db bucket belonging to another channel.
type SettleFailRef struct {
// Source identifies the outgoing link that locked in the settle or
// fail. This is then used by the *incoming* link to find the settle
// fail in another link's forwarding packages.
Source lnwire.ShortChannelID
// Height is the remote commitment height that locked in this
// Settle/Fail.
Height uint64
// Index is the index of the Add with the fwd pkg's SettleFails.
//
// NOTE: This index is static over the lifetime of a forwarding package.
Index uint16
}
// CommitDiff represents the delta needed to apply the state transition between
// two subsequent commitment states. Given state N and state N+1, one is able
// to apply the set of messages contained within the CommitDiff to N to arrive
// at state N+1. Each time a new commitment is extended, we'll write a new
// commitment (along with the full commitment state) to disk so we can
// re-transmit the state in the case of a connection loss or message drop.
type CommitDiff struct {
// ChannelCommitment is the full commitment state that one would arrive
// at by applying the set of messages contained in the UpdateDiff to
// the prior accepted commitment.
Commitment ChannelCommitment
// LogUpdates is the set of messages sent prior to the commitment state
// transition in question. Upon reconnection, if we detect that they
// don't have the commitment, then we re-send this along with the
// proper signature.
LogUpdates []LogUpdate
// CommitSig is the exact CommitSig message that should be sent after
// the set of LogUpdates above has been retransmitted. The signatures
// within this message should properly cover the new commitment state
// and also the HTLC's within the new commitment state.
CommitSig *lnwire.CommitSig
// OpenedCircuitKeys is a set of unique identifiers for any downstream
// Add packets included in this commitment txn. After a restart, this
// set of htlcs is acked from the link's incoming mailbox to ensure
// there isn't an attempt to re-add them to this commitment txn.
OpenedCircuitKeys []CircuitKey
// ClosedCircuitKeys records the unique identifiers for any settle/fail
// packets that were resolved by this commitment txn. After a restart,
// this is used to ensure those circuits are removed from the circuit
// map, and the downstream packets in the link's mailbox are removed.
ClosedCircuitKeys []CircuitKey
// AddAcks specifies the locations (commit height, pkg index) of any
// Adds that were failed/settled in this commit diff. This will ack
// entries in *this* channel's forwarding packages.
//
// NOTE: This value is not serialized, it is used to atomically mark the
// resolution of adds, such that they will not be reprocessed after a
// restart.
AddAcks []AddRef
// SettleFailAcks specifies the locations (chan id, commit height, pkg
// index) of any Settles or Fails that were locked into this commit
// diff, and originate from *another* channel, i.e. the outgoing link.
//
// NOTE: This value is not serialized, it is used to atomically acks
// settles and fails from the forwarding packages of other channels,
// such that they will not be reforwarded internally after a restart.
SettleFailAcks []SettleFailRef
}
// NetworkResult is the raw result received from the network after a payment
// attempt has been made. Since the switch doesn't always have the necessary
// data to decode the raw message, we store it together with some meta data,
// and decode it when the router query for the final result.
type NetworkResult struct {
// Msg is the received result. This should be of type UpdateFulfillHTLC
// or UpdateFailHTLC.
Msg lnwire.Message
// unencrypted indicates whether the failure encoded in the message is
// unencrypted, and hence doesn't need to be decrypted.
Unencrypted bool
// IsResolution indicates whether this is a resolution message, in
// which the failure reason might not be included.
IsResolution bool
}
// ClosureType is an enum like structure that details exactly _how_ a channel
// was closed. Three closure types are currently possible: none, cooperative,
// local force close, remote force close, and (remote) breach.
type ClosureType uint8
// ChannelConstraints represents a set of constraints meant to allow a node to
// limit their exposure, enact flow control and ensure that all HTLCs are
// economically relevant. This struct will be mirrored for both sides of the
// channel, as each side will enforce various constraints that MUST be adhered
// to for the life time of the channel. The parameters for each of these
// constraints are static for the duration of the channel, meaning the channel
// must be torn down for them to change.
type ChannelConstraints struct {
// DustLimit is the threshold (in satoshis) below which any outputs
// should be trimmed. When an output is trimmed, it isn't materialized
// as an actual output, but is instead burned to miner's fees.
DustLimit btcutil.Amount
// ChanReserve is an absolute reservation on the channel for the
// owner of this set of constraints. This means that the current
// settled balance for this node CANNOT dip below the reservation
// amount. This acts as a defense against costless attacks when
// either side no longer has any skin in the game.
ChanReserve btcutil.Amount
// MaxPendingAmount is the maximum pending HTLC value that the
// owner of these constraints can offer the remote node at a
// particular time.
MaxPendingAmount lnwire.MilliSatoshi
// MinHTLC is the minimum HTLC value that the owner of these
// constraints can offer the remote node. If any HTLCs below this
// amount are offered, then the HTLC will be rejected. This, in
// tandem with the dust limit allows a node to regulate the
// smallest HTLC that it deems economically relevant.
MinHTLC lnwire.MilliSatoshi
// MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of
// this set of constraints can offer the remote node. This allows each
// node to limit their over all exposure to HTLCs that may need to be
// acted upon in the case of a unilateral channel closure or a contract
// breach.
MaxAcceptedHtlcs uint16
// CsvDelay is the relative time lock delay expressed in blocks. Any
// settled outputs that pay to the owner of this channel configuration
// MUST ensure that the delay branch uses this value as the relative
// time lock. Similarly, any HTLC's offered by this node should use
// this value as well.
CsvDelay uint16
}
// ChannelConfig is a struct that houses the various configuration opens for
// channels. Each side maintains an instance of this configuration file as it
// governs: how the funding and commitment transaction to be created, the
// nature of HTLC's allotted, the keys to be used for delivery, and relative
// time lock parameters.
type ChannelConfig struct {
// ChannelConstraints is the set of constraints that must be upheld for
// the duration of the channel for the owner of this channel
// configuration. Constraints govern a number of flow control related
// parameters, also including the smallest HTLC that will be accepted
// by a participant.
ChannelConstraints
// MultiSigKey is the key to be used within the 2-of-2 output script
// for the owner of this channel config.
MultiSigKey keychain.KeyDescriptor
// RevocationBasePoint is the base public key to be used when deriving
// revocation keys for the remote node's commitment transaction. This
// will be combined along with a per commitment secret to derive a
// unique revocation key for each state.
RevocationBasePoint keychain.KeyDescriptor
// PaymentBasePoint is the base public key to be used when deriving
// the key used within the non-delayed pay-to-self output on the
// commitment transaction for a node. This will be combined with a
// tweak derived from the per-commitment point to ensure unique keys
// for each commitment transaction.
PaymentBasePoint keychain.KeyDescriptor
// DelayBasePoint is the base public key to be used when deriving the
// key used within the delayed pay-to-self output on the commitment
// transaction for a node. This will be combined with a tweak derived
// from the per-commitment point to ensure unique keys for each
// commitment transaction.
DelayBasePoint keychain.KeyDescriptor
// HtlcBasePoint is the base public key to be used when deriving the
// local HTLC key. The derived key (combined with the tweak derived
// from the per-commitment point) is used within the "to self" clause
// within any HTLC output scripts.
HtlcBasePoint keychain.KeyDescriptor
}
// ChannelCloseSummary contains the final state of a channel at the point it
// was closed. Once a channel is closed, all the information pertaining to that
// channel within the openChannelBucket is deleted, and a compact summary is
// put in place instead.
type ChannelCloseSummary struct {
// ChanPoint is the outpoint for this channel's funding transaction,
// and is used as a unique identifier for the channel.
ChanPoint wire.OutPoint
// ShortChanID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
ShortChanID lnwire.ShortChannelID
// ChainHash is the hash of the genesis block that this channel resides
// within.
ChainHash chainhash.Hash
// ClosingTXID is the txid of the transaction which ultimately closed
// this channel.
ClosingTXID chainhash.Hash
// RemotePub is the public key of the remote peer that we formerly had
// a channel with.
RemotePub *btcec.PublicKey
// Capacity was the total capacity of the channel.
Capacity btcutil.Amount
// CloseHeight is the height at which the funding transaction was
// spent.
CloseHeight uint32
// SettledBalance is our total balance settled balance at the time of
// channel closure. This _does not_ include the sum of any outputs that
// have been time-locked as a result of the unilateral channel closure.
SettledBalance btcutil.Amount
// TimeLockedBalance is the sum of all the time-locked outputs at the
// time of channel closure. If we triggered the force closure of this
// channel, then this value will be non-zero if our settled output is
// above the dust limit. If we were on the receiving side of a channel
// force closure, then this value will be non-zero if we had any
// outstanding outgoing HTLC's at the time of channel closure.
TimeLockedBalance btcutil.Amount
// CloseType details exactly _how_ the channel was closed. Five closure
// types are possible: cooperative, local force, remote force, breach
// and funding canceled.
CloseType ClosureType
// IsPending indicates whether this channel is in the 'pending close'
// state, which means the channel closing transaction has been
// confirmed, but not yet been fully resolved. In the case of a channel
// that has been cooperatively closed, it will go straight into the
// fully resolved state as soon as the closing transaction has been
// confirmed. However, for channels that have been force closed, they'll
// stay marked as "pending" until _all_ the pending funds have been
// swept.
IsPending bool
// RemoteCurrentRevocation is the current revocation for their
// commitment transaction. However, since this is the derived public key,
// we don't yet have the private key so we aren't yet able to verify
// that it's actually in the hash chain.
RemoteCurrentRevocation *btcec.PublicKey
// RemoteNextRevocation is the revocation key to be used for the *next*
// commitment transaction we create for the local node. Within the
// specification, this value is referred to as the
// per-commitment-point.
RemoteNextRevocation *btcec.PublicKey
// LocalChanConfig is the channel configuration for the local node.
LocalChanConfig ChannelConfig
// LastChanSyncMsg is the ChannelReestablish message for this channel
// for the state at the point where it was closed.
LastChanSyncMsg *lnwire.ChannelReestablish
}
// FwdState is an enum used to describe the lifecycle of a FwdPkg.
type FwdState byte
const (
// FwdStateLockedIn is the starting state for all forwarding packages.
// Packages in this state have not yet committed to the exact set of
// Adds to forward to the switch.
FwdStateLockedIn FwdState = iota
// FwdStateProcessed marks the state in which all Adds have been
// locally processed and the forwarding decision to the switch has been
// persisted.
FwdStateProcessed
// FwdStateCompleted signals that all Adds have been acked, and that all
// settles and fails have been delivered to their sources. Packages in
// this state can be removed permanently.
FwdStateCompleted
)
// PkgFilter is used to compactly represent a particular subset of the Adds in a
// forwarding package. Each filter is represented as a simple, statically-sized
// bitvector, where the elements are intended to be the indices of the Adds as
// they are written in the FwdPkg.
type PkgFilter struct {
count uint16
filter []byte
}
// NewPkgFilter initializes an empty PkgFilter supporting `count` elements.
func NewPkgFilter(count uint16) *PkgFilter {
// We add 7 to ensure that the integer division yields properly rounded
// values.
filterLen := (count + 7) / 8
return &PkgFilter{
count: count,
filter: make([]byte, filterLen),
}
}
// Count returns the number of elements represented by this PkgFilter.
func (f *PkgFilter) Count() uint16 {
return f.count
}
// Set marks the `i`-th element as included by this filter.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Set(i uint16) {
byt := i / 8
bit := i % 8
// Set the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
f.filter[byt] |= byte(1 << (7 - bit))
}
// Contains queries the filter for membership of index `i`.
// NOTE: It is assumed that i is always less than count.
func (f *PkgFilter) Contains(i uint16) bool {
byt := i / 8
bit := i % 8
// Read the i-th bit in the filter.
// TODO(conner): ignore if > count to prevent panic?
return f.filter[byt]&(1<<(7-bit)) != 0
}
// Equal checks two PkgFilters for equality.
func (f *PkgFilter) Equal(f2 *PkgFilter) bool {
if f == f2 {
return true
}
if f.count != f2.count {
return false
}
return bytes.Equal(f.filter, f2.filter)
}
// IsFull returns true if every element in the filter has been Set, and false
// otherwise.
func (f *PkgFilter) IsFull() bool {
// Batch validate bytes that are fully used.
for i := uint16(0); i < f.count/8; i++ {
if f.filter[i] != 0xFF {
return false
}
}
// If the count is not a multiple of 8, check that the filter contains
// all remaining bits.
rem := f.count % 8
for idx := f.count - rem; idx < f.count; idx++ {
if !f.Contains(idx) {
return false
}
}
return true
}
// Size returns number of bytes produced when the PkgFilter is serialized.
func (f *PkgFilter) Size() uint16 {
// 2 bytes for uint16 `count`, then round up number of bytes required to
// represent `count` bits.
return 2 + (f.count+7)/8
}
// Encode writes the filter to the provided io.Writer.
func (f *PkgFilter) Encode(w io.Writer) error {
if err := binary.Write(w, binary.BigEndian, f.count); err != nil {
return err
}
_, err := w.Write(f.filter)
return err
}
// Decode reads the filter from the provided io.Reader.
func (f *PkgFilter) Decode(r io.Reader) error {
if err := binary.Read(r, binary.BigEndian, &f.count); err != nil {
return err
}
f.filter = make([]byte, f.Size()-2)
_, err := io.ReadFull(r, f.filter)
return err
}
// FwdPkg records all adds, settles, and fails that were locked in as a result
// of the remote peer sending us a revocation. Each package is identified by
// the short chanid and remote commitment height corresponding to the revocation
// that locked in the HTLCs. For everything except a locally initiated payment,
// settles and fails in a forwarding package must have a corresponding Add in
// another package, and can be removed individually once the source link has
// received the fail/settle.
//
// Adds cannot be removed, as we need to present the same batch of Adds to
// properly handle replay protection. Instead, we use a PkgFilter to mark that
// we have finished processing a particular Add. A FwdPkg should only be deleted
// after the AckFilter is full and all settles and fails have been persistently
// removed.
type FwdPkg struct {
// Source identifies the channel that wrote this forwarding package.
Source lnwire.ShortChannelID
// Height is the height of the remote commitment chain that locked in
// this forwarding package.
Height uint64
// State signals the persistent condition of the package and directs how
// to reprocess the package in the event of failures.
State FwdState
// Adds contains all add messages which need to be processed and
// forwarded to the switch. Adds does not change over the life of a
// forwarding package.
Adds []LogUpdate
// FwdFilter is a filter containing the indices of all Adds that were
// forwarded to the switch.
FwdFilter *PkgFilter
// AckFilter is a filter containing the indices of all Adds for which
// the source has received a settle or fail and is reflected in the next
// commitment txn. A package should not be removed until IsFull()
// returns true.
AckFilter *PkgFilter
// SettleFails contains all settle and fail messages that should be
// forwarded to the switch.
SettleFails []LogUpdate
// SettleFailFilter is a filter containing the indices of all Settle or
// Fails originating in this package that have been received and locked
// into the incoming link's commitment state.
SettleFailFilter *PkgFilter
}
// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This
// should be used to create a package at the time we receive a revocation.
func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
addUpdates, settleFailUpdates []LogUpdate) *FwdPkg {
nAddUpdates := uint16(len(addUpdates))
nSettleFailUpdates := uint16(len(settleFailUpdates))
return &FwdPkg{
Source: source,
Height: height,
State: FwdStateLockedIn,
Adds: addUpdates,
FwdFilter: NewPkgFilter(nAddUpdates),
AckFilter: NewPkgFilter(nAddUpdates),
SettleFails: settleFailUpdates,
SettleFailFilter: NewPkgFilter(nSettleFailUpdates),
}
}
package current
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration21/common"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/shachain"
)
var (
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// writeOutpoint writes an outpoint to the passed writer using the minimal
// amount of bytes possible.
func writeOutpoint(w io.Writer, o *wire.OutPoint) error {
if _, err := w.Write(o.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, o.Index); err != nil {
return err
}
return nil
}
// readOutpoint reads an outpoint from the passed reader that was previously
// written using the writeOutpoint struct.
func readOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// NewUnknownElementType creates a new UnknownElementType error from the passed
// method name and element.
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
return UnknownElementType{method: method, element: el}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil {
return err
}
if err := binary.Write(w, byteOrder, e.Index); err != nil {
return err
}
if e.PubKey != nil {
if err := binary.Write(w, byteOrder, true); err != nil {
return fmt.Errorf("error writing serialized "+
"element: %w", err)
}
return WriteElement(w, e.PubKey)
}
return binary.Write(w, byteOrder, false)
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case wire.OutPoint:
return writeOutpoint(w, &e)
case lnwire.ShortChannelID:
if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil {
return err
}
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case int64, uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case int32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint16:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint8:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case btcutil.Amount:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case lnwire.MilliSatoshi:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case *btcec.PrivateKey:
b := e.Serialize()
if _, err := w.Write(b); err != nil {
return err
}
case *btcec.PublicKey:
b := e.SerializeCompressed()
if _, err := w.Write(b); err != nil {
return err
}
case shachain.Producer:
return e.Encode(w)
case shachain.Store:
return e.Encode(w)
case *wire.MsgTx:
return e.Serialize(w)
case [32]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case lnwire.Message:
var msgBuf bytes.Buffer
if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil {
return err
}
msgLen := uint16(len(msgBuf.Bytes()))
if err := WriteElements(w, msgLen); err != nil {
return err
}
if _, err := w.Write(msgBuf.Bytes()); err != nil {
return err
}
case common.ClosureType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case lnwire.FundingFlag:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
default:
return UnknownElementType{"WriteElement", e}
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &e.Index); err != nil {
return err
}
var hasPubKey bool
if err := binary.Read(r, byteOrder, &hasPubKey); err != nil {
return err
}
if hasPubKey {
return ReadElement(r, &e.PubKey)
}
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wire.OutPoint:
return readOutpoint(r, e)
case *lnwire.ShortChannelID:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *int64, *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *int32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint16:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint8:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *btcutil.Amount:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = btcutil.Amount(a)
case *lnwire.MilliSatoshi:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.MilliSatoshi(a)
case **btcec.PrivateKey:
var b [btcec.PrivKeyBytesLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
priv, _ := btcec.PrivKeyFromBytes(b[:])
*e = priv
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case *shachain.Producer:
var root [32]byte
if _, err := io.ReadFull(r, root[:]); err != nil {
return err
}
// TODO(roasbeef): remove
producer, err := shachain.NewRevocationProducerFromBytes(root[:])
if err != nil {
return err
}
*e = producer
case *shachain.Store:
store, err := shachain.NewRevocationStoreFromBytes(r)
if err != nil {
return err
}
*e = store
case **wire.MsgTx:
tx := wire.NewMsgTx(2)
if err := tx.Deserialize(r); err != nil {
return err
}
*e = tx
case *[32]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *lnwire.Message:
var msgLen uint16
if err := ReadElement(r, &msgLen); err != nil {
return err
}
msgReader := io.LimitReader(r, int64(msgLen))
msg, err := lnwire.ReadMessage(msgReader, 0)
if err != nil {
return err
}
*e = msg
case *common.ClosureType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *lnwire.FundingFlag:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
default:
return UnknownElementType{"ReadElement", e}
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
package current
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration21/common"
"github.com/lightningnetwork/lnd/kvdb"
)
func serializeChanCommit(w io.Writer, c *common.ChannelCommitment) error { // nolint: dupl
if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
c.CommitSig,
); err != nil {
return err
}
return serializeHtlcs(w, c.Htlcs...)
}
func SerializeLogUpdates(w io.Writer, logUpdates []common.LogUpdate) error { // nolint: dupl
numUpdates := uint16(len(logUpdates))
if err := binary.Write(w, byteOrder, numUpdates); err != nil {
return err
}
for _, diff := range logUpdates {
err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
if err != nil {
return err
}
}
return nil
}
func serializeHtlcs(b io.Writer, htlcs ...common.HTLC) error { // nolint: dupl
numHtlcs := uint16(len(htlcs))
if err := WriteElement(b, numHtlcs); err != nil {
return err
}
for _, htlc := range htlcs {
if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob,
htlc.HtlcIndex, htlc.LogIndex,
); err != nil {
return err
}
}
return nil
}
func SerializeCommitDiff(w io.Writer, diff *common.CommitDiff) error { // nolint: dupl
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
return err
}
if err := WriteElements(w, diff.CommitSig); err != nil {
return err
}
if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil {
return err
}
numOpenRefs := uint16(len(diff.OpenedCircuitKeys))
if err := binary.Write(w, byteOrder, numOpenRefs); err != nil {
return err
}
for _, openRef := range diff.OpenedCircuitKeys {
err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
if err != nil {
return err
}
}
numClosedRefs := uint16(len(diff.ClosedCircuitKeys))
if err := binary.Write(w, byteOrder, numClosedRefs); err != nil {
return err
}
for _, closedRef := range diff.ClosedCircuitKeys {
err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
if err != nil {
return err
}
}
return nil
}
func deserializeHtlcs(r io.Reader) ([]common.HTLC, error) { // nolint: dupl
var numHtlcs uint16
if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err
}
var htlcs []common.HTLC
if numHtlcs == 0 {
return htlcs, nil
}
htlcs = make([]common.HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ {
if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &htlcs[i].OnionBlob,
&htlcs[i].HtlcIndex, &htlcs[i].LogIndex,
); err != nil {
return htlcs, err
}
}
return htlcs, nil
}
func deserializeChanCommit(r io.Reader) (common.ChannelCommitment, error) { // nolint: dupl
var c common.ChannelCommitment
err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
)
if err != nil {
return c, err
}
c.Htlcs, err = deserializeHtlcs(r)
if err != nil {
return c, err
}
return c, nil
}
func DeserializeLogUpdates(r io.Reader) ([]common.LogUpdate, error) { // nolint: dupl
var numUpdates uint16
if err := binary.Read(r, byteOrder, &numUpdates); err != nil {
return nil, err
}
logUpdates := make([]common.LogUpdate, numUpdates)
for i := 0; i < int(numUpdates); i++ {
err := ReadElements(r,
&logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg,
)
if err != nil {
return nil, err
}
}
return logUpdates, nil
}
func DeserializeCommitDiff(r io.Reader) (*common.CommitDiff, error) { // nolint: dupl
var (
d common.CommitDiff
err error
)
d.Commitment, err = deserializeChanCommit(r)
if err != nil {
return nil, err
}
var msg lnwire.Message
if err := ReadElements(r, &msg); err != nil {
return nil, err
}
commitSig, ok := msg.(*lnwire.CommitSig)
if !ok {
return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+
"read: %T", msg)
}
d.CommitSig = commitSig
d.LogUpdates, err = DeserializeLogUpdates(r)
if err != nil {
return nil, err
}
var numOpenRefs uint16
if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil {
return nil, err
}
d.OpenedCircuitKeys = make([]common.CircuitKey, numOpenRefs)
for i := 0; i < int(numOpenRefs); i++ {
err := ReadElements(r,
&d.OpenedCircuitKeys[i].ChanID,
&d.OpenedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
var numClosedRefs uint16
if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil {
return nil, err
}
d.ClosedCircuitKeys = make([]common.CircuitKey, numClosedRefs)
for i := 0; i < int(numClosedRefs); i++ {
err := ReadElements(r,
&d.ClosedCircuitKeys[i].ChanID,
&d.ClosedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
return &d, nil
}
func SerializeNetworkResult(w io.Writer, n *common.NetworkResult) error { // nolint: dupl
return WriteElements(w, n.Msg, n.Unencrypted, n.IsResolution)
}
func DeserializeNetworkResult(r io.Reader) (*common.NetworkResult, error) { // nolint: dupl
n := &common.NetworkResult{}
if err := ReadElements(r,
&n.Msg, &n.Unencrypted, &n.IsResolution,
); err != nil {
return nil, err
}
return n, nil
}
func writeChanConfig(b io.Writer, c *common.ChannelConfig) error { // nolint: dupl
return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
c.HtlcBasePoint,
)
}
func SerializeChannelCloseSummary(w io.Writer, cs *common.ChannelCloseSummary) error { // nolint: dupl
err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
)
if err != nil {
return err
}
// If this is a close channel summary created before the addition of
// the new fields, then we can exit here.
if cs.RemoteCurrentRevocation == nil {
return WriteElements(w, false)
}
// If fields are present, write boolean to indicate this, and continue.
if err := WriteElements(w, true); err != nil {
return err
}
if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err
}
if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil {
return err
}
// The RemoteNextRevocation field is optional, as it's possible for a
// channel to be closed before we learn of the next unrevoked
// revocation point for the remote party. Write a boolean indicating
// whether this field is present or not.
if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil {
return err
}
// Write the field, if present.
if cs.RemoteNextRevocation != nil {
if err = WriteElements(w, cs.RemoteNextRevocation); err != nil {
return err
}
}
// Write whether the channel sync message is present.
if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil {
return err
}
// Write the channel sync message, if present.
if cs.LastChanSyncMsg != nil {
if err := WriteElements(w, cs.LastChanSyncMsg); err != nil {
return err
}
}
return nil
}
func readChanConfig(b io.Reader, c *common.ChannelConfig) error {
return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint,
&c.PaymentBasePoint, &c.DelayBasePoint,
&c.HtlcBasePoint,
)
}
func DeserializeCloseChannelSummary(r io.Reader) (*common.ChannelCloseSummary, error) { // nolint: dupl
c := &common.ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
var hasNewFields bool
err = ReadElements(r, &hasNewFields)
if err != nil {
return nil, err
}
// If fields are not present, we can return.
if !hasNewFields {
return c, nil
}
// Otherwise read the new fields.
if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil {
return nil, err
}
if err := readChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// funding locked message then this might not be present. A boolean
// indicating whether the field is present will come first.
var hasRemoteNextRevocation bool
err = ReadElements(r, &hasRemoteNextRevocation)
if err != nil {
return nil, err
}
// If this field was written, read it.
if hasRemoteNextRevocation {
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil {
return nil, err
}
}
// Check if we have a channel sync message to read.
var hasChanSyncMsg bool
err = ReadElements(r, &hasChanSyncMsg)
if err == io.EOF {
return c, nil
} else if err != nil {
return nil, err
}
// If a chan sync message is present, read it.
if hasChanSyncMsg {
// We must pass in reference to a lnwire.Message for the codec
// to support it.
var msg lnwire.Message
if err := ReadElements(r, &msg); err != nil {
return nil, err
}
chanSync, ok := msg.(*lnwire.ChannelReestablish)
if !ok {
return nil, errors.New("unable cast db Message to " +
"ChannelReestablish")
}
c.LastChanSyncMsg = chanSync
}
return c, nil
}
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
// package has potentially been mangled.
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
var (
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
fwdPackagesKey = []byte("fwd-packages")
// addBucketKey is the bucket to which all Add log updates are written.
addBucketKey = []byte("add-updates")
// failSettleBucketKey is the bucket to which all Settle/Fail log
// updates are written.
failSettleBucketKey = []byte("fail-settle-updates")
// fwdFilterKey is a key used to write the set of Adds that passed
// validation and are to be forwarded to the switch.
// NOTE: The presence of this key within a forwarding package indicates
// that the package has reached FwdStateProcessed.
fwdFilterKey = []byte("fwd-filter-key")
// ackFilterKey is a key used to access the PkgFilter indicating which
// Adds have received a Settle/Fail. This response may come from a
// number of sources, including: exitHop settle/fails, switch failures,
// chain arbiter interjections, as well as settle/fails from the
// next hop in the route.
ackFilterKey = []byte("ack-filter-key")
// settleFailFilterKey is a key used to access the PkgFilter indicating
// which Settles/Fails in have been received and processed by the link
// that originally received the Add.
settleFailFilterKey = []byte("settle-fail-filter-key")
)
func makeLogKey(updateNum uint64) [8]byte {
var key [8]byte
byteOrder.PutUint64(key[:], updateNum)
return key
}
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
func uint16Key(i uint16) []byte {
key := make([]byte, 2)
byteOrder.PutUint16(key, i)
return key
}
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
// packages. The packager is tied to a particular source channel ID, allowing it
// to create and edit its own packages. Each packager also has the ability to
// remove fail/settle htlcs that correspond to an add contained in one of
// source's packages.
type ChannelPackager struct {
source lnwire.ShortChannelID
}
// NewChannelPackager creates a new packager for a single channel.
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
return &ChannelPackager{
source: source,
}
}
// AddFwdPkg writes a newly locked in forwarding package to disk.
func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *common.FwdPkg) error { // nolint: dupl
fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey)
if err != nil {
return err
}
source := makeLogKey(fwdPkg.Source.ToUint64())
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
if err != nil {
return err
}
heightKey := makeLogKey(fwdPkg.Height)
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
if err != nil {
return err
}
// Write ADD updates we received at this commit height.
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
if err != nil {
return err
}
// Write SETTLE/FAIL updates we received at this commit height.
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
if err != nil {
return err
}
for i := range fwdPkg.Adds {
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
if err != nil {
return err
}
}
// Persist the initialized pkg filter, which will be used to determine
// when we can remove this forwarding package from disk.
var ackFilterBuf bytes.Buffer
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
return err
}
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
return err
}
for i := range fwdPkg.SettleFails {
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
if err != nil {
return err
}
}
var settleFailFilterBuf bytes.Buffer
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
if err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *common.LogUpdate) error {
var b bytes.Buffer
if err := serializeLogUpdate(&b, htlc); err != nil {
return err
}
return bkt.Put(uint16Key(idx), b.Bytes())
}
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
// processed, and returns their deserialized log updates in a map indexed by the
// remote commitment height at which the updates were locked in.
func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) {
return loadChannelFwdPkgs(tx, p.source)
}
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl
fwdPkgBkt := tx.ReadBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil, nil
}
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, nil
}
var heights []uint64
if err := sourceBkt.ForEach(func(k, _ []byte) error {
if len(k) != 8 {
return ErrCorruptedFwdPkg
}
heights = append(heights, byteOrder.Uint64(k))
return nil
}); err != nil {
return nil, err
}
// Load the forwarding package for each retrieved height.
fwdPkgs := make([]*common.FwdPkg, 0, len(heights))
for _, height := range heights {
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
if err != nil {
return nil, err
}
fwdPkgs = append(fwdPkgs, fwdPkg)
}
return fwdPkgs, nil
}
// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the
// appropriate FwdState.
func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID,
height uint64) (*common.FwdPkg, error) {
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.NestedReadBucket(heightKey[:])
if heightBkt == nil {
return nil, ErrCorruptedFwdPkg
}
// Load ADDs from disk.
addBkt := heightBkt.NestedReadBucket(addBucketKey)
if addBkt == nil {
return nil, ErrCorruptedFwdPkg
}
adds, err := loadHtlcs(addBkt)
if err != nil {
return nil, err
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
ackFilterReader := bytes.NewReader(ackFilterBytes)
ackFilter := &common.PkgFilter{}
if err := ackFilter.Decode(ackFilterReader); err != nil {
return nil, err
}
// Load SETTLE/FAILs from disk.
failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey)
if failSettleBkt == nil {
return nil, ErrCorruptedFwdPkg
}
failSettles, err := loadHtlcs(failSettleBkt)
if err != nil {
return nil, err
}
// Load settle fail filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
settleFailFilter := &common.PkgFilter{}
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return nil, err
}
// Initialize the fwding package, which always starts in the
// FwdStateLockedIn. We can determine what state the package was left in
// by examining constraints on the information loaded from disk.
fwdPkg := &common.FwdPkg{
Source: source,
State: common.FwdStateLockedIn,
Height: height,
Adds: adds,
AckFilter: ackFilter,
SettleFails: failSettles,
SettleFailFilter: settleFailFilter,
}
// Check to see if we have written the set exported filter adds to
// disk. If we haven't, processing of this package was never started, or
// failed during the last attempt.
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
if fwdFilterBytes == nil {
nAdds := uint16(len(adds))
fwdPkg.FwdFilter = common.NewPkgFilter(nAdds)
return fwdPkg, nil
}
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
fwdPkg.FwdFilter = &common.PkgFilter{}
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
return nil, err
}
// Otherwise, a complete round of processing was completed, and we
// advance the package to FwdStateProcessed.
fwdPkg.State = common.FwdStateProcessed
// If every add, settle, and fail has been fully acknowledged, we can
// safely set the package's state to FwdStateCompleted, signalling that
// it can be garbage collected.
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
fwdPkg.State = common.FwdStateCompleted
}
return fwdPkg, nil
}
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
// them in order of the indexes they were written under.
func loadHtlcs(bkt kvdb.RBucket) ([]common.LogUpdate, error) {
var htlcs []common.LogUpdate
if err := bkt.ForEach(func(_, v []byte) error {
htlc, err := deserializeLogUpdate(bytes.NewReader(v))
if err != nil {
return err
}
htlcs = append(htlcs, *htlc)
return nil
}); err != nil {
return nil, err
}
return htlcs, nil
}
// serializeLogUpdate writes a log update to the provided io.Writer.
func serializeLogUpdate(w io.Writer, l *common.LogUpdate) error {
return WriteElements(w, l.LogIndex, l.UpdateMsg)
}
// deserializeLogUpdate reads a log update from the provided io.Reader.
func deserializeLogUpdate(r io.Reader) (*common.LogUpdate, error) {
l := &common.LogUpdate{}
if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil {
return nil, err
}
return l, nil
}
package legacy
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration21/common"
"github.com/lightningnetwork/lnd/keychain"
)
var (
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// writeOutpoint writes an outpoint to the passed writer using the minimal
// amount of bytes possible.
func writeOutpoint(w io.Writer, o *wire.OutPoint) error {
if _, err := w.Write(o.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, o.Index); err != nil {
return err
}
return nil
}
// readOutpoint reads an outpoint from the passed reader that was previously
// written using the writeOutpoint struct.
func readOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// NewUnknownElementType creates a new UnknownElementType error from the passed
// method name and element.
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
return UnknownElementType{method: method, element: el}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil {
return err
}
if err := binary.Write(w, byteOrder, e.Index); err != nil {
return err
}
if e.PubKey != nil {
if err := binary.Write(w, byteOrder, true); err != nil {
return fmt.Errorf("error writing serialized "+
"element: %w", err)
}
return WriteElement(w, e.PubKey)
}
return binary.Write(w, byteOrder, false)
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case common.ClosureType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case wire.OutPoint:
return writeOutpoint(w, &e)
case lnwire.ShortChannelID:
if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil {
return err
}
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case int64, uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case int32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint16:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint8:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case btcutil.Amount:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case lnwire.MilliSatoshi:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case *btcec.PublicKey:
b := e.SerializeCompressed()
if _, err := w.Write(b); err != nil {
return err
}
case *wire.MsgTx:
return e.Serialize(w)
case [32]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case lnwire.Message:
if _, err := lnwire.WriteMessage(w, e, 0); err != nil {
return err
}
case lnwire.FundingFlag:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
default:
return UnknownElementType{"WriteElement", e}
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wire.OutPoint:
return readOutpoint(r, e)
case *lnwire.ShortChannelID:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *int64, *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *int32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint16:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint8:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *btcutil.Amount:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = btcutil.Amount(a)
case *lnwire.MilliSatoshi:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.MilliSatoshi(a)
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case **wire.MsgTx:
tx := wire.NewMsgTx(2)
if err := tx.Deserialize(r); err != nil {
return err
}
*e = tx
case *[32]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *lnwire.Message:
msg, err := lnwire.ReadMessage(r, 0)
if err != nil {
return err
}
*e = msg
case *lnwire.FundingFlag:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *common.ClosureType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &e.Index); err != nil {
return err
}
var hasPubKey bool
if err := binary.Read(r, byteOrder, &hasPubKey); err != nil {
return err
}
if hasPubKey {
return ReadElement(r, &e.PubKey)
}
default:
return UnknownElementType{"ReadElement", e}
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
package legacy
import (
"bytes"
"encoding/binary"
"errors"
"io"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration21/common"
"github.com/lightningnetwork/lnd/kvdb"
)
func deserializeHtlcs(r io.Reader) ([]common.HTLC, error) {
var numHtlcs uint16
if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err
}
var htlcs []common.HTLC
if numHtlcs == 0 {
return htlcs, nil
}
htlcs = make([]common.HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ {
if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &htlcs[i].OnionBlob,
&htlcs[i].HtlcIndex, &htlcs[i].LogIndex,
); err != nil {
return htlcs, err
}
}
return htlcs, nil
}
func DeserializeLogUpdates(r io.Reader) ([]common.LogUpdate, error) {
var numUpdates uint16
if err := binary.Read(r, byteOrder, &numUpdates); err != nil {
return nil, err
}
logUpdates := make([]common.LogUpdate, numUpdates)
for i := 0; i < int(numUpdates); i++ {
err := ReadElements(r,
&logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg,
)
if err != nil {
return nil, err
}
}
return logUpdates, nil
}
func deserializeChanCommit(r io.Reader) (common.ChannelCommitment, error) {
var c common.ChannelCommitment
err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
)
if err != nil {
return c, err
}
c.Htlcs, err = deserializeHtlcs(r)
if err != nil {
return c, err
}
return c, nil
}
func DeserializeCommitDiff(r io.Reader) (*common.CommitDiff, error) {
var (
d common.CommitDiff
err error
)
d.Commitment, err = deserializeChanCommit(r)
if err != nil {
return nil, err
}
d.CommitSig = &lnwire.CommitSig{}
if err := d.CommitSig.Decode(r, 0); err != nil {
return nil, err
}
d.LogUpdates, err = DeserializeLogUpdates(r)
if err != nil {
return nil, err
}
var numOpenRefs uint16
if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil {
return nil, err
}
d.OpenedCircuitKeys = make([]common.CircuitKey, numOpenRefs)
for i := 0; i < int(numOpenRefs); i++ {
err := ReadElements(r,
&d.OpenedCircuitKeys[i].ChanID,
&d.OpenedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
var numClosedRefs uint16
if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil {
return nil, err
}
d.ClosedCircuitKeys = make([]common.CircuitKey, numClosedRefs)
for i := 0; i < int(numClosedRefs); i++ {
err := ReadElements(r,
&d.ClosedCircuitKeys[i].ChanID,
&d.ClosedCircuitKeys[i].HtlcID)
if err != nil {
return nil, err
}
}
return &d, nil
}
func serializeHtlcs(b io.Writer, htlcs ...common.HTLC) error {
numHtlcs := uint16(len(htlcs))
if err := WriteElement(b, numHtlcs); err != nil {
return err
}
for _, htlc := range htlcs {
if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob,
htlc.HtlcIndex, htlc.LogIndex,
); err != nil {
return err
}
}
return nil
}
func serializeChanCommit(w io.Writer, c *common.ChannelCommitment) error {
if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
c.CommitSig,
); err != nil {
return err
}
return serializeHtlcs(w, c.Htlcs...)
}
func SerializeLogUpdates(w io.Writer, logUpdates []common.LogUpdate) error {
numUpdates := uint16(len(logUpdates))
if err := binary.Write(w, byteOrder, numUpdates); err != nil {
return err
}
for _, diff := range logUpdates {
err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
if err != nil {
return err
}
}
return nil
}
func SerializeCommitDiff(w io.Writer, diff *common.CommitDiff) error { // nolint: dupl
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
return err
}
if err := diff.CommitSig.Encode(w, 0); err != nil {
return err
}
if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil {
return err
}
numOpenRefs := uint16(len(diff.OpenedCircuitKeys))
if err := binary.Write(w, byteOrder, numOpenRefs); err != nil {
return err
}
for _, openRef := range diff.OpenedCircuitKeys {
err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
if err != nil {
return err
}
}
numClosedRefs := uint16(len(diff.ClosedCircuitKeys))
if err := binary.Write(w, byteOrder, numClosedRefs); err != nil {
return err
}
for _, closedRef := range diff.ClosedCircuitKeys {
err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
if err != nil {
return err
}
}
return nil
}
func DeserializeNetworkResult(r io.Reader) (*common.NetworkResult, error) {
var (
err error
)
n := &common.NetworkResult{}
n.Msg, err = lnwire.ReadMessage(r, 0)
if err != nil {
return nil, err
}
if err := ReadElements(r,
&n.Unencrypted, &n.IsResolution,
); err != nil {
return nil, err
}
return n, nil
}
func SerializeNetworkResult(w io.Writer, n *common.NetworkResult) error {
if _, err := lnwire.WriteMessage(w, n.Msg, 0); err != nil {
return err
}
return WriteElements(w, n.Unencrypted, n.IsResolution)
}
func readChanConfig(b io.Reader, c *common.ChannelConfig) error { // nolint: dupl
return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint,
&c.PaymentBasePoint, &c.DelayBasePoint,
&c.HtlcBasePoint,
)
}
func DeserializeCloseChannelSummary(
r io.Reader) (*common.ChannelCloseSummary, error) { // nolint: dupl
c := &common.ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
var hasNewFields bool
err = ReadElements(r, &hasNewFields)
if err != nil {
return nil, err
}
// If fields are not present, we can return.
if !hasNewFields {
return c, nil
}
// Otherwise read the new fields.
if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil {
return nil, err
}
if err := readChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// funding locked message then this might not be present. A boolean
// indicating whether the field is present will come first.
var hasRemoteNextRevocation bool
err = ReadElements(r, &hasRemoteNextRevocation)
if err != nil {
return nil, err
}
// If this field was written, read it.
if hasRemoteNextRevocation {
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil {
return nil, err
}
}
// Check if we have a channel sync message to read.
var hasChanSyncMsg bool
err = ReadElements(r, &hasChanSyncMsg)
if err == io.EOF {
return c, nil
} else if err != nil {
return nil, err
}
// If a chan sync message is present, read it.
if hasChanSyncMsg {
// We must pass in reference to a lnwire.Message for the codec
// to support it.
msg, err := lnwire.ReadMessage(r, 0)
if err != nil {
return nil, err
}
chanSync, ok := msg.(*lnwire.ChannelReestablish)
if !ok {
return nil, errors.New("unable cast db Message to " +
"ChannelReestablish")
}
c.LastChanSyncMsg = chanSync
}
return c, nil
}
func writeChanConfig(b io.Writer, c *common.ChannelConfig) error { // nolint: dupl
return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
c.HtlcBasePoint,
)
}
func SerializeChannelCloseSummary(w io.Writer, cs *common.ChannelCloseSummary) error {
err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
)
if err != nil {
return err
}
// If this is a close channel summary created before the addition of
// the new fields, then we can exit here.
if cs.RemoteCurrentRevocation == nil {
return WriteElements(w, false)
}
// If fields are present, write boolean to indicate this, and continue.
if err := WriteElements(w, true); err != nil {
return err
}
if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err
}
if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil {
return err
}
// The RemoteNextRevocation field is optional, as it's possible for a
// channel to be closed before we learn of the next unrevoked
// revocation point for the remote party. Write a boolean indicating
// whether this field is present or not.
if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil {
return err
}
// Write the field, if present.
if cs.RemoteNextRevocation != nil {
if err = WriteElements(w, cs.RemoteNextRevocation); err != nil {
return err
}
}
// Write whether the channel sync message is present.
if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil {
return err
}
// Write the channel sync message, if present.
if cs.LastChanSyncMsg != nil {
_, err = lnwire.WriteMessage(w, cs.LastChanSyncMsg, 0)
if err != nil {
return err
}
}
return nil
}
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
// package has potentially been mangled.
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
var (
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
fwdPackagesKey = []byte("fwd-packages")
// addBucketKey is the bucket to which all Add log updates are written.
addBucketKey = []byte("add-updates")
// failSettleBucketKey is the bucket to which all Settle/Fail log
// updates are written.
failSettleBucketKey = []byte("fail-settle-updates")
// fwdFilterKey is a key used to write the set of Adds that passed
// validation and are to be forwarded to the switch.
// NOTE: The presence of this key within a forwarding package indicates
// that the package has reached FwdStateProcessed.
fwdFilterKey = []byte("fwd-filter-key")
// ackFilterKey is a key used to access the PkgFilter indicating which
// Adds have received a Settle/Fail. This response may come from a
// number of sources, including: exitHop settle/fails, switch failures,
// chain arbiter interjections, as well as settle/fails from the
// next hop in the route.
ackFilterKey = []byte("ack-filter-key")
// settleFailFilterKey is a key used to access the PkgFilter indicating
// which Settles/Fails in have been received and processed by the link
// that originally received the Add.
settleFailFilterKey = []byte("settle-fail-filter-key")
)
func makeLogKey(updateNum uint64) [8]byte {
var key [8]byte
byteOrder.PutUint64(key[:], updateNum)
return key
}
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
func uint16Key(i uint16) []byte {
key := make([]byte, 2)
byteOrder.PutUint16(key, i)
return key
}
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
// packages. The packager is tied to a particular source channel ID, allowing it
// to create and edit its own packages. Each packager also has the ability to
// remove fail/settle htlcs that correspond to an add contained in one of
// source's packages.
type ChannelPackager struct {
source lnwire.ShortChannelID
}
// NewChannelPackager creates a new packager for a single channel.
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
return &ChannelPackager{
source: source,
}
}
// AddFwdPkg writes a newly locked in forwarding package to disk.
func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *common.FwdPkg) error { // nolint: dupl
fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey)
if err != nil {
return err
}
source := makeLogKey(fwdPkg.Source.ToUint64())
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
if err != nil {
return err
}
heightKey := makeLogKey(fwdPkg.Height)
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
if err != nil {
return err
}
// Write ADD updates we received at this commit height.
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
if err != nil {
return err
}
// Write SETTLE/FAIL updates we received at this commit height.
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
if err != nil {
return err
}
for i := range fwdPkg.Adds {
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
if err != nil {
return err
}
}
// Persist the initialized pkg filter, which will be used to determine
// when we can remove this forwarding package from disk.
var ackFilterBuf bytes.Buffer
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
return err
}
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
return err
}
for i := range fwdPkg.SettleFails {
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
if err != nil {
return err
}
}
var settleFailFilterBuf bytes.Buffer
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
if err != nil {
return err
}
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
}
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *common.LogUpdate) error {
var b bytes.Buffer
if err := serializeLogUpdate(&b, htlc); err != nil {
return err
}
return bkt.Put(uint16Key(idx), b.Bytes())
}
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
// processed, and returns their deserialized log updates in a map indexed by the
// remote commitment height at which the updates were locked in.
func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) {
return loadChannelFwdPkgs(tx, p.source)
}
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl
fwdPkgBkt := tx.ReadBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil, nil
}
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, nil
}
var heights []uint64
if err := sourceBkt.ForEach(func(k, _ []byte) error {
if len(k) != 8 {
return ErrCorruptedFwdPkg
}
heights = append(heights, byteOrder.Uint64(k))
return nil
}); err != nil {
return nil, err
}
// Load the forwarding package for each retrieved height.
fwdPkgs := make([]*common.FwdPkg, 0, len(heights))
for _, height := range heights {
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
if err != nil {
return nil, err
}
fwdPkgs = append(fwdPkgs, fwdPkg)
}
return fwdPkgs, nil
}
// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the
// appropriate FwdState.
func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID,
height uint64) (*common.FwdPkg, error) {
sourceKey := makeLogKey(source.ToUint64())
sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:])
if sourceBkt == nil {
return nil, ErrCorruptedFwdPkg
}
heightKey := makeLogKey(height)
heightBkt := sourceBkt.NestedReadBucket(heightKey[:])
if heightBkt == nil {
return nil, ErrCorruptedFwdPkg
}
// Load ADDs from disk.
addBkt := heightBkt.NestedReadBucket(addBucketKey)
if addBkt == nil {
return nil, ErrCorruptedFwdPkg
}
adds, err := loadHtlcs(addBkt)
if err != nil {
return nil, err
}
// Load ack filter from disk.
ackFilterBytes := heightBkt.Get(ackFilterKey)
if ackFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
ackFilterReader := bytes.NewReader(ackFilterBytes)
ackFilter := &common.PkgFilter{}
if err := ackFilter.Decode(ackFilterReader); err != nil {
return nil, err
}
// Load SETTLE/FAILs from disk.
failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey)
if failSettleBkt == nil {
return nil, ErrCorruptedFwdPkg
}
failSettles, err := loadHtlcs(failSettleBkt)
if err != nil {
return nil, err
}
// Load settle fail filter from disk.
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
if settleFailFilterBytes == nil {
return nil, ErrCorruptedFwdPkg
}
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
settleFailFilter := &common.PkgFilter{}
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
return nil, err
}
// Initialize the fwding package, which always starts in the
// FwdStateLockedIn. We can determine what state the package was left in
// by examining constraints on the information loaded from disk.
fwdPkg := &common.FwdPkg{
Source: source,
State: common.FwdStateLockedIn,
Height: height,
Adds: adds,
AckFilter: ackFilter,
SettleFails: failSettles,
SettleFailFilter: settleFailFilter,
}
// Check to see if we have written the set exported filter adds to
// disk. If we haven't, processing of this package was never started, or
// failed during the last attempt.
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
if fwdFilterBytes == nil {
nAdds := uint16(len(adds))
fwdPkg.FwdFilter = common.NewPkgFilter(nAdds)
return fwdPkg, nil
}
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
fwdPkg.FwdFilter = &common.PkgFilter{}
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
return nil, err
}
// Otherwise, a complete round of processing was completed, and we
// advance the package to FwdStateProcessed.
fwdPkg.State = common.FwdStateProcessed
// If every add, settle, and fail has been fully acknowledged, we can
// safely set the package's state to FwdStateCompleted, signalling that
// it can be garbage collected.
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
fwdPkg.State = common.FwdStateCompleted
}
return fwdPkg, nil
}
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
// them in order of the indexes they were written under.
func loadHtlcs(bkt kvdb.RBucket) ([]common.LogUpdate, error) {
var htlcs []common.LogUpdate
if err := bkt.ForEach(func(_, v []byte) error {
htlc, err := deserializeLogUpdate(bytes.NewReader(v))
if err != nil {
return err
}
htlcs = append(htlcs, *htlc)
return nil
}); err != nil {
return nil, err
}
return htlcs, nil
}
// serializeLogUpdate writes a log update to the provided io.Writer.
func serializeLogUpdate(w io.Writer, l *common.LogUpdate) error {
return WriteElements(w, l.LogIndex, l.UpdateMsg)
}
// deserializeLogUpdate reads a log update from the provided io.Reader.
func deserializeLogUpdate(r io.Reader) (*common.LogUpdate, error) {
l := &common.LogUpdate{}
if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil {
return nil, err
}
return l, nil
}
package migration21
import (
"bytes"
"encoding/binary"
"fmt"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration21/common"
"github.com/lightningnetwork/lnd/channeldb/migration21/current"
"github.com/lightningnetwork/lnd/channeldb/migration21/legacy"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
byteOrder = binary.BigEndian
// openChanBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
//
// openChan -> nodeID -> chanPoint
//
// TODO(roasbeef): flesh out comment
openChannelBucket = []byte("open-chan-bucket")
// commitDiffKey stores the current pending commitment state we've
// extended to the remote party (if any). Each time we propose a new
// state, we store the information necessary to reconstruct this state
// from the prior commitment. This allows us to resync the remote party
// to their expected state in the case of message loss.
//
// TODO(roasbeef): rename to commit chain?
commitDiffKey = []byte("commit-diff-key")
// unsignedAckedUpdatesKey is an entry in the channel bucket that
// contains the remote updates that we have acked, but not yet signed
// for in one of our remote commits.
unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key")
// remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that
// contains the local updates that the remote party has acked, but
// has not yet signed for in one of their local commits.
remoteUnsignedLocalUpdatesKey = []byte("remote-unsigned-local-updates-key")
// networkResultStoreBucketKey is used for the root level bucket that
// stores the network result for each payment ID.
networkResultStoreBucketKey = []byte("network-result-store-bucket")
// closedChannelBucket stores summarization information concerning
// previously open, but now closed channels.
closedChannelBucket = []byte("closed-chan-bucket")
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
fwdPackagesKey = []byte("fwd-packages")
)
// MigrateDatabaseWireMessages performs a migration in all areas that we
// currently store wire messages without length prefixes. This includes the
// CommitDiff struct, ChannelCloseSummary, LogUpdates, and also the
// networkResult struct as well.
func MigrateDatabaseWireMessages(tx kvdb.RwTx) error {
// The migration will proceed in three phases: we'll need to update any
// pending commit diffs, then any unsigned acked updates for all open
// channels, then finally we'll need to update all the current
// stored network results for payments in the switch.
//
// In this phase, we'll migrate the open channel data.
if err := migrateOpenChanBucket(tx); err != nil {
return err
}
// Next, we'll update all the present close channel summaries as well.
if err := migrateCloseChanSummaries(tx); err != nil {
return err
}
// We'll migrate forwarding packages, which have log updates as part of
// their serialized data.
if err := migrateForwardingPackages(tx); err != nil {
return err
}
// Finally, we'll update the pending network results as well.
return migrateNetworkResults(tx)
}
func migrateOpenChanBucket(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return nil
}
type channelPath struct {
nodePub []byte
chainHash []byte
chanPoint []byte
}
var channelPaths []channelPath
err := openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey, and
// also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
if nodeChanBucket == nil {
return fmt.Errorf("no bucket for node %x", nodePub)
}
// The next layer down is all the chains that this node
// has channels on with us.
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
return fmt.Errorf("unable to read "+
"bucket for chain=%x", chainHash)
}
return chainBucket.ForEach(func(chanPoint, v []byte) error {
// If there's a value, it's not a bucket so
// ignore it.
if v != nil {
return nil
}
channelPaths = append(channelPaths, channelPath{
nodePub: nodePub,
chainHash: chainHash,
chanPoint: chanPoint,
})
return nil
})
})
})
if err != nil {
return err
}
// Now that we have all the paths of the channel we need to migrate,
// we'll update all the state in a distinct step to avoid weird
// behavior from modifying buckets in a ForEach statement.
for _, channelPath := range channelPaths {
// First, we'll extract it from the node's chain bucket.
nodeChanBucket := openChanBucket.NestedReadWriteBucket(
channelPath.nodePub,
)
chainBucket := nodeChanBucket.NestedReadWriteBucket(
channelPath.chainHash,
)
chanBucket := chainBucket.NestedReadWriteBucket(
channelPath.chanPoint,
)
// At this point, we have the channel bucket now, so we'll
// check to see if this channel has a pending commitment or
// not.
commitDiffBytes := chanBucket.Get(commitDiffKey)
if commitDiffBytes != nil {
// Now that we have the commit diff in the _old_
// encoding, we'll write it back to disk using the new
// encoding which has a length prefix in front of the
// CommitSig.
commitDiff, err := legacy.DeserializeCommitDiff(
bytes.NewReader(commitDiffBytes),
)
if err != nil {
return err
}
var b bytes.Buffer
err = current.SerializeCommitDiff(&b, commitDiff)
if err != nil {
return err
}
err = chanBucket.Put(commitDiffKey, b.Bytes())
if err != nil {
return err
}
}
// With the commit diff migrated, we'll now check to see if
// there're any un-acked updates we need to migrate as well.
updateBytes := chanBucket.Get(unsignedAckedUpdatesKey)
if updateBytes != nil {
// We have un-acked updates we need to migrate so we'll
// decode then re-encode them here using the new
// format.
legacyUnackedUpdates, err := legacy.DeserializeLogUpdates(
bytes.NewReader(updateBytes),
)
if err != nil {
return err
}
var b bytes.Buffer
err = current.SerializeLogUpdates(&b, legacyUnackedUpdates)
if err != nil {
return err
}
err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes())
if err != nil {
return err
}
}
// Remote unsigned updates as well.
updateBytes = chanBucket.Get(remoteUnsignedLocalUpdatesKey)
if updateBytes != nil {
legacyUnsignedUpdates, err := legacy.DeserializeLogUpdates(
bytes.NewReader(updateBytes),
)
if err != nil {
return err
}
var b bytes.Buffer
err = current.SerializeLogUpdates(&b, legacyUnsignedUpdates)
if err != nil {
return err
}
err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b.Bytes())
if err != nil {
return err
}
}
}
return nil
}
func migrateCloseChanSummaries(tx kvdb.RwTx) error {
closedChanBucket := tx.ReadWriteBucket(closedChannelBucket)
// Exit early if bucket is not found.
if closedChannelBucket == nil {
return nil
}
type closedChan struct {
chanKey []byte
summaryBytes []byte
}
var closedChans []closedChan
err := closedChanBucket.ForEach(func(k, v []byte) error {
closedChans = append(closedChans, closedChan{
chanKey: k,
summaryBytes: v,
})
return nil
})
if err != nil {
return err
}
for _, closedChan := range closedChans {
oldSummary, err := legacy.DeserializeCloseChannelSummary(
bytes.NewReader(closedChan.summaryBytes),
)
if err != nil {
return err
}
var newSummaryBytes bytes.Buffer
err = current.SerializeChannelCloseSummary(
&newSummaryBytes, oldSummary,
)
if err != nil {
return err
}
err = closedChanBucket.Put(
closedChan.chanKey, newSummaryBytes.Bytes(),
)
if err != nil {
return err
}
}
return nil
}
func migrateForwardingPackages(tx kvdb.RwTx) error {
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
// Exit early if bucket is not found.
if fwdPkgBkt == nil {
return nil
}
// We Go through the bucket and fetches all short channel IDs.
var sources []lnwire.ShortChannelID
err := fwdPkgBkt.ForEach(func(k, v []byte) error {
source := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(k))
sources = append(sources, source)
return nil
})
if err != nil {
return err
}
// Now load all forwarding packages using the legacy encoding.
var pkgsToMigrate []*common.FwdPkg
for _, source := range sources {
packager := legacy.NewChannelPackager(source)
fwdPkgs, err := packager.LoadFwdPkgs(tx)
if err != nil {
return err
}
pkgsToMigrate = append(pkgsToMigrate, fwdPkgs...)
}
// Add back the packages using the current encoding.
for _, pkg := range pkgsToMigrate {
packager := current.NewChannelPackager(pkg.Source)
err := packager.AddFwdPkg(tx, pkg)
if err != nil {
return err
}
}
return nil
}
func migrateNetworkResults(tx kvdb.RwTx) error {
networkResults := tx.ReadWriteBucket(networkResultStoreBucketKey)
// Exit early if bucket is not found.
if networkResults == nil {
return nil
}
// Similar to the prior migrations, we'll do this one in two phases:
// we'll first grab all the keys we need to migrate in one loop, then
// update them all in another loop.
var netResultsToMigrate [][2][]byte
err := networkResults.ForEach(func(k, v []byte) error {
netResultsToMigrate = append(netResultsToMigrate, [2][]byte{
k, v,
})
return nil
})
if err != nil {
return err
}
for _, netResult := range netResultsToMigrate {
resKey := netResult[0]
resBytes := netResult[1]
oldResult, err := legacy.DeserializeNetworkResult(
bytes.NewReader(resBytes),
)
if err != nil {
return err
}
var newResultBuf bytes.Buffer
err = current.SerializeNetworkResult(&newResultBuf, oldResult)
if err != nil {
return err
}
err = networkResults.Put(resKey, newResultBuf.Bytes())
if err != nil {
return err
}
}
return nil
}
package migration23
import (
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// paymentsRootBucket is the name of the top-level bucket within the
// database that stores all data related to payments.
paymentsRootBucket = []byte("payments-root-bucket")
// paymentHtlcsBucket is a bucket where we'll store the information
// about the HTLCs that were attempted for a payment.
paymentHtlcsBucket = []byte("payment-htlcs-bucket")
// oldAttemptInfoKey is a key used in a HTLC's sub-bucket to store the
// info about the attempt that was done for the HTLC in question.
oldAttemptInfoKey = []byte("htlc-attempt-info")
// oldSettleInfoKey is a key used in a HTLC's sub-bucket to store the
// settle info, if any.
oldSettleInfoKey = []byte("htlc-settle-info")
// oldFailInfoKey is a key used in a HTLC's sub-bucket to store
// failure information, if any.
oldFailInfoKey = []byte("htlc-fail-info")
// htlcAttemptInfoKey is the key used as the prefix of an HTLC attempt
// to store the info about the attempt that was done for the HTLC in
// question. The HTLC attempt ID is concatenated at the end.
htlcAttemptInfoKey = []byte("ai")
// htlcSettleInfoKey is the key used as the prefix of an HTLC attempt
// settle info, if any. The HTLC attempt ID is concatenated at the end.
htlcSettleInfoKey = []byte("si")
// htlcFailInfoKey is the key used as the prefix of an HTLC attempt
// failure information, if any.The HTLC attempt ID is concatenated at
// the end.
htlcFailInfoKey = []byte("fi")
)
// htlcBucketKey creates a composite key from prefix and id where the result is
// simply the two concatenated. This is the exact copy from payments.go.
func htlcBucketKey(prefix, id []byte) []byte {
key := make([]byte, len(prefix)+len(id))
copy(key, prefix)
copy(key[len(prefix):], id)
return key
}
// MigrateHtlcAttempts will gather all htlc-attempt-info's, htlcs-settle-info's
// and htlcs-fail-info's from the attempt ID buckes and re-store them using the
// flattened keys to each payment's payment-htlcs-bucket.
func MigrateHtlcAttempts(tx kvdb.RwTx) error {
payments := tx.ReadWriteBucket(paymentsRootBucket)
if payments == nil {
return nil
}
// Collect all payment hashes so we can migrate payments one-by-one to
// avoid any bugs bbolt might have when invalidating cursors.
// For 100 million payments, this would need about 3 GiB memory so we
// should hopefully be fine for very large nodes too.
var paymentHashes []string
if err := payments.ForEach(func(hash, v []byte) error {
// Get the bucket which contains the payment, fail if the key
// does not have a bucket.
bucket := payments.NestedReadBucket(hash)
if bucket == nil {
return fmt.Errorf("key must be a bucket: '%v'",
string(paymentsRootBucket))
}
paymentHashes = append(paymentHashes, string(hash))
return nil
}); err != nil {
return err
}
for _, paymentHash := range paymentHashes {
payment := payments.NestedReadWriteBucket([]byte(paymentHash))
if payment.Get(paymentHtlcsBucket) != nil {
return fmt.Errorf("key must be a bucket: '%v'",
string(paymentHtlcsBucket))
}
htlcs := payment.NestedReadWriteBucket(paymentHtlcsBucket)
if htlcs == nil {
// Nothing to migrate for this payment.
continue
}
if err := migrateHtlcsBucket(htlcs); err != nil {
return err
}
}
return nil
}
// migrateHtlcsBucket is a helper to gather, transform and re-store htlc attempt
// key/values.
func migrateHtlcsBucket(htlcs kvdb.RwBucket) error {
// Collect attempt ids so that we can migrate attempts one-by-one
// to avoid any bugs bbolt might have when invalidating cursors.
var aids []string
// First we collect all htlc attempt ids.
if err := htlcs.ForEach(func(aid, v []byte) error {
aids = append(aids, string(aid))
return nil
}); err != nil {
return err
}
// Next we go over these attempts, fetch all data and migrate.
for _, aid := range aids {
aidKey := []byte(aid)
attempt := htlcs.NestedReadWriteBucket(aidKey)
if attempt == nil {
return fmt.Errorf("non bucket element '%v' in '%v' "+
"bucket", aidKey, string(paymentHtlcsBucket))
}
// Collect attempt/settle/fail infos.
attemptInfo := attempt.Get(oldAttemptInfoKey)
if len(attemptInfo) > 0 {
newKey := htlcBucketKey(htlcAttemptInfoKey, aidKey)
if err := htlcs.Put(newKey, attemptInfo); err != nil {
return err
}
}
settleInfo := attempt.Get(oldSettleInfoKey)
if len(settleInfo) > 0 {
newKey := htlcBucketKey(htlcSettleInfoKey, aidKey)
if err := htlcs.Put(newKey, settleInfo); err != nil {
return err
}
}
failInfo := attempt.Get(oldFailInfoKey)
if len(failInfo) > 0 {
newKey := htlcBucketKey(htlcFailInfoKey, aidKey)
if err := htlcs.Put(newKey, failInfo); err != nil {
return err
}
}
}
// Finally we delete old attempt buckets.
for _, aid := range aids {
if err := htlcs.DeleteNestedBucket([]byte(aid)); err != nil {
return err
}
}
return nil
}
package migration24
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration24
import (
"bytes"
"encoding/binary"
"io"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// closedChannelBucket stores summarization information concerning
// previously open, but now closed channels.
closedChannelBucket = []byte("closed-chan-bucket")
// fwdPackagesKey is the root-level bucket that all forwarding packages
// are written. This bucket is further subdivided based on the short
// channel ID of each channel.
fwdPackagesKey = []byte("fwd-packages")
)
// MigrateFwdPkgCleanup deletes all the forwarding packages of closed channels.
// It determines the closed channels by iterating closedChannelBucket. The
// ShortChanID found in the ChannelCloseSummary is then used as a key to query
// the forwarding packages bucket. If a match is found, it will be deleted.
func MigrateFwdPkgCleanup(tx kvdb.RwTx) error {
log.Infof("Deleting forwarding packages for closed channels")
// Find all closed channel summaries, which are stored in the
// closeBucket.
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return nil
}
var chanSummaries []*mig.ChannelCloseSummary
// appendSummary is a function closure to help put deserialized close
// summeries into chanSummaries.
appendSummary := func(_ []byte, summaryBytes []byte) error {
summaryReader := bytes.NewReader(summaryBytes)
chanSummary, err := deserializeCloseChannelSummary(
summaryReader,
)
if err != nil {
return err
}
// Skip pending channels
if chanSummary.IsPending {
return nil
}
chanSummaries = append(chanSummaries, chanSummary)
return nil
}
if err := closeBucket.ForEach(appendSummary); err != nil {
return err
}
// Now we will load the forwarding packages bucket, delete all the
// nested buckets whose source matches the ShortChanID found in the
// closed channel summeraries.
fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey)
if fwdPkgBkt == nil {
return nil
}
// Iterate over all close channels and remove their forwarding packages.
for _, summery := range chanSummaries {
sourceBytes := MakeLogKey(summery.ShortChanID.ToUint64())
// First, we will try to find the corresponding bucket. If there
// is not a nested bucket matching the ShortChanID, we will skip
// it.
if fwdPkgBkt.NestedReadBucket(sourceBytes[:]) == nil {
continue
}
// Otherwise, wipe all the forwarding packages.
if err := fwdPkgBkt.DeleteNestedBucket(
sourceBytes[:],
); err != nil {
return err
}
}
log.Infof("Deletion of forwarding packages of closed channels " +
"complete! DB compaction is recommended to free up the" +
"disk space.")
return nil
}
// deserializeCloseChannelSummary will decode a CloseChannelSummary with no
// optional fields.
func deserializeCloseChannelSummary(
r io.Reader) (*mig.ChannelCloseSummary, error) {
c := &mig.ChannelCloseSummary{}
err := mig.ReadElements(
r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
return c, nil
}
// makeLogKey converts a uint64 into an 8 byte array.
func MakeLogKey(updateNum uint64) [8]byte {
var key [8]byte
binary.BigEndian.PutUint64(key[:], updateNum)
return key
}
package migration25
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
mig24 "github.com/lightningnetwork/lnd/channeldb/migration24"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
)
var (
// chanCommitmentKey can be accessed within the sub-bucket for a
// particular channel. This key stores the up to date commitment state
// for a particular channel party. Appending a 0 to the end of this key
// indicates it's the commitment for the local party, and appending a 1
// to the end of this key indicates it's the commitment for the remote
// party.
chanCommitmentKey = []byte("chan-commitment-key")
// revocationLogBucketLegacy is the legacy bucket where we store the
// revocation log in old format.
revocationLogBucketLegacy = []byte("revocation-log-key")
// localUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the local peer.
localUpfrontShutdownKey = []byte("local-upfront-shutdown-key")
// remoteUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the remote peer.
remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key")
// lastWasRevokeKey is a key that stores true when the last update we
// sent was a revocation and false when it was a commitment signature.
// This is nil in the case of new channels with no updates exchanged.
lastWasRevokeKey = []byte("last-was-revoke")
// ErrNoChanInfoFound is returned when a particular channel does not
// have any channels state.
ErrNoChanInfoFound = fmt.Errorf("no chan info found")
// ErrNoPastDeltas is returned when the channel delta bucket hasn't been
// created.
ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas")
// ErrLogEntryNotFound is returned when we cannot find a log entry at
// the height requested in the revocation log.
ErrLogEntryNotFound = fmt.Errorf("log entry not found")
// ErrNoCommitmentsFound is returned when a channel has not set
// commitment states.
ErrNoCommitmentsFound = fmt.Errorf("no commitments found")
)
// ChannelType is an enum-like type that describes one of several possible
// channel types. Each open channel is associated with a particular type as the
// channel type may determine how higher level operations are conducted such as
// fee negotiation, channel closing, the format of HTLCs, etc. Structure-wise,
// a ChannelType is a bit field, with each bit denoting a modification from the
// base channel type of single funder.
type ChannelType uint8
const (
// NOTE: iota isn't used here for this enum needs to be stable
// long-term as it will be persisted to the database.
// SingleFunderBit represents a channel wherein one party solely funds
// the entire capacity of the channel.
SingleFunderBit ChannelType = 0
// DualFunderBit represents a channel wherein both parties contribute
// funds towards the total capacity of the channel. The channel may be
// funded symmetrically or asymmetrically.
DualFunderBit ChannelType = 1 << 0
// SingleFunderTweaklessBit is similar to the basic SingleFunder channel
// type, but it omits the tweak for one's key in the commitment
// transaction of the remote party.
SingleFunderTweaklessBit ChannelType = 1 << 1
// NoFundingTxBit denotes if we have the funding transaction locally on
// disk. This bit may be on if the funding transaction was crafted by a
// wallet external to the primary daemon.
NoFundingTxBit ChannelType = 1 << 2
// AnchorOutputsBit indicates that the channel makes use of anchor
// outputs to bump the commitment transaction's effective feerate. This
// channel type also uses a delayed to_remote output script.
AnchorOutputsBit ChannelType = 1 << 3
// FrozenBit indicates that the channel is a frozen channel, meaning
// that only the responder can decide to cooperatively close the
// channel.
FrozenBit ChannelType = 1 << 4
// ZeroHtlcTxFeeBit indicates that the channel should use zero-fee
// second-level HTLC transactions.
ZeroHtlcTxFeeBit ChannelType = 1 << 5
// LeaseExpirationBit indicates that the channel has been leased for a
// period of time, constraining every output that pays to the channel
// initiator with an additional CLTV of the lease maturity.
LeaseExpirationBit ChannelType = 1 << 6
)
// IsSingleFunder returns true if the channel type if one of the known single
// funder variants.
func (c ChannelType) IsSingleFunder() bool {
return c&DualFunderBit == 0
}
// IsDualFunder returns true if the ChannelType has the DualFunderBit set.
func (c ChannelType) IsDualFunder() bool {
return c&DualFunderBit == DualFunderBit
}
// IsTweakless returns true if the target channel uses a commitment that
// doesn't tweak the key for the remote party.
func (c ChannelType) IsTweakless() bool {
return c&SingleFunderTweaklessBit == SingleFunderTweaklessBit
}
// HasFundingTx returns true if this channel type is one that has a funding
// transaction stored locally.
func (c ChannelType) HasFundingTx() bool {
return c&NoFundingTxBit == 0
}
// HasAnchors returns true if this channel type has anchor outputs on its
// commitment.
func (c ChannelType) HasAnchors() bool {
return c&AnchorOutputsBit == AnchorOutputsBit
}
// ZeroHtlcTxFee returns true if this channel type uses second-level HTLC
// transactions signed with zero-fee.
func (c ChannelType) ZeroHtlcTxFee() bool {
return c&ZeroHtlcTxFeeBit == ZeroHtlcTxFeeBit
}
// IsFrozen returns true if the channel is considered to be "frozen". A frozen
// channel means that only the responder can initiate a cooperative channel
// closure.
func (c ChannelType) IsFrozen() bool {
return c&FrozenBit == FrozenBit
}
// HasLeaseExpiration returns true if the channel originated from a lease.
func (c ChannelType) HasLeaseExpiration() bool {
return c&LeaseExpirationBit == LeaseExpirationBit
}
// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in
// the default usable state, or a state where it shouldn't be used.
type ChannelStatus uint8
var (
// ChanStatusDefault is the normal state of an open channel.
ChanStatusDefault ChannelStatus
// ChanStatusBorked indicates that the channel has entered an
// irreconcilable state, triggered by a state desynchronization or
// channel breach. Channels in this state should never be added to the
// htlc switch.
ChanStatusBorked ChannelStatus = 1
// ChanStatusCommitBroadcasted indicates that a commitment for this
// channel has been broadcasted.
ChanStatusCommitBroadcasted ChannelStatus = 1 << 1
// ChanStatusLocalDataLoss indicates that we have lost channel state
// for this channel, and broadcasting our latest commitment might be
// considered a breach.
//
// TODO(halseh): actually enforce that we are not force closing such a
// channel.
ChanStatusLocalDataLoss ChannelStatus = 1 << 2
// ChanStatusRestored is a status flag that signals that the channel
// has been restored, and doesn't have all the fields a typical channel
// will have.
ChanStatusRestored ChannelStatus = 1 << 3
// ChanStatusCoopBroadcasted indicates that a cooperative close for
// this channel has been broadcasted. Older cooperatively closed
// channels will only have this status set. Newer ones will also have
// close initiator information stored using the local/remote initiator
// status. This status is set in conjunction with the initiator status
// so that we do not need to check multiple channel statues for
// cooperative closes.
ChanStatusCoopBroadcasted ChannelStatus = 1 << 4
// ChanStatusLocalCloseInitiator indicates that we initiated closing
// the channel.
ChanStatusLocalCloseInitiator ChannelStatus = 1 << 5
// ChanStatusRemoteCloseInitiator indicates that the remote node
// initiated closing the channel.
ChanStatusRemoteCloseInitiator ChannelStatus = 1 << 6
)
// chanStatusStrings maps a ChannelStatus to a human friendly string that
// describes that status.
var chanStatusStrings = map[ChannelStatus]string{
ChanStatusDefault: "ChanStatusDefault",
ChanStatusBorked: "ChanStatusBorked",
ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted",
ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss",
ChanStatusRestored: "ChanStatusRestored",
ChanStatusCoopBroadcasted: "ChanStatusCoopBroadcasted",
ChanStatusLocalCloseInitiator: "ChanStatusLocalCloseInitiator",
ChanStatusRemoteCloseInitiator: "ChanStatusRemoteCloseInitiator",
}
// orderedChanStatusFlags is an in-order list of all that channel status flags.
var orderedChanStatusFlags = []ChannelStatus{
ChanStatusBorked,
ChanStatusCommitBroadcasted,
ChanStatusLocalDataLoss,
ChanStatusRestored,
ChanStatusCoopBroadcasted,
ChanStatusLocalCloseInitiator,
ChanStatusRemoteCloseInitiator,
}
// String returns a human-readable representation of the ChannelStatus.
func (c ChannelStatus) String() string {
// If no flags are set, then this is the default case.
if c == ChanStatusDefault {
return chanStatusStrings[ChanStatusDefault]
}
// Add individual bit flags.
statusStr := ""
for _, flag := range orderedChanStatusFlags {
if c&flag == flag {
statusStr += chanStatusStrings[flag] + "|"
c -= flag
}
}
// Remove anything to the right of the final bar, including it as well.
statusStr = strings.TrimRight(statusStr, "|")
// Add any remaining flags which aren't accounted for as hex.
if c != 0 {
statusStr += "|0x" + strconv.FormatUint(uint64(c), 16)
}
// If this was purely an unknown flag, then remove the extra bar at the
// start of the string.
statusStr = strings.TrimLeft(statusStr, "|")
return statusStr
}
// OpenChannel embeds a mig.OpenChannel with the extra update-to-date fields.
//
// NOTE: doesn't have the Packager field as it's not used in current migration.
type OpenChannel struct {
mig.OpenChannel
// ChanType denotes which type of channel this is.
ChanType ChannelType
// ChanStatus is the current status of this channel. If it is not in
// the state Default, it should not be used for forwarding payments.
//
// NOTE: In `channeldb.OpenChannel`, this field is private. We choose
// to export this private field such that following migrations can
// access this field directly.
ChanStatus ChannelStatus
// InitialLocalBalance is the balance we have during the channel
// opening. When we are not the initiator, this value represents the
// push amount.
InitialLocalBalance lnwire.MilliSatoshi
// InitialRemoteBalance is the balance they have during the channel
// opening.
InitialRemoteBalance lnwire.MilliSatoshi
// LocalShutdownScript is set to a pre-set script if the channel was
// opened by the local node with option_upfront_shutdown_script set. If
// the option was not set, the field is empty.
LocalShutdownScript lnwire.DeliveryAddress
// RemoteShutdownScript is set to a pre-set script if the channel was
// opened by the remote node with option_upfront_shutdown_script set.
// If the option was not set, the field is empty.
RemoteShutdownScript lnwire.DeliveryAddress
// ThawHeight is the height when a frozen channel once again becomes a
// normal channel. If this is zero, then there're no restrictions on
// this channel. If the value is lower than 500,000, then it's
// interpreted as a relative height, or an absolute height otherwise.
ThawHeight uint32
// LastWasRevoke is a boolean that determines if the last update we
// sent was a revocation (true) or a commitment signature (false).
LastWasRevoke bool
// RevocationKeyLocator stores the KeyLocator information that we will
// need to derive the shachain root for this channel. This allows us to
// have private key isolation from lnd.
RevocationKeyLocator keychain.KeyLocator
}
func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool {
// Special case ChanStatusDefualt since it isn't actually flag, but a
// particular combination (or lack-there-of) of flags.
if status == ChanStatusDefault {
return c.ChanStatus == ChanStatusDefault
}
return c.ChanStatus&status == status
}
// FundingTxPresent returns true if expect the funding transcation to be found
// on disk or already populated within the passed open channel struct.
func (c *OpenChannel) FundingTxPresent() bool {
chanType := c.ChanType
return chanType.IsSingleFunder() && chanType.HasFundingTx() &&
c.IsInitiator &&
!c.hasChanStatus(ChanStatusRestored)
}
// fetchChanInfo deserializes the channel info based on the legacy boolean.
func fetchChanInfo(chanBucket kvdb.RBucket, c *OpenChannel, legacy bool) error {
infoBytes := chanBucket.Get(chanInfoKey)
if infoBytes == nil {
return ErrNoChanInfoFound
}
r := bytes.NewReader(infoBytes)
var (
chanType mig.ChannelType
chanStatus mig.ChannelStatus
)
if err := mig.ReadElements(r,
&chanType, &c.ChainHash, &c.FundingOutpoint,
&c.ShortChannelID, &c.IsPending, &c.IsInitiator,
&chanStatus, &c.FundingBroadcastHeight,
&c.NumConfsRequired, &c.ChannelFlags,
&c.IdentityPub, &c.Capacity, &c.TotalMSatSent,
&c.TotalMSatReceived,
); err != nil {
return err
}
c.ChanType = ChannelType(chanType)
c.ChanStatus = ChannelStatus(chanStatus)
// If this is not the legacy format, we need to read the extra two new
// fields.
if !legacy {
if err := mig.ReadElements(r,
&c.InitialLocalBalance, &c.InitialRemoteBalance,
); err != nil {
return err
}
}
// For single funder channels that we initiated and have the funding
// transaction to, read the funding txn.
if c.FundingTxPresent() {
if err := mig.ReadElement(r, &c.FundingTxn); err != nil {
return err
}
}
if err := mig.ReadChanConfig(r, &c.LocalChanCfg); err != nil {
return err
}
if err := mig.ReadChanConfig(r, &c.RemoteChanCfg); err != nil {
return err
}
// Retrieve the boolean stored under lastWasRevokeKey.
lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey)
if lastWasRevokeBytes == nil {
// If nothing has been stored under this key, we store false in
// the OpenChannel struct.
c.LastWasRevoke = false
} else {
// Otherwise, read the value into the LastWasRevoke field.
revokeReader := bytes.NewReader(lastWasRevokeBytes)
err := mig.ReadElements(revokeReader, &c.LastWasRevoke)
if err != nil {
return err
}
}
keyLocRecord := MakeKeyLocRecord(keyLocType, &c.RevocationKeyLocator)
tlvStream, err := tlv.NewStream(keyLocRecord)
if err != nil {
return err
}
if err := tlvStream.Decode(r); err != nil {
return err
}
// Finally, read the optional shutdown scripts.
if err := GetOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, &c.LocalShutdownScript,
); err != nil {
return err
}
return GetOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, &c.RemoteShutdownScript,
)
}
// fetchChanInfo serializes the channel info based on the legacy boolean and
// saves it to disk.
func putChanInfo(chanBucket kvdb.RwBucket, c *OpenChannel, legacy bool) error {
var w bytes.Buffer
if err := mig.WriteElements(&w,
mig.ChannelType(c.ChanType), c.ChainHash, c.FundingOutpoint,
c.ShortChannelID, c.IsPending, c.IsInitiator,
mig.ChannelStatus(c.ChanStatus), c.FundingBroadcastHeight,
c.NumConfsRequired, c.ChannelFlags,
c.IdentityPub, c.Capacity, c.TotalMSatSent,
c.TotalMSatReceived,
); err != nil {
return err
}
// If this is not legacy format, we need to write the extra two fields.
if !legacy {
if err := mig.WriteElements(&w,
c.InitialLocalBalance, c.InitialRemoteBalance,
); err != nil {
return err
}
}
// For single funder channels that we initiated, and we have the
// funding transaction, then write the funding txn.
if c.FundingTxPresent() {
if err := mig.WriteElement(&w, c.FundingTxn); err != nil {
return err
}
}
if err := mig.WriteChanConfig(&w, &c.LocalChanCfg); err != nil {
return err
}
if err := mig.WriteChanConfig(&w, &c.RemoteChanCfg); err != nil {
return err
}
// Write the RevocationKeyLocator as the first entry in a tlv stream.
keyLocRecord := MakeKeyLocRecord(
keyLocType, &c.RevocationKeyLocator,
)
tlvStream, err := tlv.NewStream(keyLocRecord)
if err != nil {
return err
}
if err := tlvStream.Encode(&w); err != nil {
return err
}
if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
return err
}
// Finally, add optional shutdown scripts for the local and remote peer
// if they are present.
if err := PutOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, c.LocalShutdownScript,
); err != nil {
return err
}
return PutOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, c.RemoteShutdownScript,
)
}
// EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok {
err := tlv.EUint32T(w, uint32(v.Family), buf)
if err != nil {
return err
}
return tlv.EUint32T(w, v.Index, buf)
}
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
}
// DKeyLocator is a decoder for keychain.KeyLocator.
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*keychain.KeyLocator); ok {
var family uint32
err := tlv.DUint32(r, &family, buf, 4)
if err != nil {
return err
}
v.Family = keychain.KeyFamily(family)
return tlv.DUint32(r, &v.Index, buf, 4)
}
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
}
// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
// 8 as KeyFamily is uint32 and the Index is uint32.
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
}
// PutOptionalUpfrontShutdownScript adds a shutdown script under the key
// provided if it has a non-zero length.
func PutOptionalUpfrontShutdownScript(chanBucket kvdb.RwBucket, key []byte,
script []byte) error {
// If the script is empty, we do not need to add anything.
if len(script) == 0 {
return nil
}
var w bytes.Buffer
if err := mig.WriteElement(&w, script); err != nil {
return err
}
return chanBucket.Put(key, w.Bytes())
}
// GetOptionalUpfrontShutdownScript reads the shutdown script stored under the
// key provided if it is present. Upfront shutdown scripts are optional, so the
// function returns with no error if the key is not present.
func GetOptionalUpfrontShutdownScript(chanBucket kvdb.RBucket, key []byte,
script *lnwire.DeliveryAddress) error {
// Return early if the bucket does not exit, a shutdown script was not
// set.
bs := chanBucket.Get(key)
if bs == nil {
return nil
}
var tempScript []byte
r := bytes.NewReader(bs)
if err := mig.ReadElement(r, &tempScript); err != nil {
return err
}
*script = tempScript
return nil
}
// FetchChanCommitments fetches both the local and remote commitments. This
// function is exported so it can be used by later migrations.
func FetchChanCommitments(chanBucket kvdb.RBucket, channel *OpenChannel) error {
var err error
// If this is a restored channel, then we don't have any commitments to
// read.
if channel.hasChanStatus(ChanStatusRestored) {
return nil
}
channel.LocalCommitment, err = FetchChanCommitment(chanBucket, true)
if err != nil {
return err
}
channel.RemoteCommitment, err = FetchChanCommitment(chanBucket, false)
if err != nil {
return err
}
return nil
}
// FetchChanCommitment fetches a channel commitment. This function is exported
// so it can be used by later migrations.
func FetchChanCommitment(chanBucket kvdb.RBucket,
local bool) (mig.ChannelCommitment, error) {
commitKey := chanCommitmentKey
if local {
commitKey = append(commitKey, byte(0x00))
} else {
commitKey = append(commitKey, byte(0x01))
}
commitBytes := chanBucket.Get(commitKey)
if commitBytes == nil {
return mig.ChannelCommitment{}, ErrNoCommitmentsFound
}
r := bytes.NewReader(commitBytes)
return mig.DeserializeChanCommit(r)
}
func PutChanCommitment(chanBucket kvdb.RwBucket, c *mig.ChannelCommitment,
local bool) error {
commitKey := chanCommitmentKey
if local {
commitKey = append(commitKey, byte(0x00))
} else {
commitKey = append(commitKey, byte(0x01))
}
var b bytes.Buffer
if err := mig.SerializeChanCommit(&b, c); err != nil {
return err
}
return chanBucket.Put(commitKey, b.Bytes())
}
func PutChanCommitments(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
// If this is a restored channel, then we don't have any commitments to
// write.
if channel.hasChanStatus(ChanStatusRestored) {
return nil
}
err := PutChanCommitment(
chanBucket, &channel.LocalCommitment, true,
)
if err != nil {
return err
}
return PutChanCommitment(
chanBucket, &channel.RemoteCommitment, false,
)
}
// balancesAtHeight returns the local and remote balances on our commitment
// transactions as of a given height. This function is not exported as it's
// deprecated.
//
// NOTE: these are our balances *after* subtracting the commitment fee and
// anchor outputs.
func (c *OpenChannel) balancesAtHeight(chanBucket kvdb.RBucket,
height uint64) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
// If our current commit is as the desired height, we can return our
// current balances.
if c.LocalCommitment.CommitHeight == height {
return c.LocalCommitment.LocalBalance,
c.LocalCommitment.RemoteBalance, nil
}
// If our current remote commit is at the desired height, we can return
// the current balances.
if c.RemoteCommitment.CommitHeight == height {
return c.RemoteCommitment.LocalBalance,
c.RemoteCommitment.RemoteBalance, nil
}
// If we are not currently on the height requested, we need to look up
// the previous height to obtain our balances at the given height.
commit, err := c.FindPreviousStateLegacy(chanBucket, height)
if err != nil {
return 0, 0, err
}
return commit.LocalBalance, commit.RemoteBalance, nil
}
// FindPreviousStateLegacy scans through the append-only log in an attempt to
// recover the previous channel state indicated by the update number. This
// method is intended to be used for obtaining the relevant data needed to
// claim all funds rightfully spendable in the case of an on-chain broadcast of
// the commitment transaction.
func (c *OpenChannel) FindPreviousStateLegacy(chanBucket kvdb.RBucket,
updateNum uint64) (*mig.ChannelCommitment, error) {
c.RLock()
defer c.RUnlock()
logBucket := chanBucket.NestedReadBucket(revocationLogBucketLegacy)
if logBucket == nil {
return nil, ErrNoPastDeltas
}
commit, err := fetchChannelLogEntry(logBucket, updateNum)
if err != nil {
return nil, err
}
return &commit, nil
}
func fetchChannelLogEntry(log kvdb.RBucket,
updateNum uint64) (mig.ChannelCommitment, error) {
logEntrykey := mig24.MakeLogKey(updateNum)
commitBytes := log.Get(logEntrykey[:])
if commitBytes == nil {
return mig.ChannelCommitment{}, ErrLogEntryNotFound
}
commitReader := bytes.NewReader(commitBytes)
return mig.DeserializeChanCommit(commitReader)
}
func CreateChanBucket(tx kvdb.RwTx, c *OpenChannel) (kvdb.RwBucket, error) {
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket, err := tx.CreateTopLevelBucket(openChannelBucket)
if err != nil {
return nil, err
}
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := c.IdentityPub.SerializeCompressed()
nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub)
if err != nil {
return nil, err
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(
c.ChainHash[:],
)
if err != nil {
return nil, err
}
var chanPointBuf bytes.Buffer
err = mig.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint)
if err != nil {
return nil, err
}
// With the bucket for the node fetched, we can now go down another
// level, creating the bucket for this channel itself.
return chainBucket.CreateBucketIfNotExists(chanPointBuf.Bytes())
}
package migration25
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration25
import (
"bytes"
"fmt"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// openChanBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
openChannelBucket = []byte("open-chan-bucket")
// chanInfoKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores all the static
// information for a channel which is decided at the end of the
// funding flow.
chanInfoKey = []byte("chan-info-key")
// ErrNoChanDBExists is returned when a channel bucket hasn't been
// created.
ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created")
// ErrNoActiveChannels is returned when there is no active (open)
// channels within the database.
ErrNoActiveChannels = fmt.Errorf("no active channels exist")
// ErrChannelNotFound is returned when we attempt to locate a channel
// for a specific chain, but it is not found.
ErrChannelNotFound = fmt.Errorf("channel not found")
)
// MigrateInitialBalances patches the two new fields, InitialLocalBalance and
// InitialRemoteBalance, for all the open channels. It does so by reading the
// revocation log at height 0 to learn the initial balances and then updates
// the channel's info.
// The channel info is saved in the nested bucket which is accessible via
// nodePub:chainHash:chanPoint. If any of the sub-buckets turns out to be nil,
// we will log the error and continue to process the rest.
func MigrateInitialBalances(tx kvdb.RwTx) error {
log.Infof("Migrating initial local and remote balances...")
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return nil
}
// Read a list of open channels.
channels, err := findOpenChannels(openChanBucket)
if err != nil {
return err
}
// Migrate the balances.
for _, c := range channels {
if err := migrateBalances(tx, c); err != nil {
return err
}
}
return err
}
// findOpenChannels finds all open channels.
func findOpenChannels(openChanBucket kvdb.RBucket) ([]*OpenChannel, error) {
channels := []*OpenChannel{}
// readChannel is a helper closure that reads the channel info from the
// channel bucket.
readChannel := func(chainBucket kvdb.RBucket, cp []byte) error {
c := &OpenChannel{}
// Read the sub-bucket level 3.
chanBucket := chainBucket.NestedReadBucket(
cp,
)
if chanBucket == nil {
log.Errorf("unable to read bucket for chanPoint=%x", cp)
return nil
}
// Get the old channel info.
if err := fetchChanInfo(chanBucket, c, true); err != nil {
return fmt.Errorf("unable to fetch chan info: %w", err)
}
// Fetch the channel commitments, which are useful for freshly
// open channels as they don't have any revocation logs and
// their current commitments reflect the initial balances.
if err := FetchChanCommitments(chanBucket, c); err != nil {
return fmt.Errorf("unable to fetch chan commits: %w",
err)
}
channels = append(channels, c)
return nil
}
// Iterate the root bucket.
err := openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey, and
// also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}
// Read the sub-bucket level 1.
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
if nodeChanBucket == nil {
log.Errorf("no bucket for node %x", nodePub)
return nil
}
// Iterate the bucket.
return nodeChanBucket.ForEach(func(chainHash, _ []byte) error {
// Read the sub-bucket level 2.
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
log.Errorf("unable to read bucket for chain=%x",
chainHash)
return nil
}
// Iterate the bucket.
return chainBucket.ForEach(func(cp, _ []byte) error {
return readChannel(chainBucket, cp)
})
})
})
if err != nil {
return nil, err
}
return channels, nil
}
// migrateBalances queries the revocation log at height 0 to find the initial
// balances and save them to the channel info.
func migrateBalances(tx kvdb.RwTx, c *OpenChannel) error {
// Get the bucket.
chanBucket, err := FetchChanBucket(tx, c)
if err != nil {
return err
}
// Get the initial balances.
localAmt, remoteAmt, err := c.balancesAtHeight(chanBucket, 0)
if err != nil {
return fmt.Errorf("unable to get initial balances: %w", err)
}
c.InitialLocalBalance = localAmt
c.InitialRemoteBalance = remoteAmt
// Update the channel info.
if err := putChanInfo(chanBucket, c, false); err != nil {
return fmt.Errorf("unable to put chan info: %w", err)
}
return nil
}
// FetchChanBucket is a helper function that returns the bucket where a
// channel's data resides in given: the public key for the node, the outpoint,
// and the chainhash that the channel resides on.
func FetchChanBucket(tx kvdb.RwTx, c *OpenChannel) (kvdb.RwBucket, error) {
// First fetch the top level bucket which stores all data related to
// current, active channels.
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
}
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := c.IdentityPub.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(c.ChainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
err := mig.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint)
if err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
}
package migration26
import (
"bytes"
"fmt"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
)
var (
// chanInfoKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores all the static
// information for a channel which is decided at the end of the
// funding flow.
chanInfoKey = []byte("chan-info-key")
// localUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the local peer.
localUpfrontShutdownKey = []byte("local-upfront-shutdown-key")
// remoteUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the remote peer.
remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key")
// lastWasRevokeKey is a key that stores true when the last update we
// sent was a revocation and false when it was a commitment signature.
// This is nil in the case of new channels with no updates exchanged.
lastWasRevokeKey = []byte("last-was-revoke")
// ErrNoChanInfoFound is returned when a particular channel does not
// have any channels state.
ErrNoChanInfoFound = fmt.Errorf("no chan info found")
// ErrNoPastDeltas is returned when the channel delta bucket hasn't been
// created.
ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas")
// ErrLogEntryNotFound is returned when we cannot find a log entry at
// the height requested in the revocation log.
ErrLogEntryNotFound = fmt.Errorf("log entry not found")
// ErrNoCommitmentsFound is returned when a channel has not set
// commitment states.
ErrNoCommitmentsFound = fmt.Errorf("no commitments found")
)
// OpenChannel embeds a mig25.OpenChannel with the extra update-to-date
// serialization and deserialization methods.
//
// NOTE: doesn't have the Packager field as it's not used in current migration.
type OpenChannel struct {
mig25.OpenChannel
}
// FetchChanInfo deserializes the channel info based on the legacy boolean.
// After migration25, the legacy format would have the fields
// `InitialLocalBalance` and `InitialRemoteBalance` directly encoded as bytes.
// For the new format, they will be put inside a tlv stream.
func FetchChanInfo(chanBucket kvdb.RBucket, c *OpenChannel, legacy bool) error {
infoBytes := chanBucket.Get(chanInfoKey)
if infoBytes == nil {
return ErrNoChanInfoFound
}
r := bytes.NewReader(infoBytes)
var (
chanType mig.ChannelType
chanStatus mig.ChannelStatus
)
if err := mig.ReadElements(r,
&chanType, &c.ChainHash, &c.FundingOutpoint,
&c.ShortChannelID, &c.IsPending, &c.IsInitiator,
&chanStatus, &c.FundingBroadcastHeight,
&c.NumConfsRequired, &c.ChannelFlags,
&c.IdentityPub, &c.Capacity, &c.TotalMSatSent,
&c.TotalMSatReceived,
); err != nil {
return err
}
c.ChanType = mig25.ChannelType(chanType)
c.ChanStatus = mig25.ChannelStatus(chanStatus)
// If this is the legacy format, we need to read the extra two new
// fields.
if legacy {
if err := mig.ReadElements(r,
&c.InitialLocalBalance, &c.InitialRemoteBalance,
); err != nil {
return err
}
}
// For single funder channels that we initiated and have the funding
// transaction to, read the funding txn.
if c.FundingTxPresent() {
if err := mig.ReadElement(r, &c.FundingTxn); err != nil {
return err
}
}
if err := mig.ReadChanConfig(r, &c.LocalChanCfg); err != nil {
return err
}
if err := mig.ReadChanConfig(r, &c.RemoteChanCfg); err != nil {
return err
}
// Retrieve the boolean stored under lastWasRevokeKey.
lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey)
if lastWasRevokeBytes == nil {
// If nothing has been stored under this key, we store false in
// the OpenChannel struct.
c.LastWasRevoke = false
} else {
// Otherwise, read the value into the LastWasRevoke field.
revokeReader := bytes.NewReader(lastWasRevokeBytes)
err := mig.ReadElements(revokeReader, &c.LastWasRevoke)
if err != nil {
return err
}
}
// Make the tlv stream based on the legacy param.
var (
ts *tlv.Stream
err error
localBalance uint64
remoteBalance uint64
)
keyLocRecord := mig25.MakeKeyLocRecord(
keyLocType, &c.RevocationKeyLocator,
)
// If it's legacy, create the stream with a single tlv record.
if legacy {
ts, err = tlv.NewStream(keyLocRecord)
} else {
// Otherwise, for the new format, we will encode the balance
// fields in the tlv stream too.
ts, err = tlv.NewStream(
keyLocRecord,
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
)
}
if err != nil {
return err
}
if err := ts.Decode(r); err != nil {
return err
}
// For the new format, attach the balance fields.
if !legacy {
c.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
c.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)
}
// Finally, read the optional shutdown scripts.
if err := mig25.GetOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, &c.LocalShutdownScript,
); err != nil {
return err
}
return mig25.GetOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, &c.RemoteShutdownScript,
)
}
// MakeTlvStream creates a tlv stream based on whether we are deadling with
// legacy format or not. For the legacy format, we have a single record in the
// stream. For the new format, we have the extra balance records.
func MakeTlvStream(c *OpenChannel, legacy bool) (*tlv.Stream, error) {
keyLocRecord := mig25.MakeKeyLocRecord(
keyLocType, &c.RevocationKeyLocator,
)
// If it's legacy, return the stream with a single tlv record.
if legacy {
return tlv.NewStream(keyLocRecord)
}
// Otherwise, for the new format, we will encode the balance fields in
// the tlv stream too.
localBalance := uint64(c.InitialLocalBalance)
remoteBalance := uint64(c.InitialRemoteBalance)
// Create the tlv stream.
return tlv.NewStream(
keyLocRecord,
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
)
}
// PutChanInfo serializes the channel info based on the legacy boolean. After
// migration25, the legacy format would have the fields `InitialLocalBalance`
// and `InitialRemoteBalance` directly encoded as bytes. For the new format,
// they will be put inside a tlv stream.
func PutChanInfo(chanBucket kvdb.RwBucket, c *OpenChannel, legacy bool) error {
var w bytes.Buffer
if err := mig.WriteElements(&w,
mig.ChannelType(c.ChanType), c.ChainHash, c.FundingOutpoint,
c.ShortChannelID, c.IsPending, c.IsInitiator,
mig.ChannelStatus(c.ChanStatus), c.FundingBroadcastHeight,
c.NumConfsRequired, c.ChannelFlags,
c.IdentityPub, c.Capacity, c.TotalMSatSent,
c.TotalMSatReceived,
); err != nil {
return err
}
// If this is legacy format, we need to write the extra two fields.
if legacy {
if err := mig.WriteElements(&w,
c.InitialLocalBalance, c.InitialRemoteBalance,
); err != nil {
return err
}
}
// For single funder channels that we initiated, and we have the
// funding transaction, then write the funding txn.
if c.FundingTxPresent() {
if err := mig.WriteElement(&w, c.FundingTxn); err != nil {
return err
}
}
if err := mig.WriteChanConfig(&w, &c.LocalChanCfg); err != nil {
return err
}
if err := mig.WriteChanConfig(&w, &c.RemoteChanCfg); err != nil {
return err
}
// Make the tlv stream based on the legacy param.
tlvStream, err := MakeTlvStream(c, legacy)
if err != nil {
return err
}
if err := tlvStream.Encode(&w); err != nil {
return err
}
if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
return err
}
// Finally, add optional shutdown scripts for the local and remote peer
// if they are present.
if err := mig25.PutOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, c.LocalShutdownScript,
); err != nil {
return err
}
return mig25.PutOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, c.RemoteShutdownScript,
)
}
package migration26
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration26
import (
"fmt"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// openChanBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
openChannelBucket = []byte("open-chan-bucket")
// ErrNoChanDBExists is returned when a channel bucket hasn't been
// created.
ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created")
// ErrNoActiveChannels is returned when there is no active (open)
// channels within the database.
ErrNoActiveChannels = fmt.Errorf("no active channels exist")
// ErrChannelNotFound is returned when we attempt to locate a channel
// for a specific chain, but it is not found.
ErrChannelNotFound = fmt.Errorf("channel not found")
)
// MigrateBalancesToTlvRecords migrates the balance fields into tlv records. It
// does so by first reading a list of open channels, then rewriting the channel
// info with the updated tlv stream.
func MigrateBalancesToTlvRecords(tx kvdb.RwTx) error {
log.Infof("Migrating local and remote balances into tlv records...")
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return nil
}
// Read a list of open channels.
channels, err := findOpenChannels(openChanBucket)
if err != nil {
return err
}
// Migrate the balances.
for _, c := range channels {
if err := migrateBalances(tx, c); err != nil {
return err
}
}
return err
}
// findOpenChannels finds all open channels.
func findOpenChannels(openChanBucket kvdb.RBucket) ([]*OpenChannel, error) {
channels := []*OpenChannel{}
// readChannel is a helper closure that reads the channel info from the
// channel bucket.
readChannel := func(chainBucket kvdb.RBucket, cp []byte) error {
c := &OpenChannel{}
// Read the sub-bucket level 3.
chanBucket := chainBucket.NestedReadBucket(
cp,
)
if chanBucket == nil {
log.Errorf("unable to read bucket for chanPoint=%x", cp)
return nil
}
// Get the old channel info.
if err := FetchChanInfo(chanBucket, c, true); err != nil {
return fmt.Errorf("unable to fetch chan info: %w", err)
}
channels = append(channels, c)
return nil
}
// Iterate the root bucket.
err := openChanBucket.ForEach(func(nodePub, v []byte) error {
// Ensure that this is a key the same size as a pubkey, and
// also that it leads directly to a bucket.
if len(nodePub) != 33 || v != nil {
return nil
}
// Read the sub-bucket level 1.
nodeChanBucket := openChanBucket.NestedReadBucket(nodePub)
if nodeChanBucket == nil {
log.Errorf("no bucket for node %x", nodePub)
return nil
}
// Iterate the bucket.
return nodeChanBucket.ForEach(func(chainHash, _ []byte) error {
// Read the sub-bucket level 2.
chainBucket := nodeChanBucket.NestedReadBucket(
chainHash,
)
if chainBucket == nil {
log.Errorf("unable to read bucket for chain=%x",
chainHash)
return nil
}
// Iterate the bucket.
return chainBucket.ForEach(func(cp, _ []byte) error {
return readChannel(chainBucket, cp)
})
})
})
if err != nil {
return nil, err
}
return channels, nil
}
// migrateBalances creates a new tlv stream which adds two more records to hold
// the balances info.
func migrateBalances(tx kvdb.RwTx, c *OpenChannel) error {
// Get the bucket.
chanBucket, err := mig25.FetchChanBucket(tx, &c.OpenChannel)
if err != nil {
return err
}
// Update the channel info. There isn't much to do here as the
// `PutChanInfo` will read the values from `c.InitialLocalBalance` and
// `c.InitialRemoteBalance` then create the new tlv stream as
// requested.
if err := PutChanInfo(chanBucket, c, false); err != nil {
return fmt.Errorf("unable to put chan info: %w", err)
}
return nil
}
package migration27
import (
"bytes"
"fmt"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
)
var (
// chanInfoKey can be accessed within the bucket for a channel
// (identified by its chanPoint). This key stores all the static
// information for a channel which is decided at the end of the
// funding flow.
chanInfoKey = []byte("chan-info-key")
// localUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the local peer.
localUpfrontShutdownKey = []byte("local-upfront-shutdown-key")
// remoteUpfrontShutdownKey can be accessed within the bucket for a
// channel (identified by its chanPoint). This key stores an optional
// upfront shutdown script for the remote peer.
remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key")
// lastWasRevokeKey is a key that stores true when the last update we
// sent was a revocation and false when it was a commitment signature.
// This is nil in the case of new channels with no updates exchanged.
lastWasRevokeKey = []byte("last-was-revoke")
// ErrNoChanInfoFound is returned when a particular channel does not
// have any channels state.
ErrNoChanInfoFound = fmt.Errorf("no chan info found")
)
// OpenChannel embeds a mig26.OpenChannel with the extra update-to-date
// serialization and deserialization methods.
//
// NOTE: doesn't have the Packager field as it's not used in current migration.
type OpenChannel struct {
mig26.OpenChannel
}
// FetchChanInfo deserializes the channel info based on the legacy boolean.
func FetchChanInfo(chanBucket kvdb.RBucket, c *OpenChannel, legacy bool) error {
infoBytes := chanBucket.Get(chanInfoKey)
if infoBytes == nil {
return ErrNoChanInfoFound
}
r := bytes.NewReader(infoBytes)
var (
chanType mig.ChannelType
chanStatus mig.ChannelStatus
)
if err := mig.ReadElements(r,
&chanType, &c.ChainHash, &c.FundingOutpoint,
&c.ShortChannelID, &c.IsPending, &c.IsInitiator,
&chanStatus, &c.FundingBroadcastHeight,
&c.NumConfsRequired, &c.ChannelFlags,
&c.IdentityPub, &c.Capacity, &c.TotalMSatSent,
&c.TotalMSatReceived,
); err != nil {
return fmt.Errorf("ReadElements got: %w", err)
}
c.ChanType = mig25.ChannelType(chanType)
c.ChanStatus = mig25.ChannelStatus(chanStatus)
// For single funder channels that we initiated and have the funding
// transaction to, read the funding txn.
if c.FundingTxPresent() {
if err := mig.ReadElement(r, &c.FundingTxn); err != nil {
return fmt.Errorf("read FundingTxn got: %w", err)
}
}
if err := mig.ReadChanConfig(r, &c.LocalChanCfg); err != nil {
return fmt.Errorf("read LocalChanCfg got: %w", err)
}
if err := mig.ReadChanConfig(r, &c.RemoteChanCfg); err != nil {
return fmt.Errorf("read RemoteChanCfg got: %w", err)
}
// Retrieve the boolean stored under lastWasRevokeKey.
lastWasRevokeBytes := chanBucket.Get(lastWasRevokeKey)
if lastWasRevokeBytes == nil {
// If nothing has been stored under this key, we store false in
// the OpenChannel struct.
c.LastWasRevoke = false
} else {
// Otherwise, read the value into the LastWasRevoke field.
revokeReader := bytes.NewReader(lastWasRevokeBytes)
err := mig.ReadElements(revokeReader, &c.LastWasRevoke)
if err != nil {
return fmt.Errorf("read LastWasRevoke got: %w", err)
}
}
// Make the tlv stream based on the legacy param.
var (
ts *tlv.Stream
err error
localBalance uint64
remoteBalance uint64
)
keyLocRecord := mig25.MakeKeyLocRecord(
keyLocType, &c.RevocationKeyLocator,
)
// If it's legacy, create the stream with a single tlv record.
if legacy {
ts, err = tlv.NewStream(keyLocRecord)
} else {
// Otherwise, for the new format, we will encode the balance
// fields in the tlv stream too.
ts, err = tlv.NewStream(
keyLocRecord,
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
)
}
if err != nil {
return fmt.Errorf("create tlv stream got: %w", err)
}
if err := ts.Decode(r); err != nil {
return fmt.Errorf("decode tlv stream got: %w", err)
}
// For the new format, attach the balance fields.
if !legacy {
c.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
c.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)
}
// Finally, read the optional shutdown scripts.
if err := mig25.GetOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, &c.LocalShutdownScript,
); err != nil {
return fmt.Errorf("local shutdown script got: %w", err)
}
return mig25.GetOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, &c.RemoteShutdownScript,
)
}
// PutChanInfo serializes the channel info based on the legacy boolean.
func PutChanInfo(chanBucket kvdb.RwBucket, c *OpenChannel, legacy bool) error {
var w bytes.Buffer
if err := mig.WriteElements(&w,
mig.ChannelType(c.ChanType), c.ChainHash, c.FundingOutpoint,
c.ShortChannelID, c.IsPending, c.IsInitiator,
mig.ChannelStatus(c.ChanStatus), c.FundingBroadcastHeight,
c.NumConfsRequired, c.ChannelFlags,
c.IdentityPub, c.Capacity, c.TotalMSatSent,
c.TotalMSatReceived,
); err != nil {
return err
}
// For single funder channels that we initiated, and we have the
// funding transaction, then write the funding txn.
if c.FundingTxPresent() {
if err := mig.WriteElement(&w, c.FundingTxn); err != nil {
return err
}
}
if err := mig.WriteChanConfig(&w, &c.LocalChanCfg); err != nil {
return err
}
if err := mig.WriteChanConfig(&w, &c.RemoteChanCfg); err != nil {
return err
}
// Make the tlv stream based on the legacy param.
tlvStream, err := mig26.MakeTlvStream(&c.OpenChannel, legacy)
if err != nil {
return err
}
if err := tlvStream.Encode(&w); err != nil {
return err
}
if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
return err
}
// Finally, add optional shutdown scripts for the local and remote peer
// if they are present.
if err := mig25.PutOptionalUpfrontShutdownScript(
chanBucket, localUpfrontShutdownKey, c.LocalShutdownScript,
); err != nil {
return err
}
return mig25.PutOptionalUpfrontShutdownScript(
chanBucket, remoteUpfrontShutdownKey, c.RemoteShutdownScript,
)
}
package migration27
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration27
import (
"bytes"
"fmt"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// historicalChannelBucket stores all channels that have seen their
// commitment tx confirm. All information from their previous open state
// is retained.
historicalChannelBucket = []byte("historical-chan-bucket")
)
// MigrateHistoricalBalances patches the two new fields, `InitialLocalBalance`
// and `InitialRemoteBalance`, for all the open channels saved in historical
// channel bucket. Unlike migration 25, it will only read the old channel info
// first and then patch the new tlv records with empty values. For historical
// channels, we previously didn't save the initial balances anywhere and since
// it's corresponding open channel bucket is deleted after closure, we have
// lost that balance info.
func MigrateHistoricalBalances(tx kvdb.RwTx) error {
log.Infof("Migrating historical local and remote balances...")
// First fetch the top level bucket which stores all data related to
// historically stored channels.
rootBucket := tx.ReadWriteBucket(historicalChannelBucket)
// If no bucket is found, we can exit early.
if rootBucket == nil {
return nil
}
// Read a list of historical channels.
channels, err := findHistoricalChannels(rootBucket)
if err != nil {
return err
}
// Migrate the balances.
for _, c := range channels {
if err := migrateBalances(rootBucket, c); err != nil {
return err
}
}
return err
}
// findHistoricalChannels finds all historical channels.
func findHistoricalChannels(historicalBucket kvdb.RBucket) ([]*OpenChannel,
error) {
channels := []*OpenChannel{}
// readChannel is a helper closure that reads the channel info from the
// historical sub-bucket.
readChannel := func(rootBucket kvdb.RBucket, cp []byte) error {
c := &OpenChannel{}
chanPointBuf := bytes.NewBuffer(cp)
err := mig.ReadOutpoint(chanPointBuf, &c.FundingOutpoint)
if err != nil {
return fmt.Errorf("read funding outpoint got: %w", err)
}
// Read the sub-bucket.
chanBucket := rootBucket.NestedReadBucket(cp)
if chanBucket == nil {
log.Errorf("unable to read bucket for chanPoint=%s",
c.FundingOutpoint)
return nil
}
// Try to fetch channel info in old format.
err = fetchChanInfoCompatible(chanBucket, c, true)
if err != nil {
return fmt.Errorf("%s: fetch chan info got: %w",
c.FundingOutpoint, err)
}
channels = append(channels, c)
return nil
}
// Iterate the root bucket.
err := historicalBucket.ForEach(func(cp, _ []byte) error {
return readChannel(historicalBucket, cp)
})
if err != nil {
return nil, err
}
return channels, nil
}
// fetchChanInfoCompatible tries to fetch the channel info for a historical
// channel. It will first fetch the info assuming `InitialLocalBalance` and
// `InitialRemoteBalance` are not serialized. Upon receiving an error, it will
// then fetch it again assuming the two fields are present in db.
func fetchChanInfoCompatible(chanBucket kvdb.RBucket, c *OpenChannel,
legacy bool) error {
// Try to fetch the channel info assuming the historical channel in in
// the old format, where the two fields, `InitialLocalBalance` and
// `InitialRemoteBalance` are not saved to db.
err := FetchChanInfo(chanBucket, c, legacy)
if err == nil {
return err
}
// If we got an error above, the historical channel may already have
// the new fields saved. This could happen when a channel is closed
// after applying migration 25. In this case, we'll borrow the
// `FetchChanInfo` info method from migration 26 where we assume the
// two fields are saved.
return mig26.FetchChanInfo(chanBucket, &c.OpenChannel, legacy)
}
// migrateBalances serializes the channel info using the new tlv format where
// the two fields, `InitialLocalBalance` and `InitialRemoteBalance` are patched
// with empty values.
func migrateBalances(rootBucket kvdb.RwBucket, c *OpenChannel) error {
var chanPointBuf bytes.Buffer
err := mig.WriteOutpoint(&chanPointBuf, &c.FundingOutpoint)
if err != nil {
return err
}
// Get the channel bucket.
chanBucket := rootBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return fmt.Errorf("empty historical chan bucket")
}
// Update the channel info.
if err := PutChanInfo(chanBucket, c, false); err != nil {
return fmt.Errorf("unable to put chan info: %w", err)
}
return nil
}
package migration29
import (
"encoding/binary"
"encoding/hex"
"io"
"github.com/btcsuite/btcd/wire"
)
var (
byteOrder = binary.BigEndian
)
// ChannelID is a series of 32-bytes that uniquely identifies all channels
// within the network. The ChannelID is computed using the outpoint of the
// funding transaction (the txid, and output index). Given a funding output the
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
// txid and the big-endian serialization of the output index, truncated to
// 2 bytes.
type ChannelID [32]byte
// String returns the string representation of the ChannelID. This is just the
// hex string encoding of the ChannelID itself.
func (c ChannelID) String() string {
return hex.EncodeToString(c[:])
}
// NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is
// usable within the network. In order to convert the OutPoint into a ChannelID,
// we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian
// serialization of the Index of the OutPoint, truncated to 2-bytes.
func NewChanIDFromOutPoint(op *wire.OutPoint) ChannelID {
// First we'll copy the txid of the outpoint into our channel ID slice.
var cid ChannelID
copy(cid[:], op.Hash[:])
// With the txid copied over, we'll now XOR the lower 2-bytes of the
// partial channelID with big-endian serialization of output index.
xorTxid(&cid, uint16(op.Index))
return cid
}
// xorTxid performs the transformation needed to transform an OutPoint into a
// ChannelID. To do this, we expect the cid parameter to contain the txid
// unaltered and the outputIndex to be the output index
func xorTxid(cid *ChannelID, outputIndex uint16) {
var buf [2]byte
binary.BigEndian.PutUint16(buf[:], outputIndex)
cid[30] ^= buf[0]
cid[31] ^= buf[1]
}
// readOutpoint reads an outpoint from the passed reader.
func readOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
package migration29
import "github.com/btcsuite/btclog/v2"
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specific Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration29
import (
"bytes"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// outpointBucket is the bucket that stores the set of outpoints we
// know about.
outpointBucket = []byte("outpoint-bucket")
// chanIDBucket is the bucket that stores the set of ChannelID's we
// know about.
chanIDBucket = []byte("chan-id-bucket")
)
// MigrateChanID populates the ChannelID index by using the set of outpoints
// retrieved from the outpoint bucket.
func MigrateChanID(tx kvdb.RwTx) error {
log.Info("Populating ChannelID index")
// First we'll retrieve the set of outpoints we know about.
ops, err := fetchOutPoints(tx)
if err != nil {
return err
}
return populateChanIDIndex(tx, ops)
}
// fetchOutPoints loops through the outpointBucket and returns each stored
// outpoint.
func fetchOutPoints(tx kvdb.RwTx) ([]*wire.OutPoint, error) {
var ops []*wire.OutPoint
bucket := tx.ReadBucket(outpointBucket)
err := bucket.ForEach(func(k, _ []byte) error {
var op wire.OutPoint
r := bytes.NewReader(k)
if err := readOutpoint(r, &op); err != nil {
return err
}
ops = append(ops, &op)
return nil
})
if err != nil {
return nil, err
}
return ops, nil
}
// populateChanIDIndex uses the set of retrieved outpoints and populates the
// ChannelID index.
func populateChanIDIndex(tx kvdb.RwTx, ops []*wire.OutPoint) error {
bucket := tx.ReadWriteBucket(chanIDBucket)
for _, op := range ops {
chanID := NewChanIDFromOutPoint(op)
if err := bucket.Put(chanID[:], []byte{}); err != nil {
return err
}
}
return nil
}
package migration30
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// openChanBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
openChannelBucket = []byte("open-chan-bucket")
// errExit is returned when the callback function used in iterator
// needs to exit the iteration.
errExit = errors.New("exit condition met")
)
// updateLocator defines a locator that can be used to find the next record to
// be migrated. This is useful when an interrupted migration that leads to a
// mixed revocation log formats saved in our database, we can then restart the
// migration using the locator to continue migrating the rest.
type updateLocator struct {
// nodePub, chainHash and fundingOutpoint are used to locate the
// channel bucket.
nodePub []byte
chainHash []byte
fundingOutpoint []byte
// nextHeight is used to locate the next old revocation log to be
// migrated. A nil value means we've finished the migration.
nextHeight []byte
}
// fetchChanBucket is a helper function that returns the bucket where a
// channel's data resides in given: the public key for the node, the outpoint,
// and the chainhash that the channel resides on.
func (ul *updateLocator) locateChanBucket(rootBucket kvdb.RwBucket) (
kvdb.RwBucket, error) {
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodeChanBucket := rootBucket.NestedReadWriteBucket(ul.nodePub)
if nodeChanBucket == nil {
return nil, mig25.ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(ul.chainHash)
if chainBucket == nil {
return nil, mig25.ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
chanBucket := chainBucket.NestedReadWriteBucket(ul.fundingOutpoint)
if chanBucket == nil {
return nil, mig25.ErrChannelNotFound
}
return chanBucket, nil
}
// findNextMigrateHeight finds the next commit height that's not migrated. It
// returns the commit height bytes found. A nil return value means the
// migration has been completed for this particular channel bucket.
func findNextMigrateHeight(chanBucket kvdb.RwBucket) []byte {
// Read the old log bucket. The old bucket doesn't exist, indicating
// either we don't have any old logs for this channel, or the migration
// has been finished and the old bucket has been deleted.
oldBucket := chanBucket.NestedReadBucket(
revocationLogBucketDeprecated,
)
if oldBucket == nil {
return nil
}
// Acquire a read cursor for the old bucket.
oldCursor := oldBucket.ReadCursor()
// Read the new log bucket. The sub-bucket hasn't been created yet,
// indicating we haven't migrated any logs under this channel. In this
// case, we'll return the first commit height found from the old
// revocation log bucket as the next height.
logBucket := chanBucket.NestedReadBucket(revocationLogBucket)
if logBucket == nil {
nextHeight, _ := oldCursor.First()
return nextHeight
}
// Acquire a read cursor for the new bucket.
cursor := logBucket.ReadCursor()
// Read the last migrated record. If the key is nil, we haven't
// migrated any logs yet. In this case we return the first commit
// height found from the old revocation log bucket. For instance,
// - old log: [1, 2]
// - new log: []
// We will return the first key [1].
migratedHeight, _ := cursor.Last()
if migratedHeight == nil {
nextHeight, _ := oldCursor.First()
return nextHeight
}
// Read the last height from the old log bucket.
endHeight, _ := oldCursor.Last()
switch bytes.Compare(migratedHeight, endHeight) {
// If the height of the last old revocation equals to the migrated
// height, we've done migrating for this channel. For instance,
// - old log: [1, 2]
// - new log: [1, 2]
case 0:
return nil
// If the migrated height is smaller, it means this is a resumed
// migration. In this case we will return the next height found in the
// old bucket. For instance,
// - old log: [1, 2]
// - new log: [1]
// We will return the key [2].
case -1:
// Now point the cursor to the migratedHeight. If we cannot
// find this key from the old log bucket, the database might be
// corrupted. In this case, we would return the first key so
// that we would redo the migration for this chan bucket.
matchedHeight, _ := oldCursor.Seek(migratedHeight)
// NOTE: because Seek will return the next key when the passed
// key cannot be found, we need to compare the `matchedHeight`
// to decide whether `migratedHeight` is found or not.
if !bytes.Equal(matchedHeight, migratedHeight) {
log.Warnf("Old revocation bucket doesn't have "+
"CommitHeight=%v yet it's found in the new "+
"bucket. It's likely the new revocation log "+
"bucket is corrupted. Migrations will be"+
"applied again.",
binary.BigEndian.Uint64(migratedHeight))
// Now return the first height found in the old bucket
// so we can redo the migration.
nextHeight, _ := oldCursor.First()
return nextHeight
}
// Otherwise, find the next height to be migrated.
nextHeight, _ := oldCursor.Next()
return nextHeight
// If the migrated height is greater, it means this node has new logs
// saved after v0.15.0. In this case, we need to further decide whether
// the old logs have been migrated or not.
case 1:
}
// If we ever reached here, it means we have a mixed of new and old
// logs saved. Suppose we have old logs as,
// - old log: [1, 2]
// We'd have four possible scenarios,
// - new log: [ 3, 4] <- no migration happened, return [1].
// - new log: [1, 3, 4] <- resumed migration, return [2].
// - new log: [ 2, 3, 4] <- corrupted migration, return [1].
// - new log: [1, 2, 3, 4] <- finished migration, return nil.
// To find the next migration height, we will iterate the old logs to
// grab the heights and query them in the new bucket until an height
// cannot be found, which is our next migration height. Or, if the old
// heights can all be found, it indicates a finished migration.
// Move the cursor to the first record.
oldKey, _ := oldCursor.First()
// NOTE: this action can be time-consuming as we are iterating the
// records and compare them. However, we would only ever hit here if
// this is a resumed migration with new logs created after v.0.15.0.
for {
// Try to locate the old key in the new bucket. If it cannot be
// found, it will be the next migrate height.
newKey, _ := cursor.Seek(oldKey)
// If the old key is not found in the new bucket, return it as
// our next migration height.
//
// NOTE: because Seek will return the next key when the passed
// key cannot be found, we need to compare the keys to deicde
// whether the old key is found or not.
if !bytes.Equal(newKey, oldKey) {
return oldKey
}
// Otherwise, keep iterating the old bucket.
oldKey, _ = oldCursor.Next()
// If we've done iterating, yet all the old keys can be found
// in the new bucket, this means the migration has been
// finished.
if oldKey == nil {
return nil
}
}
}
// locateNextUpdateNum returns a locator that's used to start our migration. A
// nil locator means the migration has been finished.
func locateNextUpdateNum(openChanBucket kvdb.RwBucket) (*updateLocator, error) {
locator := &updateLocator{}
// cb is the callback function to be used when iterating the buckets.
cb := func(chanBucket kvdb.RwBucket, l *updateLocator) error {
locator = l
updateNum := findNextMigrateHeight(chanBucket)
// We've found the next commit height and can now exit.
if updateNum != nil {
locator.nextHeight = updateNum
return errExit
}
return nil
}
// Iterate the buckets. If we received an exit signal, return the
// locator.
err := iterateBuckets(openChanBucket, nil, cb)
if err == errExit {
log.Debugf("found locator: nodePub=%x, fundingOutpoint=%x, "+
"nextHeight=%x", locator.nodePub, locator.chainHash,
locator.nextHeight)
return locator, nil
}
// If the err is nil, we've iterated all the sub-buckets and the
// migration is finished.
return nil, err
}
// callback defines a type that's used by the iterator.
type callback func(k, v []byte) error
// iterator is a helper function that iterates a given bucket and performs the
// callback function on each key. If a seeker is specified, it will move the
// cursor to the given position otherwise it will start from the first item.
func iterator(bucket kvdb.RBucket, seeker []byte, cb callback) error {
c := bucket.ReadCursor()
k, v := c.First()
// Move the cursor to the specified position if seeker is non-nil.
if seeker != nil {
k, v = c.Seek(seeker)
}
// Start the iteration and exit on condition.
for k, v := k, v; k != nil; k, v = c.Next() {
// cb might return errExit to signal exiting the iteration.
if err := cb(k, v); err != nil {
return err
}
}
return nil
}
// step defines the callback type that's used when iterating the buckets.
type step func(bucket kvdb.RwBucket, l *updateLocator) error
// iterateBuckets locates the cursor at a given position specified by the
// updateLocator and starts the iteration. If a nil locator is passed, it will
// start the iteration from the beginning. During each iteration, the callback
// function is called and it may exit the iteration when the callback returns
// an errExit to signal an exit condition.
func iterateBuckets(openChanBucket kvdb.RwBucket,
l *updateLocator, cb step) error {
// If the locator is nil, we will initiate an empty one, which is
// further used by the iterator.
if l == nil {
l = &updateLocator{}
}
// iterChanBucket iterates the chain bucket to act on each of the
// channel buckets.
iterChanBucket := func(chain kvdb.RwBucket,
k1, k2, _ []byte, cb step) error {
return iterator(
chain, l.fundingOutpoint,
func(k3, _ []byte) error {
// Read the sub-bucket level 3.
chanBucket := chain.NestedReadWriteBucket(k3)
if chanBucket == nil {
return fmt.Errorf("no bucket for "+
"chanPoint=%x", k3)
}
// Construct a new locator at this position.
locator := &updateLocator{
nodePub: k1,
chainHash: k2,
fundingOutpoint: k3,
}
// Set the seeker to nil so it won't affect
// other buckets.
l.fundingOutpoint = nil
return cb(chanBucket, locator)
})
}
return iterator(openChanBucket, l.nodePub, func(k1, v []byte) error {
// Read the sub-bucket level 1.
node := openChanBucket.NestedReadWriteBucket(k1)
if node == nil {
return fmt.Errorf("no bucket for node %x", k1)
}
return iterator(node, l.chainHash, func(k2, v []byte) error {
// Read the sub-bucket level 2.
chain := node.NestedReadWriteBucket(k2)
if chain == nil {
return fmt.Errorf("no bucket for chain=%x", k2)
}
// Set the seeker to nil so it won't affect other
// buckets.
l.chainHash = nil
return iterChanBucket(chain, k1, k2, v, cb)
})
})
}
package migration30
import (
"bytes"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/input"
)
// CommitmentKeyRing holds all derived keys needed to construct commitment and
// HTLC transactions. The keys are derived differently depending whether the
// commitment transaction is ours or the remote peer's. Private keys associated
// with each key may belong to the commitment owner or the "other party" which
// is referred to in the field comments, regardless of which is local and which
// is remote.
type CommitmentKeyRing struct {
// CommitPoint is the "per commitment point" used to derive the tweak
// for each base point.
CommitPoint *btcec.PublicKey
// LocalCommitKeyTweak is the tweak used to derive the local public key
// from the local payment base point or the local private key from the
// base point secret. This may be included in a SignDescriptor to
// generate signatures for the local payment key.
//
// NOTE: This will always refer to "our" local key, regardless of
// whether this is our commit or not.
LocalCommitKeyTweak []byte
// TODO(roasbeef): need delay tweak as well?
// LocalHtlcKeyTweak is the tweak used to derive the local HTLC key
// from the local HTLC base point. This value is needed in order to
// derive the final key used within the HTLC scripts in the commitment
// transaction.
//
// NOTE: This will always refer to "our" local HTLC key, regardless of
// whether this is our commit or not.
LocalHtlcKeyTweak []byte
// LocalHtlcKey is the key that will be used in any clause paying to
// our node of any HTLC scripts within the commitment transaction for
// this key ring set.
//
// NOTE: This will always refer to "our" local HTLC key, regardless of
// whether this is our commit or not.
LocalHtlcKey *btcec.PublicKey
// RemoteHtlcKey is the key that will be used in clauses within the
// HTLC script that send money to the remote party.
//
// NOTE: This will always refer to "their" remote HTLC key, regardless
// of whether this is our commit or not.
RemoteHtlcKey *btcec.PublicKey
// ToLocalKey is the commitment transaction owner's key which is
// included in HTLC success and timeout transaction scripts. This is
// the public key used for the to_local output of the commitment
// transaction.
//
// NOTE: Who's key this is depends on the current perspective. If this
// is our commitment this will be our key.
ToLocalKey *btcec.PublicKey
// ToRemoteKey is the non-owner's payment key in the commitment tx.
// This is the key used to generate the to_remote output within the
// commitment transaction.
//
// NOTE: Who's key this is depends on the current perspective. If this
// is our commitment this will be their key.
ToRemoteKey *btcec.PublicKey
// RevocationKey is the key that can be used by the other party to
// redeem outputs from a revoked commitment transaction if it were to
// be published.
//
// NOTE: Who can sign for this key depends on the current perspective.
// If this is our commitment, it means the remote node can sign for
// this key in case of a breach.
RevocationKey *btcec.PublicKey
}
// ScriptInfo holds a redeem script and hash.
type ScriptInfo struct {
// PkScript is the output's PkScript.
PkScript []byte
// WitnessScript is the full script required to properly redeem the
// output. This field should be set to the full script if a p2wsh
// output is being signed. For p2wkh it should be set equal to the
// PkScript.
WitnessScript []byte
}
// findOutputIndexesFromRemote finds the index of our and their outputs from
// the remote commitment transaction. It derives the key ring to compute the
// output scripts and compares them against the outputs inside the commitment
// to find the match.
func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash,
chanState *mig26.OpenChannel,
oldLog *mig.ChannelCommitment) (uint32, uint32, error) {
// Init the output indexes as empty.
ourIndex := uint32(OutputIndexEmpty)
theirIndex := uint32(OutputIndexEmpty)
chanCommit := oldLog
_, commitmentPoint := btcec.PrivKeyFromBytes(revocationPreimage[:])
// With the commitment point generated, we can now derive the king ring
// which will be used to generate the output scripts.
keyRing := DeriveCommitmentKeys(
commitmentPoint, false, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg,
)
// Since it's remote commitment chain, we'd used the mirrored values.
//
// We use the remote's channel config for the csv delay.
theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay)
// If we are the initiator of this channel, then it's be false from the
// remote's PoV.
isRemoteInitiator := !chanState.IsInitiator
var leaseExpiry uint32
if chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = chanState.ThawHeight
}
// Map the scripts from our PoV. When facing a local commitment, the to
// local output belongs to us and the to remote output belongs to them.
// When facing a remote commitment, the to local output belongs to them
// and the to remote output belongs to us.
// Compute the to local script. From our PoV, when facing a remote
// commitment, the to local output belongs to them.
theirScript, err := CommitScriptToSelf(
chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey,
keyRing.RevocationKey, theirDelay, leaseExpiry,
)
if err != nil {
return ourIndex, theirIndex, err
}
// Compute the to remote script. From our PoV, when facing a remote
// commitment, the to remote output belongs to us.
ourScript, _, err := CommitScriptToRemote(
chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey,
leaseExpiry,
)
if err != nil {
return ourIndex, theirIndex, err
}
// Now compare the scripts to find our/their output index.
for i, txOut := range chanCommit.CommitTx.TxOut {
switch {
case bytes.Equal(txOut.PkScript, ourScript.PkScript):
ourIndex = uint32(i)
case bytes.Equal(txOut.PkScript, theirScript.PkScript):
theirIndex = uint32(i)
}
}
return ourIndex, theirIndex, nil
}
// DeriveCommitmentKeys generates a new commitment key set using the base points
// and commitment point. The keys are derived differently depending on the type
// of channel, and whether the commitment transaction is ours or the remote
// peer's.
func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
isOurCommit bool, chanType mig25.ChannelType,
localChanCfg, remoteChanCfg *mig.ChannelConfig) *CommitmentKeyRing {
tweaklessCommit := chanType.IsTweakless()
// Depending on if this is our commit or not, we'll choose the correct
// base point.
localBasePoint := localChanCfg.PaymentBasePoint
if isOurCommit {
localBasePoint = localChanCfg.DelayBasePoint
}
// First, we'll derive all the keys that don't depend on the context of
// whose commitment transaction this is.
keyRing := &CommitmentKeyRing{
CommitPoint: commitPoint,
LocalCommitKeyTweak: input.SingleTweakBytes(
commitPoint, localBasePoint.PubKey,
),
LocalHtlcKeyTweak: input.SingleTweakBytes(
commitPoint, localChanCfg.HtlcBasePoint.PubKey,
),
LocalHtlcKey: input.TweakPubKey(
localChanCfg.HtlcBasePoint.PubKey, commitPoint,
),
RemoteHtlcKey: input.TweakPubKey(
remoteChanCfg.HtlcBasePoint.PubKey, commitPoint,
),
}
// We'll now compute the to_local, to_remote, and revocation key based
// on the current commitment point. All keys are tweaked each state in
// order to ensure the keys from each state are unlinkable. To create
// the revocation key, we take the opposite party's revocation base
// point and combine that with the current commitment point.
var (
toLocalBasePoint *btcec.PublicKey
toRemoteBasePoint *btcec.PublicKey
revocationBasePoint *btcec.PublicKey
)
if isOurCommit {
toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey
toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey
revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey
} else {
toLocalBasePoint = remoteChanCfg.DelayBasePoint.PubKey
toRemoteBasePoint = localChanCfg.PaymentBasePoint.PubKey
revocationBasePoint = localChanCfg.RevocationBasePoint.PubKey
}
// With the base points assigned, we can now derive the actual keys
// using the base point, and the current commitment tweak.
keyRing.ToLocalKey = input.TweakPubKey(toLocalBasePoint, commitPoint)
keyRing.RevocationKey = input.DeriveRevocationPubkey(
revocationBasePoint, commitPoint,
)
// If this commitment should omit the tweak for the remote point, then
// we'll use that directly, and ignore the commitPoint tweak.
if tweaklessCommit {
keyRing.ToRemoteKey = toRemoteBasePoint
// If this is not our commitment, the above ToRemoteKey will be
// ours, and we blank out the local commitment tweak to
// indicate that the key should not be tweaked when signing.
if !isOurCommit {
keyRing.LocalCommitKeyTweak = nil
}
} else {
keyRing.ToRemoteKey = input.TweakPubKey(
toRemoteBasePoint, commitPoint,
)
}
return keyRing
}
// CommitScriptToRemote derives the appropriate to_remote script based on the
// channel's commitment type. The `initiator` argument should correspond to the
// owner of the commitment transaction which we are generating the to_remote
// script for. The second return value is the CSV delay of the output script,
// what must be satisfied in order to spend the output.
func CommitScriptToRemote(chanType mig25.ChannelType, initiator bool,
key *btcec.PublicKey, leaseExpiry uint32) (*ScriptInfo, uint32, error) {
switch {
// If we are not the initiator of a leased channel, then the remote
// party has an additional CLTV requirement in addition to the 1 block
// CSV requirement.
case chanType.HasLeaseExpiration() && !initiator:
script, err := input.LeaseCommitScriptToRemoteConfirmed(
key, leaseExpiry,
)
if err != nil {
return nil, 0, err
}
p2wsh, err := input.WitnessScriptHash(script)
if err != nil {
return nil, 0, err
}
return &ScriptInfo{
PkScript: p2wsh,
WitnessScript: script,
}, 1, nil
// If this channel type has anchors, we derive the delayed to_remote
// script.
case chanType.HasAnchors():
script, err := input.CommitScriptToRemoteConfirmed(key)
if err != nil {
return nil, 0, err
}
p2wsh, err := input.WitnessScriptHash(script)
if err != nil {
return nil, 0, err
}
return &ScriptInfo{
PkScript: p2wsh,
WitnessScript: script,
}, 1, nil
default:
// Otherwise the to_remote will be a simple p2wkh.
p2wkh, err := input.CommitScriptUnencumbered(key)
if err != nil {
return nil, 0, err
}
// Since this is a regular P2WKH, the WitnessScipt and PkScript
// should both be set to the script hash.
return &ScriptInfo{
WitnessScript: p2wkh,
PkScript: p2wkh,
}, 0, nil
}
}
// CommitScriptToSelf constructs the public key script for the output on the
// commitment transaction paying to the "owner" of said commitment transaction.
// The `initiator` argument should correspond to the owner of the commitment
// transaction which we are generating the to_local script for. If the other
// party learns of the preimage to the revocation hash, then they can claim all
// the settled funds in the channel, plus the unsettled funds.
func CommitScriptToSelf(chanType mig25.ChannelType, initiator bool,
selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32) (
*ScriptInfo, error) {
var (
toLocalRedeemScript []byte
err error
)
switch {
// If we are the initiator of a leased channel, then we have an
// additional CLTV requirement in addition to the usual CSV requirement.
case initiator && chanType.HasLeaseExpiration():
toLocalRedeemScript, err = input.LeaseCommitScriptToSelf(
selfKey, revokeKey, csvDelay, leaseExpiry,
)
default:
toLocalRedeemScript, err = input.CommitScriptToSelf(
csvDelay, selfKey, revokeKey,
)
}
if err != nil {
return nil, err
}
toLocalScriptHash, err := input.WitnessScriptHash(toLocalRedeemScript)
if err != nil {
return nil, err
}
return &ScriptInfo{
PkScript: toLocalScriptHash,
WitnessScript: toLocalRedeemScript,
}, nil
}
package migration30
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration30
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"sync"
mig24 "github.com/lightningnetwork/lnd/channeldb/migration24"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
)
// recordsPerTx specifies the number of records to be migrated in each database
// transaction. In the worst case, each old revocation log is 28,057 bytes.
// 20,000 records would consume 0.56 GB of ram, which is feasible for a modern
// machine.
//
// NOTE: we could've used more ram but it doesn't help with the speed of the
// migration since the most of the CPU time is used for calculating the output
// indexes.
const recordsPerTx = 20_000
// MigrateRevLogConfig is an interface that defines the config that should be
// passed to the MigrateRevocationLog function.
type MigrateRevLogConfig interface {
// GetNoAmountData returns true if the amount data of revoked commitment
// transactions should not be stored in the revocation log.
GetNoAmountData() bool
}
// MigrateRevLogConfigImpl implements the MigrationRevLogConfig interface.
type MigrateRevLogConfigImpl struct {
// NoAmountData if set to true will result in the amount data of revoked
// commitment transactions not being stored in the revocation log.
NoAmountData bool
}
// GetNoAmountData returns true if the amount data of revoked commitment
// transactions should not be stored in the revocation log.
func (c *MigrateRevLogConfigImpl) GetNoAmountData() bool {
return c.NoAmountData
}
// MigrateRevocationLog migrates the old revocation logs into the newer format
// and deletes them once finished, with the deletion only happens once ALL the
// old logs have been migrates.
func MigrateRevocationLog(db kvdb.Backend, cfg MigrateRevLogConfig) error {
log.Infof("Migrating revocation logs, might take a while...")
var (
err error
// finished is used to exit the for loop.
finished bool
// total is the number of total records.
total uint64
// migrated is the number of already migrated records.
migrated uint64
)
// First of all, read the stats of the revocation logs.
total, migrated, err = logMigrationStat(db)
if err != nil {
return err
}
log.Infof("Total logs=%d, migrated=%d", total, migrated)
// Exit early if the old logs have already been migrated and deleted.
if total == 0 {
log.Info("Migration already finished!")
return nil
}
for {
if finished {
log.Infof("Migrating old revocation logs finished, " +
"now checking the migration results...")
break
}
// Process the migration.
err = kvdb.Update(db, func(tx kvdb.RwTx) error {
finished, err = processMigration(tx, cfg)
if err != nil {
return err
}
return nil
}, func() {})
if err != nil {
return err
}
// Each time we finished the above process, we'd read the stats
// again to understand the current progress.
total, migrated, err = logMigrationStat(db)
if err != nil {
return err
}
// Calculate and log the progress if the progress is less than
// one.
progress := float64(migrated) / float64(total) * 100
if progress >= 100 {
continue
}
log.Infof("Migration progress: %.3f%%, still have: %d",
progress, total-migrated)
}
// Before we can safety delete the old buckets, we perform a check to
// make sure the logs are migrated as expected.
err = kvdb.Update(db, validateMigration, func() {})
if err != nil {
return fmt.Errorf("validate migration failed: %w", err)
}
log.Info("Migration check passed, now deleting the old logs...")
// Once the migration completes, we can now safety delete the old
// revocation logs.
if err := deleteOldBuckets(db); err != nil {
return fmt.Errorf("deleteOldBuckets err: %w", err)
}
log.Info("Old revocation log buckets removed!")
return nil
}
// processMigration finds the next un-migrated revocation logs, reads a max
// number of `recordsPerTx` records, converts them into the new revocation logs
// and save them to disk.
func processMigration(tx kvdb.RwTx, cfg MigrateRevLogConfig) (bool, error) {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return false, fmt.Errorf("root bucket not found")
}
// Locate the next migration height.
locator, err := locateNextUpdateNum(openChanBucket)
if err != nil {
return false, fmt.Errorf("locator got error: %w", err)
}
// If the returned locator is nil, we've done migrating the logs.
if locator == nil {
return true, nil
}
// Read a list of old revocation logs.
entryMap, err := readOldRevocationLogs(openChanBucket, locator, cfg)
if err != nil {
return false, fmt.Errorf("read old logs err: %w", err)
}
// Migrate the revocation logs.
return false, writeRevocationLogs(openChanBucket, entryMap)
}
// deleteOldBuckets iterates all the channel buckets and deletes the old
// revocation buckets.
func deleteOldBuckets(db kvdb.Backend) error {
// locators records all the chan buckets found in the database.
var locators []*updateLocator
// reader is a helper closure that saves the locator found. Each
// locator is relatively small(33+32+36+8=109 bytes), assuming 1 GB of
// ram we can fit roughly 10 million records. Since each record
// corresponds to a channel, we should have more than enough memory to
// read them all.
reader := func(_ kvdb.RwBucket, l *updateLocator) error { // nolint:unparam
locators = append(locators, l)
return nil
}
// remover is a helper closure that removes the old revocation log
// bucket under the specified chan bucket by the given locator.
remover := func(rootBucket kvdb.RwBucket, l *updateLocator) error {
chanBucket, err := l.locateChanBucket(rootBucket)
if err != nil {
return err
}
return chanBucket.DeleteNestedBucket(
revocationLogBucketDeprecated,
)
}
// Perform the deletion in one db transaction. This should not cause
// any memory issue as the deletion doesn't load any data from the
// buckets.
return kvdb.Update(db, func(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// Exit early if there's no bucket.
if openChanBucket == nil {
return nil
}
// Iterate the buckets to find all the locators.
err := iterateBuckets(openChanBucket, nil, reader)
if err != nil {
return err
}
// Iterate the locators and delete all the old revocation log
// buckets.
for _, l := range locators {
err := remover(openChanBucket, l)
// If the bucket doesn't exist, we can exit safety.
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
}
return nil
}, func() {})
}
// writeRevocationLogs unwraps the entryMap and writes the new revocation logs.
func writeRevocationLogs(openChanBucket kvdb.RwBucket,
entryMap logEntries) error {
for locator, logs := range entryMap {
// Find the channel bucket.
chanBucket, err := locator.locateChanBucket(openChanBucket)
if err != nil {
return fmt.Errorf("locateChanBucket err: %w", err)
}
// Create the new log bucket.
logBucket, err := chanBucket.CreateBucketIfNotExists(
revocationLogBucket,
)
if err != nil {
return fmt.Errorf("create log bucket err: %w", err)
}
// Write the new logs.
for _, entry := range logs {
var b bytes.Buffer
err := serializeRevocationLog(&b, entry.log)
if err != nil {
return err
}
logEntrykey := mig24.MakeLogKey(entry.commitHeight)
err = logBucket.Put(logEntrykey[:], b.Bytes())
if err != nil {
return fmt.Errorf("putRevocationLog err: %w",
err)
}
}
}
return nil
}
// logMigrationStat reads the buckets to provide stats over current migration
// progress. The returned values are the numbers of total records and already
// migrated records.
func logMigrationStat(db kvdb.Backend) (uint64, uint64, error) {
var (
err error
// total is the number of total records.
total uint64
// unmigrated is the number of unmigrated records.
unmigrated uint64
)
err = kvdb.Update(db, func(tx kvdb.RwTx) error {
total, unmigrated, err = fetchLogStats(tx)
return err
}, func() {})
log.Debugf("Total logs=%d, unmigrated=%d", total, unmigrated)
return total, total - unmigrated, err
}
// fetchLogStats iterates all the chan buckets to provide stats about the logs.
// The returned values are num of total records, and num of un-migrated
// records.
func fetchLogStats(tx kvdb.RwTx) (uint64, uint64, error) {
var (
total uint64
totalUnmigrated uint64
)
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return 0, 0, fmt.Errorf("root bucket not found")
}
// counter is a helper closure used to count the number of records
// based on the given bucket.
counter := func(chanBucket kvdb.RwBucket, bucket []byte) uint64 {
// Read the sub-bucket level 4.
logBucket := chanBucket.NestedReadBucket(bucket)
// Exit early if we don't have the bucket.
if logBucket == nil {
return 0
}
// Jump to the end of the cursor.
key, _ := logBucket.ReadCursor().Last()
// Since the CommitHeight is a zero-based monotonically
// increased index, its value plus one reflects the total
// records under this chan bucket.
lastHeight := binary.BigEndian.Uint64(key) + 1
return lastHeight
}
// countTotal is a callback function used to count the total number of
// records.
countTotal := func(chanBucket kvdb.RwBucket, l *updateLocator) error {
total += counter(chanBucket, revocationLogBucketDeprecated)
return nil
}
// countUnmigrated is a callback function used to count the total
// number of un-migrated records.
countUnmigrated := func(chanBucket kvdb.RwBucket,
l *updateLocator) error {
totalUnmigrated += counter(
chanBucket, revocationLogBucketDeprecated,
)
return nil
}
// Locate the next migration height.
locator, err := locateNextUpdateNum(openChanBucket)
if err != nil {
return 0, 0, fmt.Errorf("locator got error: %w", err)
}
// If the returned locator is not nil, we still have un-migrated
// records so we need to count them. Otherwise we've done migrating the
// logs.
if locator != nil {
err = iterateBuckets(openChanBucket, locator, countUnmigrated)
if err != nil {
return 0, 0, err
}
}
// Count the total number of records by supplying a nil locator.
err = iterateBuckets(openChanBucket, nil, countTotal)
if err != nil {
return 0, 0, err
}
return total, totalUnmigrated, err
}
// logEntry houses the info needed to write a new revocation log.
type logEntry struct {
log *RevocationLog
commitHeight uint64
ourIndex uint32
theirIndex uint32
locator *updateLocator
}
// logEntries maps a bucket locator to a list of entries under that bucket.
type logEntries map[*updateLocator][]*logEntry
// result is made of two channels that's used to send back the constructed new
// revocation log or an error.
type result struct {
newLog chan *logEntry
errChan chan error
}
// readOldRevocationLogs finds a list of old revocation logs and converts them
// into the new revocation logs.
func readOldRevocationLogs(openChanBucket kvdb.RwBucket,
locator *updateLocator, cfg MigrateRevLogConfig) (logEntries, error) {
entries := make(logEntries)
results := make([]*result, 0)
var wg sync.WaitGroup
// collectLogs is a helper closure that reads all newly created
// revocation logs sent over the result channels.
//
// NOTE: the order of the logs cannot be guaranteed, which is fine as
// boltdb will take care of the orders when saving them.
collectLogs := func() error {
wg.Wait()
for _, r := range results {
select {
case entry := <-r.newLog:
entries[entry.locator] = append(
entries[entry.locator], entry,
)
case err := <-r.errChan:
return err
}
}
return nil
}
// createLog is a helper closure that constructs a new revocation log.
//
// NOTE: used as a goroutine.
createLog := func(chanState *mig26.OpenChannel,
c mig.ChannelCommitment, l *updateLocator, r *result) {
defer wg.Done()
// Find the output indexes.
ourIndex, theirIndex, err := findOutputIndexes(chanState, &c)
if err != nil {
r.errChan <- err
}
// Convert the old logs into the new logs. We do this early in
// the read tx so the old large revocation log can be set to
// nil here so save us some memory space.
newLog, err := convertRevocationLog(
&c, ourIndex, theirIndex, cfg.GetNoAmountData(),
)
if err != nil {
r.errChan <- err
}
// Create the entry that will be used to create the new log.
entry := &logEntry{
log: newLog,
commitHeight: c.CommitHeight,
ourIndex: ourIndex,
theirIndex: theirIndex,
locator: l,
}
r.newLog <- entry
}
// innerCb is the stepping function used when iterating the old log
// bucket.
innerCb := func(chanState *mig26.OpenChannel, l *updateLocator,
_, v []byte) error {
reader := bytes.NewReader(v)
c, err := mig.DeserializeChanCommit(reader)
if err != nil {
return err
}
r := &result{
newLog: make(chan *logEntry, 1),
errChan: make(chan error, 1),
}
results = append(results, r)
// We perform the log creation in a goroutine as it takes some
// time to compute and find output indexes.
wg.Add(1)
go createLog(chanState, c, l, r)
// Check the records read so far and signals exit when we've
// reached our memory cap.
if len(results) >= recordsPerTx {
return errExit
}
return nil
}
// cb is the callback function to be used when iterating the buckets.
cb := func(chanBucket kvdb.RwBucket, l *updateLocator) error {
// Read the open channel.
c := &mig26.OpenChannel{}
err := mig26.FetchChanInfo(chanBucket, c, false)
if err != nil {
return fmt.Errorf("unable to fetch chan info: %w", err)
}
err = fetchChanRevocationState(chanBucket, c)
if err != nil {
return fmt.Errorf("unable to fetch revocation "+
"state: %v", err)
}
// Read the sub-bucket level 4.
logBucket := chanBucket.NestedReadBucket(
revocationLogBucketDeprecated,
)
// Exit early if we don't have the old bucket.
if logBucket == nil {
return nil
}
// Init the map key when needed.
_, ok := entries[l]
if !ok {
entries[l] = make([]*logEntry, 0, recordsPerTx)
}
return iterator(
logBucket, locator.nextHeight,
func(k, v []byte) error {
// Reset the nextHeight for following chan
// buckets.
locator.nextHeight = nil
return innerCb(c, l, k, v)
},
)
}
err := iterateBuckets(openChanBucket, locator, cb)
// If there's an error and it's not exit signal, we won't collect the
// logs from the result channels.
if err != nil && err != errExit {
return nil, err
}
// Otherwise, collect the logs.
err = collectLogs()
return entries, err
}
// convertRevocationLog uses the fields `CommitTx` and `Htlcs` from a
// ChannelCommitment to construct a revocation log entry.
func convertRevocationLog(commit *mig.ChannelCommitment,
ourOutputIndex, theirOutputIndex uint32,
noAmtData bool) (*RevocationLog, error) {
// Sanity check that the output indexes can be safely converted.
if ourOutputIndex > math.MaxUint16 {
return nil, ErrOutputIndexTooBig
}
if theirOutputIndex > math.MaxUint16 {
return nil, ErrOutputIndexTooBig
}
rl := &RevocationLog{
OurOutputIndex: uint16(ourOutputIndex),
TheirOutputIndex: uint16(theirOutputIndex),
CommitTxHash: commit.CommitTx.TxHash(),
HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)),
}
if !noAmtData {
rl.TheirBalance = &commit.RemoteBalance
rl.OurBalance = &commit.LocalBalance
}
for _, htlc := range commit.Htlcs {
// Skip dust HTLCs.
if htlc.OutputIndex < 0 {
continue
}
// Sanity check that the output indexes can be safely
// converted.
if htlc.OutputIndex > math.MaxUint16 {
return nil, ErrOutputIndexTooBig
}
entry := &HTLCEntry{
RHash: htlc.RHash,
RefundTimeout: htlc.RefundTimeout,
Incoming: htlc.Incoming,
OutputIndex: uint16(htlc.OutputIndex),
Amt: htlc.Amt.ToSatoshis(),
}
rl.HTLCEntries = append(rl.HTLCEntries, entry)
}
return rl, nil
}
// validateMigration checks that the data saved in the new buckets match those
// saved in the old buckets. It does so by checking the last keys saved in both
// buckets can match, given the assumption that the `CommitHeight` is
// monotonically increased value so the last key represents the total number of
// records saved.
func validateMigration(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket)
// If no bucket is found, we can exit early.
if openChanBucket == nil {
return nil
}
// exitWithErr is a helper closure that prepends an error message with
// the locator info.
exitWithErr := func(l *updateLocator, msg string) error {
return fmt.Errorf("unmatched records found under <nodePub=%x"+
", chainHash=%x, fundingOutpoint=%x>: %v", l.nodePub,
l.chainHash, l.fundingOutpoint, msg)
}
// cb is the callback function to be used when iterating the buckets.
cb := func(chanBucket kvdb.RwBucket, l *updateLocator) error {
// Read both the old and new revocation log buckets.
oldBucket := chanBucket.NestedReadBucket(
revocationLogBucketDeprecated,
)
newBucket := chanBucket.NestedReadBucket(revocationLogBucket)
// Exit early if the old bucket is nil.
//
// NOTE: the new bucket may not be nil here as new logs might
// have been created using lnd@v0.15.0.
if oldBucket == nil {
return nil
}
// Return an error if the expected new bucket cannot be found.
if newBucket == nil {
return exitWithErr(l, "expected new bucket")
}
// Acquire the cursors.
oldCursor := oldBucket.ReadCursor()
newCursor := newBucket.ReadCursor()
// Jump to the end of the cursors to do a quick check.
newKey, _ := oldCursor.Last()
oldKey, _ := newCursor.Last()
// We expected the CommitHeights to be matched for nodes prior
// to v0.15.0.
if bytes.Equal(newKey, oldKey) {
return nil
}
// If the keys do not match, it's likely the node is running
// v0.15.0 and have new logs created. In this case, we will
// validate that every record in the old bucket can be found in
// the new bucket.
oldKey, _ = oldCursor.First()
for {
// Try to locate the old key in the new bucket and we
// expect it to be found.
newKey, _ := newCursor.Seek(oldKey)
// If the old key is not found in the new bucket,
// return an error.
//
// NOTE: because Seek will return the next key when the
// passed key cannot be found, we need to compare the
// keys to deicde whether the old key is found or not.
if !bytes.Equal(newKey, oldKey) {
errMsg := fmt.Sprintf("old bucket has "+
"CommitHeight=%v cannot be found in "+
"new bucket", oldKey)
return exitWithErr(l, errMsg)
}
// Otherwise, keep iterating the old bucket.
oldKey, _ = oldCursor.Next()
// If we've done iterating, all keys have been matched
// and we can safely exit.
if oldKey == nil {
return nil
}
}
}
return iterateBuckets(openChanBucket, nil, cb)
}
package migration30
import (
"bytes"
"errors"
"io"
"math"
"github.com/btcsuite/btcd/btcutil"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
mig24 "github.com/lightningnetwork/lnd/channeldb/migration24"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// OutputIndexEmpty is used when the output index doesn't exist.
OutputIndexEmpty = math.MaxUint16
// A set of tlv type definitions used to serialize the body of
// revocation logs to the database.
//
// NOTE: A migration should be added whenever this list changes.
revLogOurOutputIndexType tlv.Type = 0
revLogTheirOutputIndexType tlv.Type = 1
revLogCommitTxHashType tlv.Type = 2
revLogOurBalanceType tlv.Type = 3
revLogTheirBalanceType tlv.Type = 4
)
var (
// revocationLogBucketDeprecated is dedicated for storing the necessary
// delta state between channel updates required to re-construct a past
// state in order to punish a counterparty attempting a non-cooperative
// channel closure. This key should be accessed from within the
// sub-bucket of a target channel, identified by its channel point.
//
// Deprecated: This bucket is kept for read-only in case the user
// choose not to migrate the old data.
revocationLogBucketDeprecated = []byte("revocation-log-key")
// revocationLogBucket is a sub-bucket under openChannelBucket. This
// sub-bucket is dedicated for storing the minimal info required to
// re-construct a past state in order to punish a counterparty
// attempting a non-cooperative channel closure.
revocationLogBucket = []byte("revocation-log")
// revocationStateKey stores their current revocation hash, our
// preimage producer and their preimage store.
revocationStateKey = []byte("revocation-state-key")
// ErrNoRevocationsFound is returned when revocation state for a
// particular channel cannot be found.
ErrNoRevocationsFound = errors.New("no revocations found")
// ErrLogEntryNotFound is returned when we cannot find a log entry at
// the height requested in the revocation log.
ErrLogEntryNotFound = errors.New("log entry not found")
// ErrOutputIndexTooBig is returned when the output index is greater
// than uint16.
ErrOutputIndexTooBig = errors.New("output index is over uint16")
)
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
// historical HTLCs, which is useful for constructing RevocationLog when a
// breach is detected.
// The actual size of each HTLCEntry varies based on its RHash and Amt(sat),
// summarized as follows,
//
// | RHash empty | Amt<=252 | Amt<=65,535 | Amt<=4,294,967,295 | otherwise |
// |:-----------:|:--------:|:-----------:|:------------------:|:---------:|
// | true | 19 | 21 | 23 | 26 |
// | false | 51 | 53 | 55 | 58 |
//
// So the size varies from 19 bytes to 58 bytes, where most likely to be 23 or
// 55 bytes.
//
// NOTE: all the fields saved to disk use the primitive go types so they can be
// made into tlv records without further conversion.
type HTLCEntry struct {
// RHash is the payment hash of the HTLC.
RHash [32]byte
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout uint32
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
//
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
// gives us a max number of HTLCs of 65K.
OutputIndex uint16
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
//
// NOTE: this field is the memory representation of the field
// incomingUint.
Incoming bool
// Amt is the amount of satoshis this HTLC escrows.
//
// NOTE: this field is the memory representation of the field amtUint.
Amt btcutil.Amount
// amtTlv is the uint64 format of Amt. This field is created so we can
// easily make it into a tlv record and save it to disk.
//
// NOTE: we keep this field for accounting purpose only. If the disk
// space becomes an issue, we could delete this field to save us extra
// 8 bytes.
amtTlv uint64
// incomingTlv is the uint8 format of Incoming. This field is created
// so we can easily make it into a tlv record and save it to disk.
incomingTlv uint8
}
// RHashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (h *HTLCEntry) RHashLen() uint64 {
if h.RHash == lntypes.ZeroHash {
return 0
}
return 32
}
// RHashEncoder is the customized encoder which skips encoding the empty hash.
func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the value is an empty hash, we will skip encoding it.
if *v == lntypes.ZeroHash {
return nil
}
return tlv.EBytes32(w, v, buf)
}
// RHashDecoder is the customized decoder which skips decoding the empty hash.
func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
v, ok := val.(*[32]byte)
if !ok {
return tlv.NewTypeForEncodingErr(val, "RHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
return tlv.DBytes32(r, v, buf, 32)
}
// toTlvStream converts an HTLCEntry record into a tlv representation.
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
const (
// A set of tlv type definitions used to serialize htlc entries
// to the database. We define it here instead of the head of
// the file to avoid naming conflicts.
//
// NOTE: A migration should be added whenever this list
// changes.
rHashType tlv.Type = 0
refundTimeoutType tlv.Type = 1
outputIndexType tlv.Type = 2
incomingType tlv.Type = 3
amtType tlv.Type = 4
)
return tlv.NewStream(
tlv.MakeDynamicRecord(
rHashType, &h.RHash, h.RHashLen,
RHashEncoder, RHashDecoder,
),
tlv.MakePrimitiveRecord(
refundTimeoutType, &h.RefundTimeout,
),
tlv.MakePrimitiveRecord(
outputIndexType, &h.OutputIndex,
),
tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv),
// We will save 3 bytes if the amount is less or equal to
// 4,294,967,295 msat, or roughly 0.043 bitcoin.
tlv.MakeBigSizeRecord(amtType, &h.amtTlv),
)
}
// RevocationLog stores the info needed to construct a breach retribution. Its
// fields can be viewed as a subset of a ChannelCommitment's. In the database,
// all historical versions of the RevocationLog are saved using the
// CommitHeight as the key.
type RevocationLog struct {
// OurOutputIndex specifies our output index in this commitment. In a
// remote commitment transaction, this is the to remote output index.
OurOutputIndex uint16
// TheirOutputIndex specifies their output index in this commitment. In
// a remote commitment transaction, this is the to local output index.
TheirOutputIndex uint16
// CommitTxHash is the hash of the latest version of the commitment
// state, broadcast able by us.
CommitTxHash [32]byte
// HTLCEntries is the set of HTLCEntry's that are pending at this
// particular commitment height.
HTLCEntries []*HTLCEntry
// OurBalance is the current available balance within the channel
// directly spendable by us. In other words, it is the value of the
// to_remote output on the remote parties' commitment transaction.
//
// NOTE: this is a pointer so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for
// all revocation logs.
OurBalance *lnwire.MilliSatoshi
// TheirBalance is the current available balance within the channel
// directly spendable by the remote node. In other words, it is the
// value of the to_local output on the remote parties' commitment.
//
// NOTE: this is a pointer so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for
// all revocation logs.
TheirBalance *lnwire.MilliSatoshi
}
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
// ChannelCommitment to construct a revocation log entry and saves them to
// disk. It also saves our output index and their output index, which are
// useful when creating breach retribution.
func putRevocationLog(bucket kvdb.RwBucket, commit *mig.ChannelCommitment,
ourOutputIndex, theirOutputIndex uint32, noAmtData bool) error {
// Sanity check that the output indexes can be safely converted.
if ourOutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
if theirOutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
rl := &RevocationLog{
OurOutputIndex: uint16(ourOutputIndex),
TheirOutputIndex: uint16(theirOutputIndex),
CommitTxHash: commit.CommitTx.TxHash(),
HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)),
}
if !noAmtData {
rl.OurBalance = &commit.LocalBalance
rl.TheirBalance = &commit.RemoteBalance
}
for _, htlc := range commit.Htlcs {
// Skip dust HTLCs.
if htlc.OutputIndex < 0 {
continue
}
// Sanity check that the output indexes can be safely
// converted.
if htlc.OutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
entry := &HTLCEntry{
RHash: htlc.RHash,
RefundTimeout: htlc.RefundTimeout,
Incoming: htlc.Incoming,
OutputIndex: uint16(htlc.OutputIndex),
Amt: htlc.Amt.ToSatoshis(),
}
rl.HTLCEntries = append(rl.HTLCEntries, entry)
}
var b bytes.Buffer
err := serializeRevocationLog(&b, rl)
if err != nil {
return err
}
logEntrykey := mig24.MakeLogKey(commit.CommitHeight)
return bucket.Put(logEntrykey[:], b.Bytes())
}
// fetchRevocationLog queries the revocation log bucket to find an log entry.
// Return an error if not found.
func fetchRevocationLog(log kvdb.RBucket,
updateNum uint64) (RevocationLog, error) {
logEntrykey := mig24.MakeLogKey(updateNum)
commitBytes := log.Get(logEntrykey[:])
if commitBytes == nil {
return RevocationLog{}, ErrLogEntryNotFound
}
commitReader := bytes.NewReader(commitBytes)
return deserializeRevocationLog(commitReader)
}
// serializeRevocationLog serializes a RevocationLog record based on tlv
// format.
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// Add the tlv records for all non-optional fields.
records := []tlv.Record{
tlv.MakePrimitiveRecord(
revLogOurOutputIndexType, &rl.OurOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
}
// Now we add any optional fields that are non-nil.
if rl.OurBalance != nil {
lb := uint64(*rl.OurBalance)
records = append(records, tlv.MakeBigSizeRecord(
revLogOurBalanceType, &lb,
))
}
if rl.TheirBalance != nil {
rb := uint64(*rl.TheirBalance)
records = append(records, tlv.MakeBigSizeRecord(
revLogTheirBalanceType, &rb,
))
}
// Create the tlv stream.
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
// Write the tlv stream.
if err := writeTlvStream(w, tlvStream); err != nil {
return err
}
// Write the HTLCs.
return serializeHTLCEntries(w, rl.HTLCEntries)
}
// serializeHTLCEntries serializes a list of HTLCEntry records based on tlv
// format.
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
for _, htlc := range htlcs {
// Patch the incomingTlv field.
if htlc.Incoming {
htlc.incomingTlv = 1
}
// Patch the amtTlv field.
htlc.amtTlv = uint64(htlc.Amt)
// Create the tlv stream.
tlvStream, err := htlc.toTlvStream()
if err != nil {
return err
}
// Write the tlv stream.
if err := writeTlvStream(w, tlvStream); err != nil {
return err
}
}
return nil
}
// deserializeRevocationLog deserializes a RevocationLog based on tlv format.
func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
var (
rl RevocationLog
localBalance uint64
remoteBalance uint64
)
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(
revLogOurOutputIndexType, &rl.OurOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogTheirOutputIndexType, &rl.TheirOutputIndex,
),
tlv.MakePrimitiveRecord(
revLogCommitTxHashType, &rl.CommitTxHash,
),
tlv.MakeBigSizeRecord(revLogOurBalanceType, &localBalance),
tlv.MakeBigSizeRecord(
revLogTheirBalanceType, &remoteBalance,
),
)
if err != nil {
return rl, err
}
// Read the tlv stream.
parsedTypes, err := readTlvStream(r, tlvStream)
if err != nil {
return rl, err
}
if t, ok := parsedTypes[revLogOurBalanceType]; ok && t == nil {
lb := lnwire.MilliSatoshi(localBalance)
rl.OurBalance = &lb
}
if t, ok := parsedTypes[revLogTheirBalanceType]; ok && t == nil {
rb := lnwire.MilliSatoshi(remoteBalance)
rl.TheirBalance = &rb
}
// Read the HTLC entries.
rl.HTLCEntries, err = deserializeHTLCEntries(r)
return rl, err
}
// deserializeHTLCEntries deserializes a list of HTLC entries based on tlv
// format.
func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
var htlcs []*HTLCEntry
for {
var htlc HTLCEntry
// Create the tlv stream.
tlvStream, err := htlc.toTlvStream()
if err != nil {
return nil, err
}
// Read the HTLC entry.
if _, err := readTlvStream(r, tlvStream); err != nil {
// We've reached the end when hitting an EOF.
if err == io.ErrUnexpectedEOF {
break
}
return nil, err
}
// Patch the Incoming field.
if htlc.incomingTlv == 1 {
htlc.Incoming = true
}
// Patch the Amt field.
htlc.Amt = btcutil.Amount(htlc.amtTlv)
// Append the entry.
htlcs = append(htlcs, &htlc)
}
return htlcs, nil
}
// writeTlvStream is a helper function that encodes the tlv stream into the
// writer.
func writeTlvStream(w io.Writer, s *tlv.Stream) error {
var b bytes.Buffer
if err := s.Encode(&b); err != nil {
return err
}
// Write the stream's length as a varint.
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
if err != nil {
return err
}
if _, err = w.Write(b.Bytes()); err != nil {
return err
}
return nil
}
// readTlvStream is a helper function that decodes the tlv stream from the
// reader.
func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
var bodyLen uint64
// Read the stream's length.
bodyLen, err := tlv.ReadVarInt(r, &[8]byte{})
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
// invalid record.
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return nil, err
}
// TODO(yy): add overflow check.
lr := io.LimitReader(r, int64(bodyLen))
return s.DecodeWithParsedTypes(lr)
}
// fetchLogBucket returns a read bucket by visiting both the new and the old
// bucket.
func fetchLogBucket(chanBucket kvdb.RBucket) (kvdb.RBucket, error) {
logBucket := chanBucket.NestedReadBucket(revocationLogBucket)
if logBucket == nil {
logBucket = chanBucket.NestedReadBucket(
revocationLogBucketDeprecated,
)
if logBucket == nil {
return nil, mig25.ErrNoPastDeltas
}
}
return logBucket, nil
}
// putOldRevocationLog saves a revocation log using the old format.
func putOldRevocationLog(log kvdb.RwBucket,
commit *mig.ChannelCommitment) error {
var b bytes.Buffer
if err := mig.SerializeChanCommit(&b, commit); err != nil {
return err
}
logEntrykey := mig24.MakeLogKey(commit.CommitHeight)
return log.Put(logEntrykey[:], b.Bytes())
}
func putChanRevocationState(chanBucket kvdb.RwBucket,
channel *mig26.OpenChannel) error {
var b bytes.Buffer
err := mig.WriteElements(
&b, channel.RemoteCurrentRevocation, channel.RevocationProducer,
channel.RevocationStore,
)
if err != nil {
return err
}
// TODO(roasbeef): don't keep producer on disk
// If the next revocation is present, which is only the case after the
// FundingLocked message has been sent, then we'll write it to disk.
if channel.RemoteNextRevocation != nil {
err = mig.WriteElements(&b, channel.RemoteNextRevocation)
if err != nil {
return err
}
}
return chanBucket.Put(revocationStateKey, b.Bytes())
}
func fetchChanRevocationState(chanBucket kvdb.RBucket,
c *mig26.OpenChannel) error {
revBytes := chanBucket.Get(revocationStateKey)
if revBytes == nil {
return ErrNoRevocationsFound
}
r := bytes.NewReader(revBytes)
err := mig.ReadElements(
r, &c.RemoteCurrentRevocation, &c.RevocationProducer,
&c.RevocationStore,
)
if err != nil {
return err
}
// If there aren't any bytes left in the buffer, then we don't yet have
// the next remote revocation, so we can exit early here.
if r.Len() == 0 {
return nil
}
// Otherwise we'll read the next revocation for the remote party which
// is always the last item within the buffer.
return mig.ReadElements(r, &c.RemoteNextRevocation)
}
func findOutputIndexes(chanState *mig26.OpenChannel,
oldLog *mig.ChannelCommitment) (uint32, uint32, error) {
// With the state number broadcast known, we can now derive/restore the
// proper revocation preimage necessary to sweep the remote party's
// output.
revocationPreimage, err := chanState.RevocationStore.LookUp(
oldLog.CommitHeight,
)
if err != nil {
return 0, 0, err
}
return findOutputIndexesFromRemote(
revocationPreimage, chanState, oldLog,
)
}
package migration30
import (
"encoding/binary"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/shachain"
"github.com/stretchr/testify/mock"
)
// mockStore mocks the shachain.Store.
type mockStore struct {
mock.Mock
}
// A compile time check to ensure mockStore implements the Store interface.
var _ shachain.Store = (*mockStore)(nil)
func (m *mockStore) LookUp(height uint64) (*chainhash.Hash, error) {
args := m.Called(height)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*chainhash.Hash), args.Error(1)
}
func (m *mockStore) AddNextEntry(preimage *chainhash.Hash) error {
args := m.Called(preimage)
return args.Error(0)
}
// Encode encodes a series of dummy values to pass the serialize/deserialize
// process.
func (m *mockStore) Encode(w io.Writer) error {
err := binary.Write(w, binary.BigEndian, int8(1))
if err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, uint64(0)); err != nil {
return err
}
if _, err = w.Write(preimage2); err != nil {
return err
}
return binary.Write(w, binary.BigEndian, uint64(0))
}
package migration30
import (
"bytes"
"fmt"
"math/rand"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
mig25 "github.com/lightningnetwork/lnd/channeldb/migration25"
mig26 "github.com/lightningnetwork/lnd/channeldb/migration26"
mig "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/shachain"
)
var (
testChainHash = chainhash.Hash{1, 2, 3}
testChanType = mig25.SingleFunderTweaklessBit
// testOurIndex and testTheirIndex are artificial indexes that're saved
// to db during test setup. They are different from indexes populated
// from the actual migration process so we can check whether a new
// revocation log is overwritten or not.
testOurIndex = uint32(100)
testTheirIndex = uint32(200)
// dummyInput is used in our commit tx.
dummyInput = &wire.TxIn{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{},
Index: 0xffffffff,
},
Sequence: 0xffffffff,
}
// htlcScript is the PkScript used in the HTLC output. This script
// corresponds to revocation preimage2.
htlcScript = []byte{
0x0, 0x20, 0x3d, 0x51, 0x66, 0xda, 0x39, 0x93,
0x7b, 0x49, 0xaf, 0x2, 0xf2, 0x2f, 0x90, 0x52,
0x8e, 0x45, 0x24, 0x34, 0x8f, 0xd8, 0x76, 0x7,
0x5a, 0xfc, 0x52, 0x8d, 0x68, 0xdd, 0xbc, 0xce,
0x3e, 0x5d,
}
// toLocalScript is the PkScript used in to-local output.
toLocalScript = []byte{
0x0, 0x14, 0xc6, 0x9, 0x62, 0xab, 0x60, 0xbe,
0x40, 0xd, 0xab, 0x31, 0xc, 0x13, 0x14, 0x15,
0x93, 0xe6, 0xa2, 0x94, 0xe4, 0x2a,
}
// preimage1 defines a revocation preimage, generated from itest.
preimage1 = []byte{
0x95, 0xb4, 0x7c, 0x5a, 0x2b, 0xfd, 0x6f, 0xf4,
0x70, 0x8, 0xc, 0x70, 0x82, 0x36, 0xc8, 0x5,
0x88, 0x16, 0xaf, 0x29, 0xb5, 0x8, 0xfd, 0x5a,
0x40, 0x28, 0x24, 0xc, 0x2a, 0x7f, 0x96, 0xcd,
}
// commitTx1 is the tx saved in the first old revocation.
commitTx1 = &wire.MsgTx{
Version: 2,
// Add a dummy input.
TxIn: []*wire.TxIn{dummyInput},
TxOut: []*wire.TxOut{
{
Value: 990_950,
PkScript: toLocalScript,
},
},
}
// logHeight1 is the CommitHeight used by oldLog1.
logHeight1 = uint64(0)
// localBalance1 is the LocalBalance used in oldLog1.
localBalance1 = lnwire.MilliSatoshi(990_950_000)
// remoteBalance1 is the RemoteBalance used in oldLog1.
remoteBalance1 = lnwire.MilliSatoshi(0)
// oldLog1 defines an old revocation that has no HTLCs.
oldLog1 = mig.ChannelCommitment{
CommitHeight: logHeight1,
LocalLogIndex: 0,
LocalHtlcIndex: 0,
RemoteLogIndex: 0,
RemoteHtlcIndex: 0,
LocalBalance: localBalance1,
RemoteBalance: remoteBalance1,
CommitTx: commitTx1,
}
// newLog1 is the new version of oldLog1 in the case were we don't want
// to store any balance data.
newLog1 = RevocationLog{
OurOutputIndex: 0,
TheirOutputIndex: OutputIndexEmpty,
CommitTxHash: commitTx1.TxHash(),
}
// newLog1WithAmts is the new version of oldLog1 in the case were we do
// want to store balance data.
newLog1WithAmts = RevocationLog{
OurOutputIndex: 0,
TheirOutputIndex: OutputIndexEmpty,
CommitTxHash: commitTx1.TxHash(),
OurBalance: &localBalance1,
TheirBalance: &remoteBalance1,
}
// preimage2 defines the second revocation preimage used in the test,
// generated from itest.
preimage2 = []byte{
0xac, 0x60, 0x7a, 0x59, 0x9, 0xd6, 0x11, 0xb2,
0xf5, 0x6e, 0xaa, 0xc6, 0xb9, 0x0, 0x12, 0xdc,
0xf0, 0x89, 0x58, 0x90, 0x8a, 0xa2, 0xc6, 0xfc,
0xf1, 0x2, 0x74, 0x87, 0x30, 0x51, 0x5e, 0xea,
}
// commitTx2 is the tx saved in the second old revocation.
commitTx2 = &wire.MsgTx{
Version: 2,
// Add a dummy input.
TxIn: []*wire.TxIn{dummyInput},
TxOut: []*wire.TxOut{
{
Value: 100_000,
PkScript: htlcScript,
},
{
Value: 888_800,
PkScript: toLocalScript,
},
},
}
// rHash is the payment hash used in the htlc below.
rHash = [32]byte{
0x42, 0x5e, 0xd4, 0xe4, 0xa3, 0x6b, 0x30, 0xea,
0x21, 0xb9, 0xe, 0x21, 0xc7, 0x12, 0xc6, 0x49,
0xe8, 0x21, 0x4c, 0x29, 0xb7, 0xea, 0xf6, 0x80,
0x89, 0xd1, 0x3, 0x9c, 0x6e, 0x55, 0x38, 0x4c,
}
// htlc defines an HTLC that's saved in the old revocation log.
htlc = mig.HTLC{
RHash: rHash,
Amt: lnwire.MilliSatoshi(100_000_000),
RefundTimeout: 489,
OutputIndex: 0,
Incoming: false,
OnionBlob: bytes.Repeat([]byte{0xff}, 1366),
HtlcIndex: 0,
LogIndex: 0,
}
// logHeight2 is the CommitHeight used by oldLog2.
logHeight2 = uint64(1)
// localBalance2 is the LocalBalance used in oldLog2.
localBalance2 = lnwire.MilliSatoshi(888_800_000)
// remoteBalance2 is the RemoteBalance used in oldLog2.
remoteBalance2 = lnwire.MilliSatoshi(0)
// oldLog2 defines an old revocation that has one HTLC.
oldLog2 = mig.ChannelCommitment{
CommitHeight: logHeight2,
LocalLogIndex: 1,
LocalHtlcIndex: 1,
RemoteLogIndex: 0,
RemoteHtlcIndex: 0,
LocalBalance: localBalance2,
RemoteBalance: remoteBalance2,
CommitTx: commitTx2,
Htlcs: []mig.HTLC{htlc},
}
// newLog2 is the new version of the oldLog2 in the case were we don't
// want to store any balance data.
newLog2 = RevocationLog{
OurOutputIndex: 1,
TheirOutputIndex: OutputIndexEmpty,
CommitTxHash: commitTx2.TxHash(),
HTLCEntries: []*HTLCEntry{
{
RHash: rHash,
RefundTimeout: 489,
OutputIndex: 0,
Incoming: false,
Amt: btcutil.Amount(100_000),
},
},
}
// newLog2WithAmts is the new version of oldLog2 in the case were we do
// want to store balance data.
newLog2WithAmts = RevocationLog{
OurOutputIndex: 1,
TheirOutputIndex: OutputIndexEmpty,
CommitTxHash: commitTx2.TxHash(),
OurBalance: &localBalance2,
TheirBalance: &remoteBalance2,
HTLCEntries: []*HTLCEntry{
{
RHash: rHash,
RefundTimeout: 489,
OutputIndex: 0,
Incoming: false,
Amt: btcutil.Amount(100_000),
},
},
}
// newLog3 defines an revocation log that's been created after v0.15.0.
newLog3 = mig.ChannelCommitment{
CommitHeight: logHeight2 + 1,
LocalLogIndex: 1,
LocalHtlcIndex: 1,
RemoteLogIndex: 0,
RemoteHtlcIndex: 0,
LocalBalance: lnwire.MilliSatoshi(888_800_000),
RemoteBalance: 0,
CommitTx: commitTx2,
Htlcs: []mig.HTLC{htlc},
}
// The following public keys are taken from the itest results.
localMusigKey, _ = btcec.ParsePubKey([]byte{
0x2,
0xda, 0x42, 0xa4, 0x4a, 0x6b, 0x42, 0xfe, 0xcb,
0x2f, 0x7e, 0x35, 0x89, 0x99, 0xdd, 0x43, 0xba,
0x4b, 0xf1, 0x9c, 0xf, 0x18, 0xef, 0x9, 0x83,
0x35, 0x31, 0x59, 0xa4, 0x3b, 0xde, 0xa, 0xde,
})
localRevocationBasePoint, _ = btcec.ParsePubKey([]byte{
0x2,
0x6, 0x16, 0xd1, 0xb1, 0x4f, 0xee, 0x11, 0x86,
0x55, 0xfe, 0x31, 0x66, 0x6f, 0x43, 0x1, 0x80,
0xa8, 0xa7, 0x5c, 0x2, 0x92, 0xe5, 0x7c, 0x4,
0x31, 0xa6, 0xcf, 0x43, 0xb6, 0xdb, 0xe6, 0x10,
})
localPaymentBasePoint, _ = btcec.ParsePubKey([]byte{
0x2,
0x88, 0x65, 0x16, 0xc2, 0x37, 0x3f, 0xc5, 0x16,
0x62, 0x71, 0x0, 0xdd, 0x4d, 0x43, 0x28, 0x43,
0x32, 0x91, 0x75, 0xcc, 0xd8, 0x81, 0xb6, 0xb0,
0xd8, 0x96, 0x78, 0xad, 0x18, 0x3b, 0x16, 0xe1,
})
localDelayBasePoint, _ = btcec.ParsePubKey([]byte{
0x2,
0xea, 0x41, 0x48, 0x11, 0x2, 0x59, 0xe3, 0x5c,
0x51, 0x15, 0x90, 0x25, 0x4a, 0x61, 0x5, 0x51,
0xb3, 0x8, 0xe9, 0xd5, 0xf, 0xc6, 0x91, 0x25,
0x14, 0xd2, 0xcf, 0xc8, 0xc5, 0x5b, 0xd9, 0x88,
})
localHtlcBasePoint, _ = btcec.ParsePubKey([]byte{
0x3,
0xfa, 0x1f, 0x6, 0x3a, 0xa4, 0x75, 0x2e, 0x74,
0x3e, 0x55, 0x9, 0x20, 0x6e, 0xf6, 0xa8, 0xe1,
0xd7, 0x61, 0x50, 0x75, 0xa8, 0x34, 0x15, 0xc3,
0x6b, 0xdc, 0xb0, 0xbf, 0xaa, 0x66, 0xd7, 0xa7,
})
remoteMultiSigKey, _ = btcec.ParsePubKey([]byte{
0x2,
0x2b, 0x88, 0x7c, 0x6a, 0xf8, 0xb3, 0x51, 0x61,
0xd3, 0x1c, 0xf1, 0xe4, 0x43, 0xc2, 0x8c, 0x5e,
0xfa, 0x8e, 0xb5, 0xe9, 0xd0, 0x14, 0xb5, 0x33,
0x6a, 0xcc, 0xd, 0x11, 0x42, 0xb8, 0x4b, 0x7d,
})
remoteRevocationBasePoint, _ = btcec.ParsePubKey([]byte{
0x2,
0x6c, 0x39, 0xa3, 0x6d, 0x93, 0x69, 0xac, 0x14,
0x1f, 0xbb, 0x4, 0x86, 0x3, 0x82, 0x5, 0xe2,
0xcb, 0xb0, 0x62, 0x41, 0xa, 0x93, 0x3, 0x6c,
0x8d, 0xc0, 0x42, 0x4d, 0x9e, 0x51, 0x9b, 0x36,
})
remotePaymentBasePoint, _ = btcec.ParsePubKey([]byte{
0x3,
0xab, 0x74, 0x1e, 0x83, 0x48, 0xe3, 0xb5, 0x6,
0x25, 0x1c, 0x80, 0xe7, 0xf2, 0x3e, 0x7d, 0xb7,
0x7a, 0xc7, 0xd, 0x6, 0x3b, 0xbc, 0x74, 0x96,
0x8e, 0x9b, 0x2d, 0xd1, 0x42, 0x71, 0xa5, 0x2a,
})
remoteDelayBasePoint, _ = btcec.ParsePubKey([]byte{
0x2,
0x4b, 0xdd, 0x52, 0x46, 0x1b, 0x50, 0x89, 0xb9,
0x49, 0x4, 0xf2, 0xd2, 0x98, 0x7d, 0x51, 0xa1,
0xa6, 0x3f, 0x9b, 0xd0, 0x40, 0x7c, 0x93, 0x74,
0x3b, 0x8c, 0x4d, 0x63, 0x32, 0x90, 0xa, 0xca,
})
remoteHtlcBasePoint, _ = btcec.ParsePubKey([]byte{
0x3,
0x5b, 0x8f, 0x4a, 0x71, 0x4c, 0x2e, 0x71, 0x14,
0x86, 0x1f, 0x30, 0x96, 0xc0, 0xd4, 0x11, 0x76,
0xf8, 0xc3, 0xfc, 0x7, 0x2d, 0x15, 0x99, 0x55,
0x8, 0x69, 0xf6, 0x1, 0xa2, 0xcd, 0x6b, 0xa7,
})
// withAmtData if set, will result in the amount data of the revoked
// commitment transactions also being stored in the new revocation log.
// The value of this variable is set randomly in the init function of
// this package.
withAmtData bool
)
// setupTestLogs takes care of creating the related buckets and inserts testing
// records.
func setupTestLogs(db kvdb.Backend, c *mig26.OpenChannel,
oldLogs, newLogs []mig.ChannelCommitment) error {
return kvdb.Update(db, func(tx kvdb.RwTx) error {
// If the open channel is nil, only create the root
// bucket and skip creating the channel bucket.
if c == nil {
_, err := tx.CreateTopLevelBucket(openChannelBucket)
return err
}
// Create test buckets.
chanBucket, err := mig25.CreateChanBucket(tx, &c.OpenChannel)
if err != nil {
return err
}
// Save channel info.
if err := mig26.PutChanInfo(chanBucket, c, false); err != nil {
return fmt.Errorf("PutChanInfo got %w", err)
}
// Save revocation state.
if err := putChanRevocationState(chanBucket, c); err != nil {
return fmt.Errorf("putChanRevocationState got %w", err)
}
// Create old logs.
err = writeOldRevocationLogs(chanBucket, oldLogs)
if err != nil {
return fmt.Errorf("write old logs: %w", err)
}
// Create new logs.
return writeNewRevocationLogs(chanBucket, newLogs, !withAmtData)
}, func() {})
}
// createTestChannel creates an OpenChannel using the specified nodePub and
// outpoint. If any of the params is nil, a random value is populated.
func createTestChannel(nodePub *btcec.PublicKey) *mig26.OpenChannel {
// Create a random private key that's used to provide randomness.
priv, _ := btcec.NewPrivateKey()
// If passed public key is nil, use the random public key.
if nodePub == nil {
nodePub = priv.PubKey()
}
// Create a random channel point.
var op wire.OutPoint
copy(op.Hash[:], priv.Serialize())
testProducer := shachain.NewRevocationProducer(op.Hash)
store, _ := createTestStore()
localCfg := mig.ChannelConfig{
ChannelConstraints: mig.ChannelConstraints{
DustLimit: btcutil.Amount(354),
MaxAcceptedHtlcs: 483,
CsvDelay: 4,
},
MultiSigKey: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: localMusigKey,
},
RevocationBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 1,
Index: 0,
},
PubKey: localRevocationBasePoint,
},
HtlcBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 2,
Index: 0,
},
PubKey: localHtlcBasePoint,
},
PaymentBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 3,
Index: 0,
},
PubKey: localPaymentBasePoint,
},
DelayBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 4,
Index: 0,
},
PubKey: localDelayBasePoint,
},
}
remoteCfg := mig.ChannelConfig{
ChannelConstraints: mig.ChannelConstraints{
DustLimit: btcutil.Amount(354),
MaxAcceptedHtlcs: 483,
CsvDelay: 4,
},
MultiSigKey: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: remoteMultiSigKey,
},
RevocationBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: remoteRevocationBasePoint,
},
HtlcBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: remoteHtlcBasePoint,
},
PaymentBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: remotePaymentBasePoint,
},
DelayBasePoint: keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: 0,
Index: 0,
},
PubKey: remoteDelayBasePoint,
},
}
c := &mig26.OpenChannel{
OpenChannel: mig25.OpenChannel{
OpenChannel: mig.OpenChannel{
ChainHash: testChainHash,
IdentityPub: nodePub,
FundingOutpoint: op,
LocalChanCfg: localCfg,
RemoteChanCfg: remoteCfg,
// Assign dummy values.
RemoteCurrentRevocation: nodePub,
RevocationProducer: testProducer,
RevocationStore: store,
},
ChanType: testChanType,
},
}
return c
}
// writeOldRevocationLogs saves an old revocation log to db.
func writeOldRevocationLogs(chanBucket kvdb.RwBucket,
oldLogs []mig.ChannelCommitment) error {
// Don't bother continue if the logs are empty.
if len(oldLogs) == 0 {
return nil
}
logBucket, err := chanBucket.CreateBucketIfNotExists(
revocationLogBucketDeprecated,
)
if err != nil {
return err
}
for _, c := range oldLogs {
if err := putOldRevocationLog(logBucket, &c); err != nil {
return err
}
}
return nil
}
// writeNewRevocationLogs saves a new revocation log to db.
func writeNewRevocationLogs(chanBucket kvdb.RwBucket,
oldLogs []mig.ChannelCommitment, noAmtData bool) error {
// Don't bother continue if the logs are empty.
if len(oldLogs) == 0 {
return nil
}
logBucket, err := chanBucket.CreateBucketIfNotExists(
revocationLogBucket,
)
if err != nil {
return err
}
for _, c := range oldLogs {
// NOTE: we just blindly write the output indexes to db here
// whereas normally, we would find the correct indexes from the
// old commit tx. We do this intentionally so we can
// distinguish a newly created log from an already saved one.
err := putRevocationLog(
logBucket, &c, testOurIndex, testTheirIndex, noAmtData,
)
if err != nil {
return err
}
}
return nil
}
// createTestStore creates a revocation store and always saves the above
// defined two preimages into the store.
func createTestStore() (shachain.Store, error) {
var p chainhash.Hash
copy(p[:], preimage1)
testStore := shachain.NewRevocationStore()
if err := testStore.AddNextEntry(&p); err != nil {
return nil, err
}
copy(p[:], preimage2)
if err := testStore.AddNextEntry(&p); err != nil {
return nil, err
}
return testStore, nil
}
// createNotStarted will setup a situation where we haven't started the
// migration for the channel. We use the legacy to denote whether to simulate a
// node with v0.15.0.
func createNotStarted(cdb kvdb.Backend, c *mig26.OpenChannel,
legacy bool) error {
var newLogs []mig.ChannelCommitment
// Create test logs.
oldLogs := []mig.ChannelCommitment{oldLog1, oldLog2}
// Add a new log if the node is running with v0.15.0.
if !legacy {
newLogs = []mig.ChannelCommitment{newLog3}
}
return setupTestLogs(cdb, c, oldLogs, newLogs)
}
// createNotFinished will setup a situation where we have un-migrated logs and
// return the next migration height. We use the legacy to denote whether to
// simulate a node with v0.15.0.
func createNotFinished(cdb kvdb.Backend, c *mig26.OpenChannel,
legacy bool) error {
// Create test logs.
oldLogs := []mig.ChannelCommitment{oldLog1, oldLog2}
newLogs := []mig.ChannelCommitment{oldLog1}
// Add a new log if the node is running with v0.15.0.
if !legacy {
newLogs = append(newLogs, newLog3)
}
return setupTestLogs(cdb, c, oldLogs, newLogs)
}
// createFinished will setup a situation where all the old logs have been
// migrated and return a nil. We use the legacy to denote whether to simulate a
// node with v0.15.0.
func createFinished(cdb kvdb.Backend, c *mig26.OpenChannel,
legacy bool) error {
// Create test logs.
oldLogs := []mig.ChannelCommitment{oldLog1, oldLog2}
newLogs := []mig.ChannelCommitment{oldLog1, oldLog2}
// Add a new log if the node is running with v0.15.0.
if !legacy {
newLogs = append(newLogs, newLog3)
}
return setupTestLogs(cdb, c, oldLogs, newLogs)
}
func init() {
rand.Seed(time.Now().Unix())
if rand.Intn(2) == 0 {
withAmtData = true
}
if withAmtData {
newLog1 = newLog1WithAmts
newLog2 = newLog2WithAmts
}
}
package migration31
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration31
import (
"errors"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/kvdb"
)
// DeleteLastPublishedTxTLB deletes the top level bucket with the key
// "sweeper-last-tx".
func DeleteLastPublishedTxTLB(tx kvdb.RwTx) error {
log.Infof("Deleting top-level bucket: %x ...", lastTxBucketKey)
err := tx.DeleteTopLevelBucket(lastTxBucketKey)
if err != nil && !errors.Is(err, walletdb.ErrBucketNotFound) {
return err
}
log.Infof("Deleted top-level bucket: %x", lastTxBucketKey)
return nil
}
package migration32
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
)
var (
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *uint32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *lnwire.MilliSatoshi:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.MilliSatoshi(a)
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *int64, *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *int32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
default:
return UnknownElementType{"ReadElement", e}
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case int64, uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case int32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case lnwire.MilliSatoshi:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
default:
return UnknownElementType{"WriteElement", e}
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
package migration32
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// EncryptedDataOnionType is the type used to include encrypted data
// provided by the receiver in the onion for use in blinded paths.
EncryptedDataOnionType tlv.Type = 10
// BlindingPointOnionType is the type used to include receiver provided
// ephemeral keys in the onion that are used in blinded paths.
BlindingPointOnionType tlv.Type = 12
// MetadataOnionType is the type used in the onion for the payment
// metadata.
MetadataOnionType tlv.Type = 16
// TotalAmtMsatBlindedType is the type used in the onion for the total
// amount field that is included in the final hop for blinded payments.
TotalAmtMsatBlindedType tlv.Type = 18
)
// NewEncryptedDataRecord creates a tlv.Record that encodes the encrypted_data
// (type 10) record for an onion payload.
func NewEncryptedDataRecord(data *[]byte) tlv.Record {
return tlv.MakePrimitiveRecord(EncryptedDataOnionType, data)
}
// NewBlindingPointRecord creates a tlv.Record that encodes the blinding_point
// (type 12) record for an onion payload.
func NewBlindingPointRecord(point **btcec.PublicKey) tlv.Record {
return tlv.MakePrimitiveRecord(BlindingPointOnionType, point)
}
// NewMetadataRecord creates a tlv.Record that encodes the metadata (type 10)
// for an onion payload.
func NewMetadataRecord(metadata *[]byte) tlv.Record {
return tlv.MakeDynamicRecord(
MetadataOnionType, metadata,
func() uint64 {
return uint64(len(*metadata))
},
tlv.EVarBytes, tlv.DVarBytes,
)
}
// NewTotalAmtMsatBlinded creates a tlv.Record that encodes the
// total_amount_msat for the final an onion payload within a blinded route.
func NewTotalAmtMsatBlinded(amt *uint64) tlv.Record {
return tlv.MakeDynamicRecord(
TotalAmtMsatBlindedType, amt, func() uint64 {
return tlv.SizeTUint64(*amt)
},
tlv.ETUint64, tlv.DTUint64,
)
}
package migration32
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration32
import (
"bytes"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
// MigrateMCRouteSerialisation reads all the mission control store entries and
// re-serializes them using a minimal route serialisation so that only the parts
// of the route that are actually required for mission control are persisted.
func MigrateMCRouteSerialisation(tx kvdb.RwTx) error {
log.Infof("Migrating Mission Control store to use a more minimal " +
"encoding for routes")
resultBucket := tx.ReadWriteBucket(resultsKey)
// If the results bucket does not exist then there are no entries in
// the mission control store yet and so there is nothing to migrate.
if resultBucket == nil {
return nil
}
// For each entry, read it into memory using the old encoding. Then,
// extract the more minimal route, re-encode and persist the entry.
return resultBucket.ForEach(func(k, v []byte) error {
// Read the entry using the old encoding.
resultOld, err := deserializeOldResult(k, v)
if err != nil {
return err
}
// Convert to the new payment result format with the minimal
// route.
resultNew := convertPaymentResult(resultOld)
// Serialise the new payment result using the new encoding.
key, resultNewBytes, err := serializeNewResult(resultNew)
if err != nil {
return err
}
// Make sure that the derived key is the same.
if !bytes.Equal(key, k) {
return fmt.Errorf("new payment result key (%v) is "+
"not the same as the old key (%v)", key, k)
}
// Finally, overwrite the previous value with the new encoding.
return resultBucket.Put(k, resultNewBytes)
})
}
package migration32
import (
"bytes"
"io"
"math"
"time"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// unknownFailureSourceIdx is the database encoding of an unknown error
// source.
unknownFailureSourceIdx = -1
)
var (
// resultsKey is the fixed key under which the attempt results are
// stored.
resultsKey = []byte("missioncontrol-results")
)
// paymentResultOld is the information that becomes available when a payment
// attempt completes.
type paymentResultOld struct {
id uint64
timeFwd, timeReply time.Time
route *Route
success bool
failureSourceIdx *int
failure lnwire.FailureMessage
}
// deserializeOldResult deserializes a payment result using the old encoding.
func deserializeOldResult(k, v []byte) (*paymentResultOld, error) {
// Parse payment id.
result := paymentResultOld{
id: byteOrder.Uint64(k[8:]),
}
r := bytes.NewReader(v)
// Read timestamps, success status and failure source index.
var (
timeFwd, timeReply uint64
dbFailureSourceIdx int32
)
err := ReadElements(
r, &timeFwd, &timeReply, &result.success, &dbFailureSourceIdx,
)
if err != nil {
return nil, err
}
// Convert time stamps to local time zone for consistent logging.
result.timeFwd = time.Unix(0, int64(timeFwd)).Local()
result.timeReply = time.Unix(0, int64(timeReply)).Local()
// Convert from unknown index magic number to nil value.
if dbFailureSourceIdx != unknownFailureSourceIdx {
failureSourceIdx := int(dbFailureSourceIdx)
result.failureSourceIdx = &failureSourceIdx
}
// Read route.
route, err := DeserializeRoute(r)
if err != nil {
return nil, err
}
result.route = &route
// Read failure.
failureBytes, err := wire.ReadVarBytes(r, 0, math.MaxUint16, "failure")
if err != nil {
return nil, err
}
if len(failureBytes) > 0 {
result.failure, err = lnwire.DecodeFailureMessage(
bytes.NewReader(failureBytes), 0,
)
if err != nil {
return nil, err
}
}
return &result, nil
}
// convertPaymentResult converts a paymentResultOld to a paymentResultNew.
func convertPaymentResult(old *paymentResultOld) *paymentResultNew {
var failure *paymentFailure
if !old.success {
failure = newPaymentFailure(old.failureSourceIdx, old.failure)
}
return newPaymentResult(
old.id, extractMCRoute(old.route), old.timeFwd, old.timeReply,
failure,
)
}
// newPaymentResult constructs a new paymentResult.
func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time,
failure *paymentFailure) *paymentResultNew {
result := &paymentResultNew{
id: id,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(timeFwd.UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint64(timeReply.UnixNano()),
),
route: tlv.NewRecordT[tlv.TlvType2](*rt),
}
if failure != nil {
result.failure = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](*failure),
)
}
return result
}
// paymentResultNew is the information that becomes available when a payment
// attempt completes.
type paymentResultNew struct {
id uint64
timeFwd tlv.RecordT[tlv.TlvType0, uint64]
timeReply tlv.RecordT[tlv.TlvType1, uint64]
route tlv.RecordT[tlv.TlvType2, mcRoute]
// failure holds information related to the failure of a payment. The
// presence of this record indicates a payment failure. The absence of
// this record indicates a successful payment.
failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure]
}
// paymentFailure represents the presence of a payment failure. It may or may
// not include additional information about said failure.
type paymentFailure struct {
info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo]
}
// newPaymentFailure constructs a new paymentFailure struct. If the source
// index is nil, then an empty paymentFailure is returned. This represents a
// failure with unknown details. Otherwise, the index and failure message are
// used to populate the info field of the paymentFailure.
func newPaymentFailure(sourceIdx *int,
failureMsg lnwire.FailureMessage) *paymentFailure {
if sourceIdx == nil {
return &paymentFailure{}
}
info := paymentFailureInfo{
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
),
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
}
return &paymentFailure{
info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)),
}
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailure to/from a TLV stream.
func (r *paymentFailure) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailure(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailure, decodePaymentFailure,
)
}
func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailure); ok {
var recordProducers []tlv.RecordProducer
v.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
recordProducers = append(recordProducers, &r)
},
)
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure")
}
func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailure); ok {
var h paymentFailure
info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]()
typeMap, err := lnwire.DecodeRecords(
r, lnwire.ProduceRecordsSorted(&info)...,
)
if err != nil {
return err
}
if _, ok := typeMap[h.info.TlvType()]; ok {
h.info = tlv.SomeRecordT(info)
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l)
}
// paymentFailureInfo holds additional information about a payment failure.
type paymentFailureInfo struct {
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
msg tlv.RecordT[tlv.TlvType1, failureMessage]
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailureInfo to/from a TLV stream.
func (r *paymentFailureInfo) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailureInfo(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailureInfo,
decodePaymentFailureInfo,
)
}
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailureInfo); ok {
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(
&v.sourceIdx, &v.msg,
),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo")
}
func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailureInfo); ok {
var h paymentFailureInfo
_, err := lnwire.DecodeRecords(
r,
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
)
if err != nil {
return err
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(
val, "routing.paymentFailureInfo", l, l,
)
}
type failureMessage struct {
lnwire.FailureMessage
}
// Record returns a TLV record that can be used to encode/decode a list of
// failureMessage to/from a TLV stream.
func (r *failureMessage) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeFailureMessage(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeFailureMessage, decodeFailureMessage,
)
}
func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*failureMessage); ok {
var b bytes.Buffer
err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0)
if err != nil {
return err
}
_, err = w.Write(b.Bytes())
return err
}
return tlv.NewTypeForEncodingErr(val, "routing.failureMessage")
}
func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*failureMessage); ok {
msg, err := lnwire.DecodeFailureMessage(r, 0)
if err != nil {
return err
}
*v = failureMessage{
FailureMessage: msg,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l)
}
// extractMCRoute extracts the fields required by MC from the Route struct to
// create the more minimal mcRoute struct.
func extractMCRoute(r *Route) *mcRoute {
return &mcRoute{
sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey),
totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount),
hops: tlv.NewRecordT[tlv.TlvType2](
extractMCHops(r.Hops),
),
}
}
// extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops.
func extractMCHops(hops []*Hop) mcHops {
return fn.Map(hops, extractMCHop)
}
// extractMCHop extracts the Hop fields that MC actually uses from a Hop.
func extractMCHop(hop *Hop) *mcHop {
h := mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](
hop.ChannelID,
),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](
hop.PubKeyBytes,
),
amtToFwd: tlv.NewRecordT[tlv.TlvType2, lnwire.MilliSatoshi](
hop.AmtToForward,
),
}
if hop.BlindingPoint != nil {
h.hasBlindingPoint = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3, lnwire.TrueBoolean](
lnwire.TrueBoolean{},
),
)
}
if len(hop.CustomRecords) != 0 {
h.hasCustomRecords = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean](
lnwire.TrueBoolean{},
),
)
}
return &h
}
// mcRoute holds the bare minimum info about a payment attempt route that MC
// requires.
type mcRoute struct {
sourcePubKey tlv.RecordT[tlv.TlvType0, Vertex]
totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi]
hops tlv.RecordT[tlv.TlvType2, mcHops]
}
// Record returns a TLV record that can be used to encode/decode an mcRoute
// to/from a TLV stream.
func (r *mcRoute) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCRoute(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeMCRoute, decodeMCRoute,
)
}
func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*mcRoute); ok {
return serializeRoute(w, v)
}
return tlv.NewTypeForEncodingErr(val, "routing.mcRoute")
}
func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if v, ok := val.(*mcRoute); ok {
route, err := deserializeRoute(io.LimitReader(r, int64(l)))
if err != nil {
return err
}
*v = *route
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l)
}
// mcHops is a list of mcHop records.
type mcHops []*mcHop
// Record returns a TLV record that can be used to encode/decode a list of
// mcHop to/from a TLV stream.
func (h *mcHops) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCHops(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, encodeMCHops, decodeMCHops,
)
}
func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*mcHops); ok {
// Encode the number of hops as a var int.
if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for _, hop := range *v {
var hopBytes bytes.Buffer
if err := serializeNewHop(&hopBytes, hop); err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(hopBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
if _, err := w.Write(hopBytes.Bytes()); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "routing.mcHops")
}
func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*mcHops); ok {
// First, we'll decode the varint that encodes how many hops
// are encoded in the stream.
numHops, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numHops; i++ {
// Read out the varint that encodes the size of this
// inner TLV record.
hopSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := &io.LimitedReader{
R: r,
N: int64(hopSize),
}
hop, err := deserializeNewHop(innerTlvReader)
if err != nil {
return err
}
*v = append(*v, hop)
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l)
}
// serializeRoute serializes a mcRoute and writes the resulting bytes to the
// given io.Writer.
func serializeRoute(w io.Writer, r *mcRoute) error {
records := lnwire.ProduceRecordsSorted(
&r.sourcePubKey,
&r.totalAmount,
&r.hops,
)
return lnwire.EncodeRecordsTo(w, records)
}
// deserializeRoute deserializes the mcRoute from the given io.Reader.
func deserializeRoute(r io.Reader) (*mcRoute, error) {
var rt mcRoute
records := lnwire.ProduceRecordsSorted(
&rt.sourcePubKey,
&rt.totalAmount,
&rt.hops,
)
_, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
return &rt, nil
}
// deserializeNewHop deserializes the mcHop from the given io.Reader.
func deserializeNewHop(r io.Reader) (*mcHop, error) {
var (
h mcHop
blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]()
custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]()
)
records := lnwire.ProduceRecordsSorted(
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
&blinding,
&custom,
)
typeMap, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok {
h.hasBlindingPoint = tlv.SomeRecordT(blinding)
}
if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok {
h.hasCustomRecords = tlv.SomeRecordT(custom)
}
return &h, nil
}
// serializeNewHop serializes a mcHop and writes the resulting bytes to the
// given io.Writer.
func serializeNewHop(w io.Writer, h *mcHop) error {
recordProducers := []tlv.RecordProducer{
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
}
h.hasBlindingPoint.WhenSome(func(
hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasBlinding)
})
h.hasCustomRecords.WhenSome(func(
hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasCustom)
})
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
// mcHop holds the bare minimum info about a payment attempt route hop that MC
// requires.
type mcHop struct {
channelID tlv.RecordT[tlv.TlvType0, uint64]
pubKeyBytes tlv.RecordT[tlv.TlvType1, Vertex]
amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi]
hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean]
hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean]
}
// serializeOldResult serializes a payment result and returns a key and value
// byte slice to insert into the bucket.
func serializeOldResult(rp *paymentResultOld) ([]byte, []byte, error) {
// Write timestamps, success status, failure source index and route.
var b bytes.Buffer
var dbFailureSourceIdx int32
if rp.failureSourceIdx == nil {
dbFailureSourceIdx = unknownFailureSourceIdx
} else {
dbFailureSourceIdx = int32(*rp.failureSourceIdx)
}
err := WriteElements(
&b,
uint64(rp.timeFwd.UnixNano()),
uint64(rp.timeReply.UnixNano()),
rp.success, dbFailureSourceIdx,
)
if err != nil {
return nil, nil, err
}
if err := SerializeRoute(&b, *rp.route); err != nil {
return nil, nil, err
}
// Write failure. If there is no failure message, write an empty
// byte slice.
var failureBytes bytes.Buffer
if rp.failure != nil {
err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0)
if err != nil {
return nil, nil, err
}
}
err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes())
if err != nil {
return nil, nil, err
}
// Compose key that identifies this result.
key := getResultKeyOld(rp)
return key, b.Bytes(), nil
}
// getResultKeyOld returns a byte slice representing a unique key for this
// payment result.
func getResultKeyOld(rp *paymentResultOld) []byte {
var keyBytes [8 + 8 + 33]byte
// Identify records by a combination of time, payment id and sender pub
// key. This allows importing mission control data from an external
// source without key collisions and keeps the records sorted
// chronologically.
byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano()))
byteOrder.PutUint64(keyBytes[8:], rp.id)
copy(keyBytes[16:], rp.route.SourcePubKey[:])
return keyBytes[:]
}
// serializeNewResult serializes a payment result and returns a key and value
// byte slice to insert into the bucket.
func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) {
recordProducers := []tlv.RecordProducer{
&rp.timeFwd,
&rp.timeReply,
&rp.route,
}
rp.failure.WhenSome(
func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) {
recordProducers = append(recordProducers, &failure)
},
)
// Compose key that identifies this result.
key := getResultKeyNew(rp)
var buff bytes.Buffer
err := lnwire.EncodeRecordsTo(
&buff, lnwire.ProduceRecordsSorted(recordProducers...),
)
if err != nil {
return nil, nil, err
}
return key, buff.Bytes(), nil
}
// getResultKeyNew returns a byte slice representing a unique key for this
// payment result.
func getResultKeyNew(rp *paymentResultNew) []byte {
var keyBytes [8 + 8 + 33]byte
// Identify records by a combination of time, payment id and sender pub
// key. This allows importing mission control data from an external
// source without key collisions and keeps the records sorted
// chronologically.
byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val)
byteOrder.PutUint64(keyBytes[8:], rp.id)
copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:])
return keyBytes[:]
}
package migration32
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// MPPOnionType is the type used in the onion to reference the MPP
// fields: total_amt and payment_addr.
MPPOnionType tlv.Type = 8
// AMPOnionType is the type used in the onion to reference the AMP
// fields: root_share, set_id, and child_index.
AMPOnionType tlv.Type = 14
)
// VertexSize is the size of the array to store a vertex.
const VertexSize = 33
// Vertex is a simple alias for the serialization of a compressed Bitcoin
// public key.
type Vertex [VertexSize]byte
// Record returns a TLV record that can be used to encode/decode a Vertex
// to/from a TLV stream.
func (v *Vertex) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, v, VertexSize, encodeVertex, decodeVertex,
)
}
func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*Vertex); ok {
_, err := w.Write(b[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "Vertex")
}
func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*Vertex); ok {
_, err := io.ReadFull(r, b[:])
return err
}
return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize)
}
// Route represents a path through the channel graph which runs over one or
// more channels in succession. This struct carries all the information
// required to craft the Sphinx onion packet, and send the payment along the
// first hop in the path. A route is only selected as valid if all the channels
// have sufficient capacity to carry the initial payment amount after fees are
// accounted for.
type Route struct {
// TotalTimeLock is the cumulative (final) time lock across the entire
// route. This is the CLTV value that should be extended to the first
// hop in the route. All other hops will decrement the time-lock as
// advertised, leaving enough time for all hops to wait for or present
// the payment preimage to complete the payment.
TotalTimeLock uint32
// TotalAmount is the total amount of funds required to complete a
// payment over this route. This value includes the cumulative fees at
// each hop. As a result, the HTLC extended to the first-hop in the
// route will need to have at least this many satoshis, otherwise the
// route will fail at an intermediate node due to an insufficient
// amount of fees.
TotalAmount lnwire.MilliSatoshi
// SourcePubKey is the pubkey of the node where this route originates
// from.
SourcePubKey Vertex
// Hops contains details concerning the specific forwarding details at
// each hop.
Hops []*Hop
// FirstHopAmount is the amount that should actually be sent to the
// first hop in the route. This is only different from TotalAmount above
// for custom channels where the on-chain amount doesn't necessarily
// reflect all the value of an outgoing payment.
FirstHopAmount tlv.RecordT[
tlv.TlvType0, tlv.BigSizeT[lnwire.MilliSatoshi],
]
// FirstHopWireCustomRecords is a set of custom records that should be
// included in the wire message sent to the first hop. This is only set
// on custom channels and is used to include additional information
// about the actual value of the payment.
//
// NOTE: Since these records already represent TLV records, and we
// enforce them to be in the custom range (e.g. >= 65536), we don't use
// another parent record type here. Instead, when serializing the Route
// we merge the TLV records together with the custom records and encode
// everything as a single TLV stream.
FirstHopWireCustomRecords lnwire.CustomRecords
}
// Hop represents an intermediate or final node of the route. This naming
// is in line with the definition given in BOLT #4: Onion Routing Protocol.
// The struct houses the channel along which this hop can be reached and
// the values necessary to create the HTLC that needs to be sent to the
// next hop. It is also used to encode the per-hop payload included within
// the Sphinx packet.
type Hop struct {
// PubKeyBytes is the raw bytes of the public key of the target node.
PubKeyBytes Vertex
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// OutgoingTimeLock is the timelock value that should be used when
// crafting the _outgoing_ HTLC from this hop.
OutgoingTimeLock uint32
// AmtToForward is the amount that this hop will forward to the next
// hop. This value is less than the value that the incoming HTLC
// carries as a fee will be subtracted by the hop.
AmtToForward lnwire.MilliSatoshi
// MPP encapsulates the data required for option_mpp. This field should
// only be set for the final hop.
MPP *MPP
// AMP encapsulates the data required for option_amp. This field should
// only be set for the final hop.
AMP *AMP
// CustomRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node.
CustomRecords lnwire.CustomRecords
// LegacyPayload if true, then this signals that this node doesn't
// understand the new TLV payload, so we must instead use the legacy
// payload.
//
// NOTE: we should no longer ever create a Hop with Legacy set to true.
// The only reason we are keeping this member is that it could be the
// case that we have serialised hops persisted to disk where
// LegacyPayload is true.
LegacyPayload bool
// Metadata is additional data that is sent along with the payment to
// the payee.
Metadata []byte
// EncryptedData is an encrypted data blob includes for hops that are
// part of a blinded route.
EncryptedData []byte
// BlindingPoint is an ephemeral public key used by introduction nodes
// in blinded routes to unblind their portion of the route and pass on
// the next ephemeral key to the next blinded node to do the same.
BlindingPoint *btcec.PublicKey
// TotalAmtMsat is the total amount for a blinded payment, potentially
// spread over more than one HTLC. This field should only be set for
// the final hop in a blinded path.
TotalAmtMsat lnwire.MilliSatoshi
}
// MPP is a record that encodes the fields necessary for multi-path payments.
type MPP struct {
// paymentAddr is a random, receiver-generated value used to avoid
// collisions with concurrent payers.
paymentAddr [32]byte
// totalMsat is the total value of the payment, potentially spread
// across more than one HTLC.
totalMsat lnwire.MilliSatoshi
}
// Record returns a tlv.Record that can be used to encode or decode this record.
func (r *MPP) Record() tlv.Record {
// Fixed-size, 32 byte payment address followed by truncated 64-bit
// total msat.
size := func() uint64 {
return 32 + tlv.SizeTUint64(uint64(r.totalMsat))
}
return tlv.MakeDynamicRecord(
MPPOnionType, r, size, MPPEncoder, MPPDecoder,
)
}
const (
// minMPPLength is the minimum length of a serialized MPP TLV record,
// which occurs when the truncated encoding of total_amt_msat takes 0
// bytes, leaving only the payment_addr.
minMPPLength = 32
// maxMPPLength is the maximum length of a serialized MPP TLV record,
// which occurs when the truncated encoding of total_amt_msat takes 8
// bytes.
maxMPPLength = 40
)
// MPPEncoder writes the MPP record to the provided io.Writer.
func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*MPP); ok {
err := tlv.EBytes32(w, &v.paymentAddr, buf)
if err != nil {
return err
}
return tlv.ETUint64T(w, uint64(v.totalMsat), buf)
}
return tlv.NewTypeForEncodingErr(val, "MPP")
}
// MPPDecoder reads the MPP record to the provided io.Reader.
func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength {
if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil {
return err
}
var total uint64
if err := tlv.DTUint64(r, &total, buf, l-32); err != nil {
return err
}
v.totalMsat = lnwire.MilliSatoshi(total)
return nil
}
return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength)
}
// AMP is a record that encodes the fields necessary for atomic multi-path
// payments.
type AMP struct {
rootShare [32]byte
setID [32]byte
childIndex uint32
}
// AMPEncoder writes the AMP record to the provided io.Writer.
func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*AMP); ok {
if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil {
return err
}
if err := tlv.EBytes32(w, &v.setID, buf); err != nil {
return err
}
return tlv.ETUint32T(w, v.childIndex, buf)
}
return tlv.NewTypeForEncodingErr(val, "AMP")
}
const (
// minAMPLength is the minimum length of a serialized AMP TLV record,
// which occurs when the truncated encoding of child_index takes 0
// bytes, leaving only the root_share and set_id.
minAMPLength = 64
// maxAMPLength is the maximum length of a serialized AMP TLV record,
// which occurs when the truncated encoding of a child_index takes 2
// bytes.
maxAMPLength = 68
)
// AMPDecoder reads the AMP record from the provided io.Reader.
func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength {
if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil {
return err
}
if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil {
return err
}
return tlv.DTUint32(r, &v.childIndex, buf, l-minAMPLength)
}
return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength)
}
// Record returns a tlv.Record that can be used to encode or decode this record.
func (a *AMP) Record() tlv.Record {
return tlv.MakeDynamicRecord(
AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder,
)
}
// PayloadSize returns the size this record takes up in encoded form.
func (a *AMP) PayloadSize() uint64 {
return 32 + 32 + tlv.SizeTUint32(a.childIndex)
}
// SerializeRoute serializes a route.
func SerializeRoute(w io.Writer, r Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHop(w, h); err != nil {
return err
}
}
// Any new/extra TLV data is encoded in serializeHTLCAttemptInfo!
return nil
}
func serializeHop(w io.Writer, h *Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:],
h.ChannelID,
h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
return err
}
// For legacy payloads, we don't need to write any TLV records, so
// we'll write a zero indicating the our serialized TLV map has no
// records.
if h.LegacyPayload {
return WriteElements(w, uint32(0))
}
// Gather all non-primitive TLV records so that they can be serialized
// as a single blob.
//
// TODO(conner): add migration to unify all fields in a single TLV
// blobs. The split approach will cause headaches down the road as more
// fields are added, which we can avoid by having a single TLV stream
// for all payload fields.
var records []tlv.Record
if h.MPP != nil {
records = append(records, h.MPP.Record())
}
// Add blinding point and encrypted data if present.
if h.EncryptedData != nil {
records = append(records, NewEncryptedDataRecord(
&h.EncryptedData,
))
}
if h.BlindingPoint != nil {
records = append(records, NewBlindingPointRecord(
&h.BlindingPoint,
))
}
if h.AMP != nil {
records = append(records, h.AMP.Record())
}
if h.Metadata != nil {
records = append(records, NewMetadataRecord(&h.Metadata))
}
if h.TotalAmtMsat != 0 {
totalMsatInt := uint64(h.TotalAmtMsat)
records = append(
records, NewTotalAmtMsatBlinded(&totalMsatInt),
)
}
// Final sanity check to absolutely rule out custom records that are not
// custom and write into the standard range.
if err := h.CustomRecords.Validate(); err != nil {
return err
}
// Convert custom records to tlv and add to the record list.
// MapToRecords sorts the list, so adding it here will keep the list
// canonical.
tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
// Otherwise, we'll transform our slice of records into a map of the
// raw bytes, then serialize them in-line with a length (number of
// elements) prefix.
mapRecords, err := tlv.RecordsToMap(records)
if err != nil {
return err
}
numRecords := uint32(len(mapRecords))
if err := WriteElements(w, numRecords); err != nil {
return err
}
for recordType, rawBytes := range mapRecords {
if err := WriteElements(w, recordType); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
return err
}
}
return nil
}
// DeserializeRoute deserializes a route.
func DeserializeRoute(r io.Reader) (Route, error) {
rt := Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHop(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
// Any new/extra TLV data is decoded in deserializeHTLCAttemptInfo!
return rt, nil
}
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
// to read/write a TLV stream larger than this.
const maxOnionPayloadSize = 1300
func deserializeHop(r io.Reader) (*Hop, error) {
h := &Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
// TODO(roasbeef): change field to allow LegacyPayload false to be the
// legacy default?
err := binary.Read(r, byteOrder, &h.LegacyPayload)
if err != nil {
return nil, err
}
var numElements uint32
if err := ReadElements(r, &numElements); err != nil {
return nil, err
}
// If there're no elements, then we can return early.
if numElements == 0 {
return h, nil
}
tlvMap := make(map[uint64][]byte)
for i := uint32(0); i < numElements; i++ {
var tlvType uint64
if err := ReadElements(r, &tlvType); err != nil {
return nil, err
}
rawRecordBytes, err := wire.ReadVarBytes(
r, 0, maxOnionPayloadSize, "tlv",
)
if err != nil {
return nil, err
}
tlvMap[tlvType] = rawRecordBytes
}
// If the MPP type is present, remove it from the generic TLV map and
// parse it back into a proper MPP struct.
//
// TODO(conner): add migration to unify all fields in a single TLV
// blobs. The split approach will cause headaches down the road as more
// fields are added, which we can avoid by having a single TLV stream
// for all payload fields.
mppType := uint64(MPPOnionType)
if mppBytes, ok := tlvMap[mppType]; ok {
delete(tlvMap, mppType)
var (
mpp = &MPP{}
mppRec = mpp.Record()
r = bytes.NewReader(mppBytes)
)
err := mppRec.Decode(r, uint64(len(mppBytes)))
if err != nil {
return nil, err
}
h.MPP = mpp
}
// If encrypted data or blinding key are present, remove them from
// the TLV map and parse into proper types.
encryptedDataType := uint64(EncryptedDataOnionType)
if data, ok := tlvMap[encryptedDataType]; ok {
delete(tlvMap, encryptedDataType)
h.EncryptedData = data
}
blindingType := uint64(BlindingPointOnionType)
if blindingPoint, ok := tlvMap[blindingType]; ok {
delete(tlvMap, blindingType)
h.BlindingPoint, err = btcec.ParsePubKey(blindingPoint)
if err != nil {
return nil, fmt.Errorf("invalid blinding point: %w",
err)
}
}
ampType := uint64(AMPOnionType)
if ampBytes, ok := tlvMap[ampType]; ok {
delete(tlvMap, ampType)
var (
amp = &{}
ampRec = amp.Record()
r = bytes.NewReader(ampBytes)
)
err := ampRec.Decode(r, uint64(len(ampBytes)))
if err != nil {
return nil, err
}
h.AMP = amp
}
// If the metadata type is present, remove it from the tlv map and
// populate directly on the hop.
metadataType := uint64(MetadataOnionType)
if metadata, ok := tlvMap[metadataType]; ok {
delete(tlvMap, metadataType)
h.Metadata = metadata
}
totalAmtMsatType := uint64(TotalAmtMsatBlindedType)
if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok {
delete(tlvMap, totalAmtMsatType)
var (
totalAmtMsatInt uint64
buf [8]byte
)
if err := tlv.DTUint64(
bytes.NewReader(totalAmtMsat),
&totalAmtMsatInt,
&buf,
uint64(len(totalAmtMsat)),
); err != nil {
return nil, err
}
h.TotalAmtMsat = lnwire.MilliSatoshi(totalAmtMsatInt)
}
h.CustomRecords = tlvMap
return h, nil
}
package migration33
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration33
import (
"bytes"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// resultsKey is the fixed key under which the attempt results are
// stored.
resultsKey = []byte("missioncontrol-results")
// defaultMCNamespaceKey is the key of the default mission control store
// namespace.
defaultMCNamespaceKey = []byte("default")
)
// MigrateMCStoreNameSpacedResults reads in all the current mission control
// entries and re-writes them under a new default namespace.
func MigrateMCStoreNameSpacedResults(tx kvdb.RwTx) error {
log.Infof("Migrating Mission Control store to use namespaced results")
// Get the top level bucket. All the MC results are currently stored
// as KV pairs in this bucket
topLevelBucket := tx.ReadWriteBucket(resultsKey)
// If the results bucket does not exist then there are no entries in
// the mission control store yet and so there is nothing to migrate.
if topLevelBucket == nil {
return nil
}
// Create a new default namespace bucket under the top-level bucket.
defaultNSBkt, err := topLevelBucket.CreateBucket(defaultMCNamespaceKey)
if err != nil {
return err
}
// Iterate through each of the existing result pairs, write them to the
// new namespaced bucket. Also collect the set of keys so that we can
// later delete them from the top level bucket.
var keys [][]byte
err = topLevelBucket.ForEach(func(k, v []byte) error {
// Skip the new default namespace key.
if bytes.Equal(k, defaultMCNamespaceKey) {
return nil
}
// Collect the key.
keys = append(keys, k)
// Write the pair under the default namespace.
return defaultNSBkt.Put(k, v)
})
if err != nil {
return err
}
// Finally, iterate through the set of keys and delete them from the
// top level bucket.
for _, k := range keys {
if err := topLevelBucket.Delete(k); err != nil {
return err
}
}
return err
}
package migration_01_to_11
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"github.com/lightningnetwork/lnd/tor"
)
// addressType specifies the network protocol and version that should be used
// when connecting to a node at a particular address.
type addressType uint8
const (
// tcp4Addr denotes an IPv4 TCP address.
tcp4Addr addressType = 0
// tcp6Addr denotes an IPv6 TCP address.
tcp6Addr addressType = 1
// v2OnionAddr denotes a version 2 Tor onion service address.
v2OnionAddr addressType = 2
// v3OnionAddr denotes a version 3 Tor (prop224) onion service address.
v3OnionAddr addressType = 3
)
// encodeTCPAddr serializes a TCP address into its compact raw bytes
// representation.
func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error {
var (
addrType byte
ip []byte
)
if addr.IP.To4() != nil {
addrType = byte(tcp4Addr)
ip = addr.IP.To4()
} else {
addrType = byte(tcp6Addr)
ip = addr.IP.To16()
}
if ip == nil {
return fmt.Errorf("unable to encode IP %v", addr.IP)
}
if _, err := w.Write([]byte{addrType}); err != nil {
return err
}
if _, err := w.Write(ip); err != nil {
return err
}
var port [2]byte
byteOrder.PutUint16(port[:], uint16(addr.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
return nil
}
// encodeOnionAddr serializes an onion address into its compact raw bytes
// representation.
func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error {
var suffixIndex int
hostLen := len(addr.OnionService)
switch hostLen {
case tor.V2Len:
if _, err := w.Write([]byte{byte(v2OnionAddr)}); err != nil {
return err
}
suffixIndex = tor.V2Len - tor.OnionSuffixLen
case tor.V3Len:
if _, err := w.Write([]byte{byte(v3OnionAddr)}); err != nil {
return err
}
suffixIndex = tor.V3Len - tor.OnionSuffixLen
default:
return errors.New("unknown onion service length")
}
suffix := addr.OnionService[suffixIndex:]
if suffix != tor.OnionSuffix {
return fmt.Errorf("invalid suffix \"%v\"", suffix)
}
host, err := tor.Base32Encoding.DecodeString(
addr.OnionService[:suffixIndex],
)
if err != nil {
return err
}
// Sanity check the decoded length.
switch {
case hostLen == tor.V2Len && len(host) != tor.V2DecodedLen:
return fmt.Errorf("onion service %v decoded to invalid host %x",
addr.OnionService, host)
case hostLen == tor.V3Len && len(host) != tor.V3DecodedLen:
return fmt.Errorf("onion service %v decoded to invalid host %x",
addr.OnionService, host)
}
if _, err := w.Write(host); err != nil {
return err
}
var port [2]byte
byteOrder.PutUint16(port[:], uint16(addr.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
return nil
}
// deserializeAddr reads the serialized raw representation of an address and
// deserializes it into the actual address. This allows us to avoid address
// resolution within the channeldb package.
func deserializeAddr(r io.Reader) (net.Addr, error) {
var addrType [1]byte
if _, err := r.Read(addrType[:]); err != nil {
return nil, err
}
var address net.Addr
switch addressType(addrType[0]) {
case tcp4Addr:
var ip [4]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
case tcp6Addr:
var ip [16]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
case v2OnionAddr:
var h [tor.V2DecodedLen]byte
if _, err := r.Read(h[:]); err != nil {
return nil, err
}
var p [2]byte
if _, err := r.Read(p[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
case v3OnionAddr:
var h [tor.V3DecodedLen]byte
if _, err := r.Read(h[:]); err != nil {
return nil, err
}
var p [2]byte
if _, err := r.Read(p[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
default:
return nil, ErrUnknownAddressType
}
return address, nil
}
// serializeAddr serializes an address into its raw bytes representation so that
// it can be deserialized without requiring address resolution.
func serializeAddr(w io.Writer, address net.Addr) error {
switch addr := address.(type) {
case *net.TCPAddr:
return encodeTCPAddr(w, addr)
case *tor.OnionAddr:
return encodeOnionAddr(w, addr)
default:
return ErrUnknownAddressType
}
}
package migration_01_to_11
import (
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/shachain"
)
var (
// closedChannelBucket stores summarization information concerning
// previously open, but now closed channels.
closedChannelBucket = []byte("closed-chan-bucket")
// openChanBucket stores all the currently open channels. This bucket
// has a second, nested bucket which is keyed by a node's ID. Within
// that node ID bucket, all attributes required to track, update, and
// close a channel are stored.
//
// openChan -> nodeID -> chanPoint
//
// TODO(roasbeef): flesh out comment
openChannelBucket = []byte("open-chan-bucket")
)
// ChannelType is an enum-like type that describes one of several possible
// channel types. Each open channel is associated with a particular type as the
// channel type may determine how higher level operations are conducted such as
// fee negotiation, channel closing, the format of HTLCs, etc.
// TODO(roasbeef): split up per-chain?
type ChannelType uint8
const (
// NOTE: iota isn't used here for this enum needs to be stable
// long-term as it will be persisted to the database.
// SingleFunder represents a channel wherein one party solely funds the
// entire capacity of the channel.
SingleFunder ChannelType = 0
)
// ChannelConstraints represents a set of constraints meant to allow a node to
// limit their exposure, enact flow control and ensure that all HTLCs are
// economically relevant. This struct will be mirrored for both sides of the
// channel, as each side will enforce various constraints that MUST be adhered
// to for the life time of the channel. The parameters for each of these
// constraints are static for the duration of the channel, meaning the channel
// must be torn down for them to change.
type ChannelConstraints struct {
// DustLimit is the threshold (in satoshis) below which any outputs
// should be trimmed. When an output is trimmed, it isn't materialized
// as an actual output, but is instead burned to miner's fees.
DustLimit btcutil.Amount
// ChanReserve is an absolute reservation on the channel for the
// owner of this set of constraints. This means that the current
// settled balance for this node CANNOT dip below the reservation
// amount. This acts as a defense against costless attacks when
// either side no longer has any skin in the game.
ChanReserve btcutil.Amount
// MaxPendingAmount is the maximum pending HTLC value that the
// owner of these constraints can offer the remote node at a
// particular time.
MaxPendingAmount lnwire.MilliSatoshi
// MinHTLC is the minimum HTLC value that the owner of these
// constraints can offer the remote node. If any HTLCs below this
// amount are offered, then the HTLC will be rejected. This, in
// tandem with the dust limit allows a node to regulate the
// smallest HTLC that it deems economically relevant.
MinHTLC lnwire.MilliSatoshi
// MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of
// this set of constraints can offer the remote node. This allows each
// node to limit their over all exposure to HTLCs that may need to be
// acted upon in the case of a unilateral channel closure or a contract
// breach.
MaxAcceptedHtlcs uint16
// CsvDelay is the relative time lock delay expressed in blocks. Any
// settled outputs that pay to the owner of this channel configuration
// MUST ensure that the delay branch uses this value as the relative
// time lock. Similarly, any HTLC's offered by this node should use
// this value as well.
CsvDelay uint16
}
// ChannelConfig is a struct that houses the various configuration opens for
// channels. Each side maintains an instance of this configuration file as it
// governs: how the funding and commitment transaction to be created, the
// nature of HTLC's allotted, the keys to be used for delivery, and relative
// time lock parameters.
type ChannelConfig struct {
// ChannelConstraints is the set of constraints that must be upheld for
// the duration of the channel for the owner of this channel
// configuration. Constraints govern a number of flow control related
// parameters, also including the smallest HTLC that will be accepted
// by a participant.
ChannelConstraints
// MultiSigKey is the key to be used within the 2-of-2 output script
// for the owner of this channel config.
MultiSigKey keychain.KeyDescriptor
// RevocationBasePoint is the base public key to be used when deriving
// revocation keys for the remote node's commitment transaction. This
// will be combined along with a per commitment secret to derive a
// unique revocation key for each state.
RevocationBasePoint keychain.KeyDescriptor
// PaymentBasePoint is the base public key to be used when deriving
// the key used within the non-delayed pay-to-self output on the
// commitment transaction for a node. This will be combined with a
// tweak derived from the per-commitment point to ensure unique keys
// for each commitment transaction.
PaymentBasePoint keychain.KeyDescriptor
// DelayBasePoint is the base public key to be used when deriving the
// key used within the delayed pay-to-self output on the commitment
// transaction for a node. This will be combined with a tweak derived
// from the per-commitment point to ensure unique keys for each
// commitment transaction.
DelayBasePoint keychain.KeyDescriptor
// HtlcBasePoint is the base public key to be used when deriving the
// local HTLC key. The derived key (combined with the tweak derived
// from the per-commitment point) is used within the "to self" clause
// within any HTLC output scripts.
HtlcBasePoint keychain.KeyDescriptor
}
// ChannelCommitment is a snapshot of the commitment state at a particular
// point in the commitment chain. With each state transition, a snapshot of the
// current state along with all non-settled HTLCs are recorded. These snapshots
// detail the state of the _remote_ party's commitment at a particular state
// number. For ourselves (the local node) we ONLY store our most recent
// (unrevoked) state for safety purposes.
type ChannelCommitment struct {
// CommitHeight is the update number that this ChannelDelta represents
// the total number of commitment updates to this point. This can be
// viewed as sort of a "commitment height" as this number is
// monotonically increasing.
CommitHeight uint64
// LocalLogIndex is the cumulative log index index of the local node at
// this point in the commitment chain. This value will be incremented
// for each _update_ added to the local update log.
LocalLogIndex uint64
// LocalHtlcIndex is the current local running HTLC index. This value
// will be incremented for each outgoing HTLC the local node offers.
LocalHtlcIndex uint64
// RemoteLogIndex is the cumulative log index index of the remote node
// at this point in the commitment chain. This value will be
// incremented for each _update_ added to the remote update log.
RemoteLogIndex uint64
// RemoteHtlcIndex is the current remote running HTLC index. This value
// will be incremented for each outgoing HTLC the remote node offers.
RemoteHtlcIndex uint64
// LocalBalance is the current available settled balance within the
// channel directly spendable by us.
LocalBalance lnwire.MilliSatoshi
// RemoteBalance is the current available settled balance within the
// channel directly spendable by the remote node.
RemoteBalance lnwire.MilliSatoshi
// CommitFee is the amount calculated to be paid in fees for the
// current set of commitment transactions. The fee amount is persisted
// with the channel in order to allow the fee amount to be removed and
// recalculated with each channel state update, including updates that
// happen after a system restart.
CommitFee btcutil.Amount
// FeePerKw is the min satoshis/kilo-weight that should be paid within
// the commitment transaction for the entire duration of the channel's
// lifetime. This field may be updated during normal operation of the
// channel as on-chain conditions change.
//
// TODO(halseth): make this SatPerKWeight. Cannot be done atm because
// this will cause the import cycle lnwallet<->channeldb. Fee
// estimation stuff should be in its own package.
FeePerKw btcutil.Amount
// CommitTx is the latest version of the commitment state, broadcast
// able by us.
CommitTx *wire.MsgTx
// CommitSig is one half of the signature required to fully complete
// the script for the commitment transaction above. This is the
// signature signed by the remote party for our version of the
// commitment transactions.
CommitSig []byte
// Htlcs is the set of HTLC's that are pending at this particular
// commitment height.
Htlcs []HTLC
// TODO(roasbeef): pending commit pointer?
// * lets just walk through
}
// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in
// the default usable state, or a state where it shouldn't be used.
type ChannelStatus uint8
var (
// ChanStatusDefault is the normal state of an open channel.
ChanStatusDefault ChannelStatus
// ChanStatusBorked indicates that the channel has entered an
// irreconcilable state, triggered by a state desynchronization or
// channel breach. Channels in this state should never be added to the
// htlc switch.
ChanStatusBorked ChannelStatus = 1
// ChanStatusCommitBroadcasted indicates that a commitment for this
// channel has been broadcasted.
ChanStatusCommitBroadcasted ChannelStatus = 1 << 1
// ChanStatusLocalDataLoss indicates that we have lost channel state
// for this channel, and broadcasting our latest commitment might be
// considered a breach.
//
// TODO(halseh): actually enforce that we are not force closing such a
// channel.
ChanStatusLocalDataLoss ChannelStatus = 1 << 2
// ChanStatusRestored is a status flag that signals that the channel
// has been restored, and doesn't have all the fields a typical channel
// will have.
ChanStatusRestored ChannelStatus = 1 << 3
)
// chanStatusStrings maps a ChannelStatus to a human friendly string that
// describes that status.
var chanStatusStrings = map[ChannelStatus]string{
ChanStatusDefault: "ChanStatusDefault",
ChanStatusBorked: "ChanStatusBorked",
ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted",
ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss",
ChanStatusRestored: "ChanStatusRestored",
}
// orderedChanStatusFlags is an in-order list of all that channel status flags.
var orderedChanStatusFlags = []ChannelStatus{
ChanStatusDefault,
ChanStatusBorked,
ChanStatusCommitBroadcasted,
ChanStatusLocalDataLoss,
ChanStatusRestored,
}
// String returns a human-readable representation of the ChannelStatus.
func (c ChannelStatus) String() string {
// If no flags are set, then this is the default case.
if c == 0 {
return chanStatusStrings[ChanStatusDefault]
}
// Add individual bit flags.
statusStr := ""
for _, flag := range orderedChanStatusFlags {
if c&flag == flag {
statusStr += chanStatusStrings[flag] + "|"
c -= flag
}
}
// Remove anything to the right of the final bar, including it as well.
statusStr = strings.TrimRight(statusStr, "|")
// Add any remaining flags which aren't accounted for as hex.
if c != 0 {
statusStr += "|0x" + strconv.FormatUint(uint64(c), 16)
}
// If this was purely an unknown flag, then remove the extra bar at the
// start of the string.
statusStr = strings.TrimLeft(statusStr, "|")
return statusStr
}
// OpenChannel encapsulates the persistent and dynamic state of an open channel
// with a remote node. An open channel supports several options for on-disk
// serialization depending on the exact context. Full (upon channel creation)
// state commitments, and partial (due to a commitment update) writes are
// supported. Each partial write due to a state update appends the new update
// to an on-disk log, which can then subsequently be queried in order to
// "time-travel" to a prior state.
type OpenChannel struct {
// ChanType denotes which type of channel this is.
ChanType ChannelType
// ChainHash is a hash which represents the blockchain that this
// channel will be opened within. This value is typically the genesis
// hash. In the case that the original chain went through a contentious
// hard-fork, then this value will be tweaked using the unique fork
// point on each branch.
ChainHash chainhash.Hash
// FundingOutpoint is the outpoint of the final funding transaction.
// This value uniquely and globally identifies the channel within the
// target blockchain as specified by the chain hash parameter.
FundingOutpoint wire.OutPoint
// ShortChannelID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
ShortChannelID lnwire.ShortChannelID
// IsPending indicates whether a channel's funding transaction has been
// confirmed.
IsPending bool
// IsInitiator is a bool which indicates if we were the original
// initiator for the channel. This value may affect how higher levels
// negotiate fees, or close the channel.
IsInitiator bool
// FundingBroadcastHeight is the height in which the funding
// transaction was broadcast. This value can be used by higher level
// sub-systems to determine if a channel is stale and/or should have
// been confirmed before a certain height.
FundingBroadcastHeight uint32
// NumConfsRequired is the number of confirmations a channel's funding
// transaction must have received in order to be considered available
// for normal transactional use.
NumConfsRequired uint16
// ChannelFlags holds the flags that were sent as part of the
// open_channel message.
ChannelFlags lnwire.FundingFlag
// IdentityPub is the identity public key of the remote node this
// channel has been established with.
IdentityPub *btcec.PublicKey
// Capacity is the total capacity of this channel.
Capacity btcutil.Amount
// TotalMSatSent is the total number of milli-satoshis we've sent
// within this channel.
TotalMSatSent lnwire.MilliSatoshi
// TotalMSatReceived is the total number of milli-satoshis we've
// received within this channel.
TotalMSatReceived lnwire.MilliSatoshi
// LocalChanCfg is the channel configuration for the local node.
LocalChanCfg ChannelConfig
// RemoteChanCfg is the channel configuration for the remote node.
RemoteChanCfg ChannelConfig
// LocalCommitment is the current local commitment state for the local
// party. This is stored distinct from the state of the remote party
// as there are certain asymmetric parameters which affect the
// structure of each commitment.
LocalCommitment ChannelCommitment
// RemoteCommitment is the current remote commitment state for the
// remote party. This is stored distinct from the state of the local
// party as there are certain asymmetric parameters which affect the
// structure of each commitment.
RemoteCommitment ChannelCommitment
// RemoteCurrentRevocation is the current revocation for their
// commitment transaction. However, since this the derived public key,
// we don't yet have the private key so we aren't yet able to verify
// that it's actually in the hash chain.
RemoteCurrentRevocation *btcec.PublicKey
// RemoteNextRevocation is the revocation key to be used for the *next*
// commitment transaction we create for the local node. Within the
// specification, this value is referred to as the
// per-commitment-point.
RemoteNextRevocation *btcec.PublicKey
// RevocationProducer is used to generate the revocation in such a way
// that remote side might store it efficiently and have the ability to
// restore the revocation by index if needed. Current implementation of
// secret producer is shachain producer.
RevocationProducer shachain.Producer
// RevocationStore is used to efficiently store the revocations for
// previous channels states sent to us by remote side. Current
// implementation of secret store is shachain store.
RevocationStore shachain.Store
// FundingTxn is the transaction containing this channel's funding
// outpoint. Upon restarts, this txn will be rebroadcast if the channel
// is found to be pending.
//
// NOTE: This value will only be populated for single-funder channels
// for which we are the initiator.
FundingTxn *wire.MsgTx
// TODO(roasbeef): eww
Db *DB
// TODO(roasbeef): just need to store local and remote HTLC's?
sync.RWMutex
}
// ShortChanID returns the current ShortChannelID of this channel.
func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID {
c.RLock()
defer c.RUnlock()
return c.ShortChannelID
}
// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are
// contained within ChannelDeltas which encode the current state of the
// commitment between state updates.
//
// TODO(roasbeef): save space by using smaller ints at tail end?
type HTLC struct {
// Signature is the signature for the second level covenant transaction
// for this HTLC. The second level transaction is a timeout tx in the
// case that this is an outgoing HTLC, and a success tx in the case
// that this is an incoming HTLC.
//
// TODO(roasbeef): make [64]byte instead?
Signature []byte
// RHash is the payment hash of the HTLC.
RHash [32]byte
// Amt is the amount of milli-satoshis this HTLC escrows.
Amt lnwire.MilliSatoshi
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout uint32
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
OutputIndex int32
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
Incoming bool
// OnionBlob is an opaque blob which is used to complete multi-hop
// routing.
OnionBlob []byte
// HtlcIndex is the HTLC counter index of this active, outstanding
// HTLC. This differs from the LogIndex, as the HtlcIndex is only
// incremented for each offered HTLC, while they LogIndex is
// incremented for each update (includes settle+fail).
HtlcIndex uint64
// LogIndex is the cumulative log index of this HTLC. This differs
// from the HtlcIndex as this will be incremented for each new log
// update added.
LogIndex uint64
}
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
// from the switch, and is used to purge our in-memory state of HTLCs that have
// already been processed by a link. Two list of CircuitKeys are included in
// each CommitDiff to allow a link to determine which in-memory htlcs directed
// the opening and closing of circuits in the switch's circuit map.
type CircuitKey struct {
// ChanID is the short chanid indicating the HTLC's origin.
//
// NOTE: It is fine for this value to be blank, as this indicates a
// locally-sourced payment.
ChanID lnwire.ShortChannelID
// HtlcID is the unique htlc index predominately assigned by links,
// though can also be assigned by switch in the case of locally-sourced
// payments.
HtlcID uint64
}
// String returns a string representation of the CircuitKey.
func (k CircuitKey) String() string {
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID)
}
// ClosureType is an enum like structure that details exactly _how_ a channel
// was closed. Three closure types are currently possible: none, cooperative,
// local force close, remote force close, and (remote) breach.
type ClosureType uint8
const (
// RemoteForceClose indicates that the remote peer has unilaterally
// broadcast their current commitment state on-chain.
RemoteForceClose ClosureType = 4
)
// ChannelCloseSummary contains the final state of a channel at the point it
// was closed. Once a channel is closed, all the information pertaining to that
// channel within the openChannelBucket is deleted, and a compact summary is
// put in place instead.
type ChannelCloseSummary struct {
// ChanPoint is the outpoint for this channel's funding transaction,
// and is used as a unique identifier for the channel.
ChanPoint wire.OutPoint
// ShortChanID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
ShortChanID lnwire.ShortChannelID
// ChainHash is the hash of the genesis block that this channel resides
// within.
ChainHash chainhash.Hash
// ClosingTXID is the txid of the transaction which ultimately closed
// this channel.
ClosingTXID chainhash.Hash
// RemotePub is the public key of the remote peer that we formerly had
// a channel with.
RemotePub *btcec.PublicKey
// Capacity was the total capacity of the channel.
Capacity btcutil.Amount
// CloseHeight is the height at which the funding transaction was
// spent.
CloseHeight uint32
// SettledBalance is our total balance settled balance at the time of
// channel closure. This _does not_ include the sum of any outputs that
// have been time-locked as a result of the unilateral channel closure.
SettledBalance btcutil.Amount
// TimeLockedBalance is the sum of all the time-locked outputs at the
// time of channel closure. If we triggered the force closure of this
// channel, then this value will be non-zero if our settled output is
// above the dust limit. If we were on the receiving side of a channel
// force closure, then this value will be non-zero if we had any
// outstanding outgoing HTLC's at the time of channel closure.
TimeLockedBalance btcutil.Amount
// CloseType details exactly _how_ the channel was closed. Five closure
// types are possible: cooperative, local force, remote force, breach
// and funding canceled.
CloseType ClosureType
// IsPending indicates whether this channel is in the 'pending close'
// state, which means the channel closing transaction has been
// confirmed, but not yet been fully resolved. In the case of a channel
// that has been cooperatively closed, it will go straight into the
// fully resolved state as soon as the closing transaction has been
// confirmed. However, for channels that have been force closed, they'll
// stay marked as "pending" until _all_ the pending funds have been
// swept.
IsPending bool
// RemoteCurrentRevocation is the current revocation for their
// commitment transaction. However, since this is the derived public key,
// we don't yet have the private key so we aren't yet able to verify
// that it's actually in the hash chain.
RemoteCurrentRevocation *btcec.PublicKey
// RemoteNextRevocation is the revocation key to be used for the *next*
// commitment transaction we create for the local node. Within the
// specification, this value is referred to as the
// per-commitment-point.
RemoteNextRevocation *btcec.PublicKey
// LocalChanConfig is the channel configuration for the local node.
LocalChanConfig ChannelConfig
// LastChanSyncMsg is the ChannelReestablish message for this channel
// for the state at the point where it was closed.
LastChanSyncMsg *lnwire.ChannelReestablish
}
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
)
if err != nil {
return err
}
// If this is a close channel summary created before the addition of
// the new fields, then we can exit here.
if cs.RemoteCurrentRevocation == nil {
return WriteElements(w, false)
}
// If fields are present, write boolean to indicate this, and continue.
if err := WriteElements(w, true); err != nil {
return err
}
if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err
}
if err := WriteChanConfig(w, &cs.LocalChanConfig); err != nil {
return err
}
// The RemoteNextRevocation field is optional, as it's possible for a
// channel to be closed before we learn of the next unrevoked
// revocation point for the remote party. Write a boolean indicating
// whether this field is present or not.
if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil {
return err
}
// Write the field, if present.
if cs.RemoteNextRevocation != nil {
if err = WriteElements(w, cs.RemoteNextRevocation); err != nil {
return err
}
}
// Write whether the channel sync message is present.
if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil {
return err
}
// Write the channel sync message, if present.
if cs.LastChanSyncMsg != nil {
if err := WriteElements(w, cs.LastChanSyncMsg); err != nil {
return err
}
}
return nil
}
func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
var hasNewFields bool
err = ReadElements(r, &hasNewFields)
if err != nil {
return nil, err
}
// If fields are not present, we can return.
if !hasNewFields {
return c, nil
}
// Otherwise read the new fields.
if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil {
return nil, err
}
if err := ReadChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// funding locked message then this might not be present. A boolean
// indicating whether the field is present will come first.
var hasRemoteNextRevocation bool
err = ReadElements(r, &hasRemoteNextRevocation)
if err != nil {
return nil, err
}
// If this field was written, read it.
if hasRemoteNextRevocation {
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil {
return nil, err
}
}
// Check if we have a channel sync message to read.
var hasChanSyncMsg bool
err = ReadElements(r, &hasChanSyncMsg)
if err == io.EOF {
return c, nil
} else if err != nil {
return nil, err
}
// If a chan sync message is present, read it.
if hasChanSyncMsg {
// We must pass in reference to a lnwire.Message for the codec
// to support it.
var msg lnwire.Message
if err := ReadElements(r, &msg); err != nil {
return nil, err
}
chanSync, ok := msg.(*lnwire.ChannelReestablish)
if !ok {
return nil, errors.New("unable cast db Message to " +
"ChannelReestablish")
}
c.LastChanSyncMsg = chanSync
}
return c, nil
}
func WriteChanConfig(b io.Writer, c *ChannelConfig) error {
return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
c.HtlcBasePoint,
)
}
func ReadChanConfig(b io.Reader, c *ChannelConfig) error {
return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint,
&c.PaymentBasePoint, &c.DelayBasePoint,
&c.HtlcBasePoint,
)
}
func DeserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
var c ChannelCommitment
err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
)
if err != nil {
return c, err
}
c.Htlcs, err = DeserializeHtlcs(r)
if err != nil {
return c, err
}
return c, nil
}
// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed
// io.Reader. The bytes within the passed reader MUST have been previously
// written to using the SerializeHtlcs function.
//
// NOTE: This API is NOT stable, the on-disk format will likely change in the
// future.
func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
var numHtlcs uint16
if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err
}
var htlcs []HTLC
if numHtlcs == 0 {
return htlcs, nil
}
htlcs = make([]HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ {
if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &htlcs[i].OnionBlob,
&htlcs[i].HtlcIndex, &htlcs[i].LogIndex,
); err != nil {
return htlcs, err
}
}
return htlcs, nil
}
func SerializeChanCommit(w io.Writer, c *ChannelCommitment) error {
if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
c.CommitSig,
); err != nil {
return err
}
return SerializeHtlcs(w, c.Htlcs...)
}
// SerializeHtlcs writes out the passed set of HTLC's into the passed writer
// using the current default on-disk serialization format.
//
// NOTE: This API is NOT stable, the on-disk format will likely change in the
// future.
func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
numHtlcs := uint16(len(htlcs))
if err := WriteElement(b, numHtlcs); err != nil {
return err
}
for _, htlc := range htlcs {
if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:],
htlc.HtlcIndex, htlc.LogIndex,
); err != nil {
return err
}
}
return nil
}
package migration_01_to_11
import (
"encoding/binary"
"fmt"
"io"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/shachain"
)
// WriteOutpoint writes an outpoint to the passed writer using the minimal
// amount of bytes possible.
func WriteOutpoint(w io.Writer, o *wire.OutPoint) error {
if _, err := w.Write(o.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, o.Index); err != nil {
return err
}
return nil
}
// ReadOutpoint reads an outpoint from the passed reader that was previously
// written using the writeOutpoint struct.
func ReadOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil {
return err
}
if err := binary.Write(w, byteOrder, e.Index); err != nil {
return err
}
if e.PubKey != nil {
if err := binary.Write(w, byteOrder, true); err != nil {
return fmt.Errorf("error writing serialized element: %w", err)
}
return WriteElement(w, e.PubKey)
}
return binary.Write(w, byteOrder, false)
case ChannelType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case chainhash.Hash:
if _, err := w.Write(e[:]); err != nil {
return err
}
case wire.OutPoint:
return WriteOutpoint(w, &e)
case lnwire.ShortChannelID:
if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil {
return err
}
case lnwire.ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case int64, uint64:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case int32:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint16:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case uint8:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case bool:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case btcutil.Amount:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case lnwire.MilliSatoshi:
if err := binary.Write(w, byteOrder, uint64(e)); err != nil {
return err
}
case *btcec.PrivateKey:
b := e.Serialize()
if _, err := w.Write(b); err != nil {
return err
}
case *btcec.PublicKey:
b := e.SerializeCompressed()
if _, err := w.Write(b); err != nil {
return err
}
case shachain.Producer:
return e.Encode(w)
case shachain.Store:
return e.Encode(w)
case *wire.MsgTx:
return e.Serialize(w)
case [32]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case lnwire.Message:
if _, err := lnwire.WriteMessage(w, e, 0); err != nil {
return err
}
case ChannelStatus:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case ClosureType:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case lnwire.FundingFlag:
if err := binary.Write(w, byteOrder, e); err != nil {
return err
}
case net.Addr:
if err := serializeAddr(w, e); err != nil {
return err
}
case []net.Addr:
if err := WriteElement(w, uint32(len(e))); err != nil {
return err
}
for _, addr := range e {
if err := serializeAddr(w, addr); err != nil {
return err
}
}
default:
return UnknownElementType{"WriteElement", e}
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(w, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &e.Index); err != nil {
return err
}
var hasPubKey bool
if err := binary.Read(r, byteOrder, &hasPubKey); err != nil {
return err
}
if hasPubKey {
return ReadElement(r, &e.PubKey)
}
case *ChannelType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *chainhash.Hash:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *wire.OutPoint:
return ReadOutpoint(r, e)
case *lnwire.ShortChannelID:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.NewShortChanIDFromInt(a)
case *lnwire.ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *int64, *uint64:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *int32:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint16:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *uint8:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *bool:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *btcutil.Amount:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = btcutil.Amount(a)
case *lnwire.MilliSatoshi:
var a uint64
if err := binary.Read(r, byteOrder, &a); err != nil {
return err
}
*e = lnwire.MilliSatoshi(a)
case **btcec.PrivateKey:
var b [btcec.PrivKeyBytesLen]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
priv, _ := btcec.PrivKeyFromBytes(b[:])
*e = priv
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case *shachain.Producer:
var root [32]byte
if _, err := io.ReadFull(r, root[:]); err != nil {
return err
}
// TODO(roasbeef): remove
producer, err := shachain.NewRevocationProducerFromBytes(root[:])
if err != nil {
return err
}
*e = producer
case *shachain.Store:
store, err := shachain.NewRevocationStoreFromBytes(r)
if err != nil {
return err
}
*e = store
case **wire.MsgTx:
tx := wire.NewMsgTx(2)
if err := tx.Deserialize(r); err != nil {
return err
}
*e = tx
case *[32]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *[]byte:
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
if err != nil {
return err
}
*e = bytes
case *lnwire.Message:
msg, err := lnwire.ReadMessage(r, 0)
if err != nil {
return err
}
*e = msg
case *ChannelStatus:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *ClosureType:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *lnwire.FundingFlag:
if err := binary.Read(r, byteOrder, e); err != nil {
return err
}
case *net.Addr:
addr, err := deserializeAddr(r)
if err != nil {
return err
}
*e = addr
case *[]net.Addr:
var numAddrs uint32
if err := ReadElement(r, &numAddrs); err != nil {
return err
}
*e = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := deserializeAddr(r)
if err != nil {
return err
}
(*e)[i] = addr
}
default:
return UnknownElementType{"ReadElement", e}
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"fmt"
"os"
"path/filepath"
"time"
"github.com/lightningnetwork/lnd/kvdb"
)
const (
dbName = "channel.db"
dbFilePermission = 0600
)
// migration is a function which takes a prior outdated version of the database
// instances and mutates the key/bucket structure to arrive at a more
// up-to-date version of the database.
type migration func(tx kvdb.RwTx) error
var (
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// DB is the primary datastore for the lnd daemon. The database stores
// information related to nodes, routing data, open/closed channels, fee
// schedules, and reputation data.
type DB struct {
kvdb.Backend
dbPath string
graph *ChannelGraph
now func() time.Time
}
// Open opens an existing channeldb. Any necessary schemas migrations due to
// updates will take place as necessary.
func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) {
path := filepath.Join(dbPath, dbName)
if !fileExists(path) {
if err := createChannelDB(dbPath); err != nil {
return nil, err
}
}
opts := DefaultOptions()
for _, modifier := range modifiers {
modifier(&opts)
}
// Specify bbolt freelist options to reduce heap pressure in case the
// freelist grows to be very large.
bdb, err := kvdb.Open(
kvdb.BoltBackendName, path,
opts.NoFreelistSync, opts.DBTimeout, false,
)
if err != nil {
return nil, err
}
chanDB := &DB{
Backend: bdb,
dbPath: dbPath,
now: time.Now,
}
chanDB.graph = newChannelGraph(
chanDB, opts.RejectCacheSize, opts.ChannelCacheSize,
)
return chanDB, nil
}
// createChannelDB creates and initializes a fresh version of channeldb. In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
// within the database are created.
func createChannelDB(dbPath string) error {
if !fileExists(dbPath) {
if err := os.MkdirAll(dbPath, 0700); err != nil {
return err
}
}
path := filepath.Join(dbPath, dbName)
bdb, err := kvdb.Create(
kvdb.BoltBackendName, path, false, kvdb.DefaultDBTimeout,
false,
)
if err != nil {
return err
}
err = kvdb.Update(bdb, func(tx kvdb.RwTx) error {
if _, err := tx.CreateTopLevelBucket(openChannelBucket); err != nil {
return err
}
if _, err := tx.CreateTopLevelBucket(closedChannelBucket); err != nil {
return err
}
if _, err := tx.CreateTopLevelBucket(invoiceBucket); err != nil {
return err
}
if _, err := tx.CreateTopLevelBucket(paymentBucket); err != nil {
return err
}
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucket(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucket(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return err
}
if _, err := edges.CreateBucket(edgeIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(channelPointBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(zombieBucket); err != nil {
return err
}
graphMeta, err := tx.CreateTopLevelBucket(graphMetaBucket)
if err != nil {
return err
}
_, err = graphMeta.CreateBucket(pruneLogBucket)
if err != nil {
return err
}
if _, err := tx.CreateTopLevelBucket(metaBucket); err != nil {
return err
}
meta := &Meta{
DbVersionNumber: 0,
}
return putMeta(meta, tx)
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channeldb")
}
return bdb.Close()
}
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// FetchClosedChannels attempts to fetch all closed channels from the database.
// The pendingOnly bool toggles if channels that aren't yet fully closed should
// be returned in the response or not. When a channel was cooperatively closed,
// it becomes fully closed after a single confirmation. When a channel was
// forcibly closed, it will become fully closed after _all_ the pending funds
// (if any) have been swept.
func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) {
var chanSummaries []*ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil {
return ErrNoClosedChannels
}
return closeBucket.ForEach(func(chanID []byte, summaryBytes []byte) error {
summaryReader := bytes.NewReader(summaryBytes)
chanSummary, err := deserializeCloseChannelSummary(summaryReader)
if err != nil {
return err
}
// If the query specified to only include pending
// channels, then we'll skip any channels which aren't
// currently pending.
if !chanSummary.IsPending && pendingOnly {
return nil
}
chanSummaries = append(chanSummaries, chanSummary)
return nil
})
}, func() {
chanSummaries = nil
}); err != nil {
return nil, err
}
return chanSummaries, nil
}
// ChannelGraph returns a new instance of the directed channel graph.
func (d *DB) ChannelGraph() *ChannelGraph {
return d.graph
}
package migration_01_to_11
import (
"fmt"
)
var (
// ErrNoInvoicesCreated is returned when we don't have invoices in
// our database to return.
ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices")
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
// created.
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")
// ErrGraphNotFound is returned when at least one of the components of
// graph doesn't exist.
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
// ErrSourceNodeNotSet is returned if the source node of the graph
// hasn't been added The source node is the center node within a
// star-graph.
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
// ErrGraphNodeNotFound is returned when we're unable to find the target
// node.
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
// ErrEdgeNotFound is returned when an edge for the target chanID
// can't be found.
ErrEdgeNotFound = fmt.Errorf("edge not found")
// ErrUnknownAddressType is returned when a node's addressType is not
// an expected value.
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
// ErrNoClosedChannels is returned when a node is queries for all the
// channels it has closed, but it hasn't yet closed any channels.
ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet")
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
// policy field is not found in the db even though its message flags
// indicate it should be.
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
"present")
)
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the
// caller attempts to write an announcement message which bares too many extra
// opaque bytes. We limit this value in order to ensure that we don't waste
// disk space due to nodes unnecessarily padding out their announcements with
// garbage data.
func ErrTooManyExtraOpaqueBytes(numBytes int) error {
return fmt.Errorf("max allowed number of opaque bytes is %v, received "+
"%v bytes", MaxAllowedExtraOpaqueBytes, numBytes)
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"fmt"
"image/color"
"io"
"net"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// nodeBucket is a bucket which houses all the vertices or nodes within
// the channel graph. This bucket has a single-sub bucket which adds an
// additional index from pubkey -> alias. Within the top-level of this
// bucket, the key space maps a node's compressed public key to the
// serialized information for that node. Additionally, there's a
// special key "source" which stores the pubkey of the source node. The
// source node is used as the starting point for all graph/queries and
// traversals. The graph is formed as a star-graph with the source node
// at the center.
//
// maps: pubKey -> nodeInfo
// maps: source -> selfPubKey
nodeBucket = []byte("graph-node")
// nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket
// will be used to quickly look up the "freshness" of a node's last
// update to the network. The bucket only contains keys, and no values,
// it's mapping:
//
// maps: updateTime || nodeID -> nil
nodeUpdateIndexBucket = []byte("graph-node-update-index")
// sourceKey is a special key that resides within the nodeBucket. The
// sourceKey maps a key to the public key of the "self node".
sourceKey = []byte("source")
// aliasIndexBucket is a sub-bucket that's nested within the main
// nodeBucket. This bucket maps the public key of a node to its
// current alias. This bucket is provided as it can be used within a
// future UI layer to add an additional degree of confirmation.
aliasIndexBucket = []byte("alias")
// edgeBucket is a bucket which houses all of the edge or channel
// information within the channel graph. This bucket essentially acts
// as an adjacency list, which in conjunction with a range scan, can be
// used to iterate over all the incoming and outgoing edges for a
// particular node. Key in the bucket use a prefix scheme which leads
// with the node's public key and sends with the compact edge ID.
// For each chanID, there will be two entries within the bucket, as the
// graph is directed: nodes may have different policies w.r.t to fees
// for their respective directions.
//
// maps: pubKey || chanID -> channel edge policy for node
edgeBucket = []byte("graph-edge")
// unknownPolicy is represented as an empty slice. It is
// used as the value in edgeBucket for unknown channel edge policies.
// Unknown policies are still stored in the database to enable efficient
// lookup of incoming channel edges.
unknownPolicy = []byte{}
// edgeIndexBucket is an index which can be used to iterate all edges
// in the bucket, grouping them according to their in/out nodes.
// Additionally, the items in this bucket also contain the complete
// edge information for a channel. The edge information includes the
// capacity of the channel, the nodes that made the channel, etc. This
// bucket resides within the edgeBucket above. Creation of an edge
// proceeds in two phases: first the edge is added to the edge index,
// afterwards the edgeBucket can be updated with the latest details of
// the edge as they are announced on the network.
//
// maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo
edgeIndexBucket = []byte("edge-index")
// edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This
// bucket contains an index which allows us to gauge the "freshness" of
// a channel's last updates.
//
// maps: updateTime || chanID -> nil
edgeUpdateIndexBucket = []byte("edge-update-index")
// channelPointBucket maps a channel's full outpoint (txid:index) to
// its short 8-byte channel ID. This bucket resides within the
// edgeBucket above, and can be used to quickly remove an edge due to
// the outpoint being spent, or to query for existence of a channel.
//
// maps: outPoint -> chanID
channelPointBucket = []byte("chan-index")
// zombieBucket is a sub-bucket of the main edgeBucket bucket
// responsible for maintaining an index of zombie channels. Each entry
// exists within the bucket as follows:
//
// maps: chanID -> pubKey1 || pubKey2
//
// The chanID represents the channel ID of the edge that is marked as a
// zombie and is used as the key, which maps to the public keys of the
// edge's participants.
zombieBucket = []byte("zombie-index")
// disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket bucket
// responsible for maintaining an index of disabled edge policies. Each
// entry exists within the bucket as follows:
//
// maps: <chanID><direction> -> []byte{}
//
// The chanID represents the channel ID of the edge and the direction is
// one byte representing the direction of the edge. The main purpose of
// this index is to allow pruning disabled channels in a fast way without
// the need to iterate all over the graph.
disabledEdgePolicyBucket = []byte("disabled-edge-policy-index")
// graphMetaBucket is a top-level bucket which stores various meta-deta
// related to the on-disk channel graph. Data stored in this bucket
// includes the block to which the graph has been synced to, the total
// number of channels, etc.
graphMetaBucket = []byte("graph-meta")
// pruneLogBucket is a bucket within the graphMetaBucket that stores
// a mapping from the block height to the hash for the blocks used to
// prune the graph.
// Once a new block is discovered, any channels that have been closed
// (by spending the outpoint) can safely be removed from the graph, and
// the block is added to the prune log. We need to keep such a log for
// the case where a reorg happens, and we must "rewind" the state of the
// graph by removing channels that were previously confirmed. In such a
// case we'll remove all entries from the prune log with a block height
// that no longer exists.
pruneLogBucket = []byte("prune-log")
)
const (
// MaxAllowedExtraOpaqueBytes is the largest amount of opaque bytes that
// we'll permit to be written to disk. We limit this as otherwise, it
// would be possible for a node to create a ton of updates and slowly
// fill our disk, and also waste bandwidth due to relaying.
MaxAllowedExtraOpaqueBytes = 10000
)
// ChannelGraph is a persistent, on-disk graph representation of the Lightning
// Network. This struct can be used to implement path finding algorithms on top
// of, and also to update a node's view based on information received from the
// p2p network. Internally, the graph is stored using a modified adjacency list
// representation with some added object interaction possible with each
// serialized edge/node. The graph is stored is directed, meaning that are two
// edges stored for each channel: an inbound/outbound edge for each node pair.
// Nodes, edges, and edge information can all be added to the graph
// independently. Edge removal results in the deletion of all edge information
// for that edge.
type ChannelGraph struct {
db *DB
}
// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The
// returned instance has its own unique reject cache and channel cache.
func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph {
return &ChannelGraph{
db: db,
}
}
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
// node based off the source node.
func (c *ChannelGraph) SourceNode() (*LightningNode, error) {
var source *LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
node, err := c.sourceNode(nodes)
if err != nil {
return err
}
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err
}
return source, nil
}
// sourceNode uses an existing database transaction and returns the source node
// of the graph. The source node is treated as the center node within a
// star-graph. This method may be used to kick off a path finding algorithm in
// order to explore the reachability of another node based off the source node.
func (c *ChannelGraph) sourceNode(nodes kvdb.RBucket) (*LightningNode, error) {
selfPub := nodes.Get(sourceKey)
if selfPub == nil {
return nil, ErrSourceNodeNotSet
}
// With the pubKey of the source node retrieved, we're able to
// fetch the full node information.
node, err := fetchLightningNode(nodes, selfPub)
if err != nil {
return nil, err
}
node.db = c.db
return &node, nil
}
// SetSourceNode sets the source node within the graph database. The source
// node is to be used as the center of a star-graph within path finding
// algorithms.
func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
nodePubBytes := node.PubKeyBytes[:]
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
// Next we create the mapping from source to the targeted
// public key.
if err := nodes.Put(sourceKey, nodePubBytes); err != nil {
return err
}
// Finally, we commit the information of the lightning node
// itself.
return addLightningNode(tx, node)
}, func() {})
}
func addLightningNode(tx kvdb.RwTx, node *LightningNode) error {
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
aliases, err := nodes.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
updateIndex, err := nodes.CreateBucketIfNotExists(
nodeUpdateIndexBucket,
)
if err != nil {
return err
}
return putLightningNode(nodes, aliases, updateIndex, node)
}
// updateEdgePolicy attempts to update an edge's policy within the relevant
// buckets using an existing database transaction. The returned boolean will be
// true if the updated policy belongs to node1, and false if the policy belonged
// to node2.
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return false, ErrEdgeNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return false, ErrEdgeNotFound
}
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return false, err
}
// Create the channelID key be converting the channel ID
// integer into a byte slice.
var chanID [8]byte
byteOrder.PutUint64(chanID[:], edge.ChannelID)
// With the channel ID, we then fetch the value storing the two
// nodes which connect this channel edge.
nodeInfo := edgeIndex.Get(chanID[:])
if nodeInfo == nil {
return false, ErrEdgeNotFound
}
// Depending on the flags value passed above, either the first
// or second edge policy is being updated.
var fromNode, toNode []byte
var isUpdate1 bool
if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
fromNode = nodeInfo[:33]
toNode = nodeInfo[33:66]
isUpdate1 = true
} else {
fromNode = nodeInfo[33:66]
toNode = nodeInfo[:33]
isUpdate1 = false
}
// Finally, with the direction of the edge being updated
// identified, we update the on-disk edge representation.
err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode)
if err != nil {
return false, err
}
return isUpdate1, nil
}
// LightningNode represents an individual vertex/node within the channel graph.
// A node is connected to other nodes by one or more channel edges emanating
// from it. As the graph is directed, a node will also have an incoming edge
// attached to it for each outgoing edge.
type LightningNode struct {
// PubKeyBytes is the raw bytes of the public key of the target node.
PubKeyBytes [33]byte
pubKey *btcec.PublicKey
// HaveNodeAnnouncement indicates whether we received a node
// announcement for this particular node. If true, the remaining fields
// will be set, if false only the PubKey is known for this node.
HaveNodeAnnouncement bool
// LastUpdate is the last time the vertex information for this node has
// been updated.
LastUpdate time.Time
// Address is the TCP address this node is reachable over.
Addresses []net.Addr
// Color is the selected color for the node.
Color color.RGBA
// Alias is a nick-name for the node. The alias can be used to confirm
// a node's identity or to serve as a short ID for an address book.
Alias string
// AuthSigBytes is the raw signature under the advertised public key
// which serves to authenticate the attributes announced by this node.
AuthSigBytes []byte
// Features is the list of protocol features supported by this node.
Features *lnwire.FeatureVector
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
db *DB
// TODO(roasbeef): discovery will need storage to keep it's last IP
// address and re-announce if interface changes?
// TODO(roasbeef): add update method and fetch?
}
// PubKey is the node's long-term identity public key. This key will be used to
// authenticated any advertisements/updates sent by the node.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (l *LightningNode) PubKey() (*btcec.PublicKey, error) {
if l.pubKey != nil {
return l.pubKey, nil
}
key, err := btcec.ParsePubKey(l.PubKeyBytes[:])
if err != nil {
return nil, err
}
l.pubKey = key
return key, nil
}
// ChannelEdgeInfo represents a fully authenticated channel along with all its
// unique attributes. Once an authenticated channel announcement has been
// processed on the network, then an instance of ChannelEdgeInfo encapsulating
// the channels attributes is stored. The other portions relevant to routing
// policy of a channel are stored within a ChannelEdgePolicy for each direction
// of the channel.
type ChannelEdgeInfo struct {
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// ChainHash is the hash that uniquely identifies the chain that this
// channel was opened within.
//
// TODO(roasbeef): need to modify db keying for multi-chain
// * must add chain hash to prefix as well
ChainHash chainhash.Hash
// NodeKey1Bytes is the raw public key of the first node.
NodeKey1Bytes [33]byte
// NodeKey2Bytes is the raw public key of the first node.
NodeKey2Bytes [33]byte
// BitcoinKey1Bytes is the raw public key of the first node.
BitcoinKey1Bytes [33]byte
// BitcoinKey2Bytes is the raw public key of the first node.
BitcoinKey2Bytes [33]byte
// Features is an opaque byte slice that encodes the set of channel
// specific features that this channel edge supports.
Features []byte
// AuthProof is the authentication proof for this channel. This proof
// contains a set of signatures binding four identities, which attests
// to the legitimacy of the advertised channel.
AuthProof *ChannelAuthProof
// ChannelPoint is the funding outpoint of the channel. This can be
// used to uniquely identify the channel within the channel graph.
ChannelPoint wire.OutPoint
// Capacity is the total capacity of the channel, this is determined by
// the value output in the outpoint that created this channel.
Capacity btcutil.Amount
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// ChannelAuthProof is the authentication proof (the signature portion) for a
// channel. Using the four signatures contained in the struct, and some
// auxiliary knowledge (the funding script, node identities, and outpoint) nodes
// on the network are able to validate the authenticity and existence of a
// channel. Each of these signatures signs the following digest: chanID ||
// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len ||
// features.
type ChannelAuthProof struct {
// NodeSig1Bytes are the raw bytes of the first node signature encoded
// in DER format.
NodeSig1Bytes []byte
// NodeSig2Bytes are the raw bytes of the second node signature
// encoded in DER format.
NodeSig2Bytes []byte
// BitcoinSig1Bytes are the raw bytes of the first bitcoin signature
// encoded in DER format.
BitcoinSig1Bytes []byte
// BitcoinSig2Bytes are the raw bytes of the second bitcoin signature
// encoded in DER format.
BitcoinSig2Bytes []byte
}
// IsEmpty check is the authentication proof is empty Proof is empty if at
// least one of the signatures are equal to nil.
func (c *ChannelAuthProof) IsEmpty() bool {
return len(c.NodeSig1Bytes) == 0 ||
len(c.NodeSig2Bytes) == 0 ||
len(c.BitcoinSig1Bytes) == 0 ||
len(c.BitcoinSig2Bytes) == 0
}
// ChannelEdgePolicy represents a *directed* edge within the channel graph. For
// each channel in the database, there are two distinct edges: one for each
// possible direction of travel along the channel. The edges themselves hold
// information concerning fees, and minimum time-lock information which is
// utilized during path finding.
type ChannelEdgePolicy struct {
// SigBytes is the raw bytes of the signature of the channel edge
// policy. We'll only parse these if the caller needs to access the
// signature for validation purposes. Do not set SigBytes directly, but
// use SetSigBytes instead to make sure that the cache is invalidated.
SigBytes []byte
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// LastUpdate is the last time an authenticated edge for this channel
// was received.
LastUpdate time.Time
// MessageFlags is a bitfield which indicates the presence of optional
// fields (like max_htlc) in the policy.
MessageFlags lnwire.ChanUpdateMsgFlags
// ChannelFlags is a bitfield which signals the capabilities of the
// channel as well as the directed edge this update applies to.
ChannelFlags lnwire.ChanUpdateChanFlags
// TimeLockDelta is the number of blocks this node will subtract from
// the expiry of an incoming HTLC. This value expresses the time buffer
// the node would like to HTLC exchanges.
TimeLockDelta uint16
// MinHTLC is the smallest value HTLC this node will accept, expressed
// in millisatoshi.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the largest value HTLC this node will accept, expressed
// in millisatoshi.
MaxHTLC lnwire.MilliSatoshi
// FeeBaseMSat is the base HTLC fee that will be charged for forwarding
// ANY HTLC, expressed in mSAT's.
FeeBaseMSat lnwire.MilliSatoshi
// FeeProportionalMillionths is the rate that the node will charge for
// HTLCs for each millionth of a satoshi forwarded.
FeeProportionalMillionths lnwire.MilliSatoshi
// Node is the LightningNode that this directed edge leads to. Using
// this pointer the channel graph can further be traversed.
Node *LightningNode
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// IsDisabled determines whether the edge has the disabled bit set.
func (c *ChannelEdgePolicy) IsDisabled() bool {
return c.ChannelFlags&lnwire.ChanUpdateDisabled ==
lnwire.ChanUpdateDisabled
}
func putLightningNode(nodeBucket kvdb.RwBucket, aliasBucket kvdb.RwBucket,
updateIndex kvdb.RwBucket, node *LightningNode) error {
var (
scratch [16]byte
b bytes.Buffer
)
pub, err := node.PubKey()
if err != nil {
return err
}
nodePub := pub.SerializeCompressed()
// If the node has the update time set, write it, else write 0.
updateUnix := uint64(0)
if node.LastUpdate.Unix() > 0 {
updateUnix = uint64(node.LastUpdate.Unix())
}
byteOrder.PutUint64(scratch[:8], updateUnix)
if _, err := b.Write(scratch[:8]); err != nil {
return err
}
if _, err := b.Write(nodePub); err != nil {
return err
}
// If we got a node announcement for this node, we will have the rest
// of the data available. If not we don't have more data to write.
if !node.HaveNodeAnnouncement {
// Write HaveNodeAnnouncement=0.
byteOrder.PutUint16(scratch[:2], 0)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
return nodeBucket.Put(nodePub, b.Bytes())
}
// Write HaveNodeAnnouncement=1.
byteOrder.PutUint16(scratch[:2], 1)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.R); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.G); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.B); err != nil {
return err
}
if err := wire.WriteVarString(&b, 0, node.Alias); err != nil {
return err
}
if err := node.Features.Encode(&b); err != nil {
return err
}
numAddresses := uint16(len(node.Addresses))
byteOrder.PutUint16(scratch[:2], numAddresses)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
for _, address := range node.Addresses {
if err := serializeAddr(&b, address); err != nil {
return err
}
}
sigLen := len(node.AuthSigBytes)
if sigLen > 80 {
return fmt.Errorf("max sig len allowed is 80, had %v",
sigLen)
}
err = wire.WriteVarBytes(&b, 0, node.AuthSigBytes)
if err != nil {
return err
}
if len(node.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(node.ExtraOpaqueData))
}
err = wire.WriteVarBytes(&b, 0, node.ExtraOpaqueData)
if err != nil {
return err
}
if err := aliasBucket.Put(nodePub, []byte(node.Alias)); err != nil {
return err
}
// With the alias bucket updated, we'll now update the index that
// tracks the time series of node updates.
var indexKey [8 + 33]byte
byteOrder.PutUint64(indexKey[:8], updateUnix)
copy(indexKey[8:], nodePub)
// If there was already an old index entry for this node, then we'll
// delete the old one before we write the new entry.
if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil {
// Extract out the old update time to we can reconstruct the
// prior index key to delete it from the index.
oldUpdateTime := nodeBytes[:8]
var oldIndexKey [8 + 33]byte
copy(oldIndexKey[:8], oldUpdateTime)
copy(oldIndexKey[8:], nodePub)
if err := updateIndex.Delete(oldIndexKey[:]); err != nil {
return err
}
}
if err := updateIndex.Put(indexKey[:], nil); err != nil {
return err
}
return nodeBucket.Put(nodePub, b.Bytes())
}
func fetchLightningNode(nodeBucket kvdb.RBucket,
nodePub []byte) (LightningNode, error) {
nodeBytes := nodeBucket.Get(nodePub)
if nodeBytes == nil {
return LightningNode{}, ErrGraphNodeNotFound
}
nodeReader := bytes.NewReader(nodeBytes)
return deserializeLightningNode(nodeReader)
}
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
var (
node LightningNode
scratch [8]byte
err error
)
if _, err := r.Read(scratch[:]); err != nil {
return LightningNode{}, err
}
unix := int64(byteOrder.Uint64(scratch[:]))
node.LastUpdate = time.Unix(unix, 0)
if _, err := io.ReadFull(r, node.PubKeyBytes[:]); err != nil {
return LightningNode{}, err
}
if _, err := r.Read(scratch[:2]); err != nil {
return LightningNode{}, err
}
hasNodeAnn := byteOrder.Uint16(scratch[:2])
if hasNodeAnn == 1 {
node.HaveNodeAnnouncement = true
} else {
node.HaveNodeAnnouncement = false
}
// The rest of the data is optional, and will only be there if we got a node
// announcement for this node.
if !node.HaveNodeAnnouncement {
return node, nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
if err := binary.Read(r, byteOrder, &node.Color.R); err != nil {
return LightningNode{}, err
}
if err := binary.Read(r, byteOrder, &node.Color.G); err != nil {
return LightningNode{}, err
}
if err := binary.Read(r, byteOrder, &node.Color.B); err != nil {
return LightningNode{}, err
}
node.Alias, err = wire.ReadVarString(r, 0)
if err != nil {
return LightningNode{}, err
}
fv := lnwire.NewFeatureVector(nil, nil)
err = fv.Decode(r)
if err != nil {
return LightningNode{}, err
}
node.Features = fv
if _, err := r.Read(scratch[:2]); err != nil {
return LightningNode{}, err
}
numAddresses := int(byteOrder.Uint16(scratch[:2]))
var addresses []net.Addr
for i := 0; i < numAddresses; i++ {
address, err := deserializeAddr(r)
if err != nil {
return LightningNode{}, err
}
addresses = append(addresses, address)
}
node.Addresses = addresses
node.AuthSigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig")
if err != nil {
return LightningNode{}, err
}
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the node as is.
node.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case err == io.ErrUnexpectedEOF:
case err == io.EOF:
case err != nil:
return LightningNode{}, err
}
return node, nil
}
func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) {
var (
err error
edgeInfo ChannelEdgeInfo
)
if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil {
return ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil {
return ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil {
return ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil {
return ChannelEdgeInfo{}, err
}
edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features")
if err != nil {
return ChannelEdgeInfo{}, err
}
proof := &ChannelAuthProof{}
proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return ChannelEdgeInfo{}, err
}
proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return ChannelEdgeInfo{}, err
}
proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return ChannelEdgeInfo{}, err
}
proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return ChannelEdgeInfo{}, err
}
if !proof.IsEmpty() {
edgeInfo.AuthProof = proof
}
edgeInfo.ChannelPoint = wire.OutPoint{}
if err := ReadOutpoint(r, &edgeInfo.ChannelPoint); err != nil {
return ChannelEdgeInfo{}, err
}
if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil {
return ChannelEdgeInfo{}, err
}
if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil {
return ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil {
return ChannelEdgeInfo{}, err
}
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the edge as is.
edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case err == io.ErrUnexpectedEOF:
case err == io.EOF:
case err != nil:
return ChannelEdgeInfo{}, err
}
return edgeInfo, nil
}
func putChanEdgePolicy(edges, nodes kvdb.RwBucket, edge *ChannelEdgePolicy,
from, to []byte) error {
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], edge.ChannelID)
var b bytes.Buffer
if err := serializeChanEdgePolicy(&b, edge, to); err != nil {
return err
}
// Before we write out the new edge, we'll create a new entry in the
// update index in order to keep it fresh.
updateUnix := uint64(edge.LastUpdate.Unix())
var indexKey [8 + 8]byte
byteOrder.PutUint64(indexKey[:8], updateUnix)
byteOrder.PutUint64(indexKey[8:], edge.ChannelID)
updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
if err != nil {
return err
}
// If there was already an entry for this edge, then we'll need to
// delete the old one to ensure we don't leave around any after-images.
// An unknown policy value does not have a update time recorded, so
// it also does not need to be removed.
if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil &&
!bytes.Equal(edgeBytes[:], unknownPolicy) {
// In order to delete the old entry, we'll need to obtain the
// *prior* update time in order to delete it. To do this, we'll
// need to deserialize the existing policy within the database
// (now outdated by the new one), and delete its corresponding
// entry within the update index. We'll ignore any
// ErrEdgePolicyOptionalFieldNotFound error, as we only need
// the channel ID and update time to delete the entry.
// TODO(halseth): get rid of these invalid policies in a
// migration.
oldEdgePolicy, err := deserializeChanEdgePolicy(
bytes.NewReader(edgeBytes), nodes,
)
if err != nil && err != ErrEdgePolicyOptionalFieldNotFound {
return err
}
oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix())
var oldIndexKey [8 + 8]byte
byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime)
byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID)
if err := updateIndex.Delete(oldIndexKey[:]); err != nil {
return err
}
}
if err := updateIndex.Put(indexKey[:], nil); err != nil {
return err
}
updateEdgePolicyDisabledIndex(
edges, edge.ChannelID,
edge.ChannelFlags&lnwire.ChanUpdateDirection > 0,
edge.IsDisabled(),
)
return edges.Put(edgeKey[:], b.Bytes()[:])
}
// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex
// bucket by either add a new disabled ChannelEdgePolicy or remove an existing
// one.
// The direction represents the direction of the edge and disabled is used for
// deciding whether to remove or add an entry to the bucket.
// In general a channel is disabled if two entries for the same chanID exist
// in this bucket.
// Maintaining the bucket this way allows a fast retrieval of disabled
// channels, for example when prune is needed.
func updateEdgePolicyDisabledIndex(edges kvdb.RwBucket, chanID uint64,
direction bool, disabled bool) error {
var disabledEdgeKey [8 + 1]byte
byteOrder.PutUint64(disabledEdgeKey[0:], chanID)
if direction {
disabledEdgeKey[8] = 1
}
disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists(
disabledEdgePolicyBucket,
)
if err != nil {
return err
}
if disabled {
return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{})
}
return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:])
}
// putChanEdgePolicyUnknown marks the edge policy as unknown
// in the edges bucket.
func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64,
from []byte) error {
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], channelID)
if edges.Get(edgeKey[:]) != nil {
return fmt.Errorf("Cannot write unknown policy for channel %v "+
" when there is already a policy present", channelID)
}
return edges.Put(edgeKey[:], unknownPolicy)
}
func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
nodePub []byte, nodes kvdb.RBucket) (*ChannelEdgePolicy, error) {
var edgeKey [33 + 8]byte
copy(edgeKey[:], nodePub)
copy(edgeKey[33:], chanID[:])
edgeBytes := edges.Get(edgeKey[:])
if edgeBytes == nil {
return nil, ErrEdgeNotFound
}
// No need to deserialize unknown policy.
if bytes.Equal(edgeBytes[:], unknownPolicy) {
return nil, nil
}
edgeReader := bytes.NewReader(edgeBytes)
ep, err := deserializeChanEdgePolicy(edgeReader, nodes)
switch {
// If the db policy was missing an expected optional field, we return
// nil as if the policy was unknown.
case err == ErrEdgePolicyOptionalFieldNotFound:
return nil, nil
case err != nil:
return nil, err
}
return ep, nil
}
func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy,
to []byte) error {
err := wire.WriteVarBytes(w, 0, edge.SigBytes)
if err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil {
return err
}
var scratch [8]byte
updateUnix := uint64(edge.LastUpdate.Unix())
byteOrder.PutUint64(scratch[:], updateUnix)
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil {
return err
}
if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil {
return err
}
if err := binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)); err != nil {
return err
}
if err := binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)); err != nil {
return err
}
if _, err := w.Write(to); err != nil {
return err
}
// If the max_htlc field is present, we write it. To be compatible with
// older versions that wasn't aware of this field, we write it as part
// of the opaque data.
// TODO(halseth): clean up when moving to TLV.
var opaqueBuf bytes.Buffer
if edge.MessageFlags.HasMaxHtlc() {
err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC))
if err != nil {
return err
}
}
if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData))
}
if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil {
return err
}
return nil
}
func deserializeChanEdgePolicy(r io.Reader,
nodes kvdb.RBucket) (*ChannelEdgePolicy, error) {
edge := &ChannelEdgePolicy{}
var err error
edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig")
if err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil {
return nil, err
}
var scratch [8]byte
if _, err := r.Read(scratch[:]); err != nil {
return nil, err
}
unix := int64(byteOrder.Uint64(scratch[:]))
edge.LastUpdate = time.Unix(unix, 0)
if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil {
return nil, err
}
var n uint64
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.MinHTLC = lnwire.MilliSatoshi(n)
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.FeeBaseMSat = lnwire.MilliSatoshi(n)
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n)
var pub [33]byte
if _, err := r.Read(pub[:]); err != nil {
return nil, err
}
node, err := fetchLightningNode(nodes, pub[:])
if err != nil {
return nil, fmt.Errorf("unable to fetch node: %x, %w",
pub[:], err)
}
edge.Node = &node
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the edge as is.
edge.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case err == io.ErrUnexpectedEOF:
case err == io.EOF:
case err != nil:
return nil, err
}
// See if optional fields are present.
if edge.MessageFlags.HasMaxHtlc() {
// The max_htlc field should be at the beginning of the opaque
// bytes.
opq := edge.ExtraOpaqueData
// If the max_htlc field is not present, it might be old data
// stored before this field was validated. We'll return the
// edge along with an error.
if len(opq) < 8 {
return edge, ErrEdgePolicyOptionalFieldNotFound
}
maxHtlc := byteOrder.Uint64(opq[:8])
edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc)
// Exclude the parsed field from the rest of the opaque data.
edge.ExtraOpaqueData = opq[8:]
}
return edge, nil
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"time"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
// Within the invoice bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// addIndexBucket is an index bucket that we'll use to create a
// monotonically increasing set of add indexes. Each time we add a new
// invoice, this sequence number will be incremented and then populated
// within the new invoice.
//
// In addition to this sequence number, we map:
//
// addIndexNo => invoiceKey
addIndexBucket = []byte("invoice-add-index")
// settleIndexBucket is an index bucket that we'll use to create a
// monotonically increasing integer for tracking a "settle index". Each
// time an invoice is settled, this sequence number will be incremented
// as populate within the newly settled invoice.
//
// In addition to this sequence number, we map:
//
// settleIndexNo => invoiceKey
settleIndexBucket = []byte("invoice-settle-index")
)
const (
// MaxMemoSize is maximum size of the memo field within invoices stored
// in the database.
MaxMemoSize = 1024
// MaxReceiptSize is the maximum size of the payment receipt stored
// within the database along side incoming/outgoing invoices.
MaxReceiptSize = 1024
// MaxPaymentRequestSize is the max size of a payment request for
// this invoice.
// TODO(halseth): determine the max length payment request when field
// lengths are final.
MaxPaymentRequestSize = 4096
// A set of tlv type definitions used to serialize invoice htlcs to the
// database.
chanIDType tlv.Type = 1
htlcIDType tlv.Type = 3
amtType tlv.Type = 5
acceptHeightType tlv.Type = 7
acceptTimeType tlv.Type = 9
resolveTimeType tlv.Type = 11
expiryHeightType tlv.Type = 13
stateType tlv.Type = 15
)
// ContractState describes the state the invoice is in.
type ContractState uint8
const (
// ContractOpen means the invoice has only been created.
ContractOpen ContractState = 0
// ContractSettled means the htlc is settled and the invoice has been
// paid.
ContractSettled ContractState = 1
// ContractCanceled means the invoice has been canceled.
ContractCanceled ContractState = 2
// ContractAccepted means the HTLC has been accepted but not settled
// yet.
ContractAccepted ContractState = 3
)
// String returns a human readable identifier for the ContractState type.
func (c ContractState) String() string {
switch c {
case ContractOpen:
return "Open"
case ContractSettled:
return "Settled"
case ContractCanceled:
return "Canceled"
case ContractAccepted:
return "Accepted"
}
return "Unknown"
}
// ContractTerm is a companion struct to the Invoice struct. This struct houses
// the necessary conditions required before the invoice can be considered fully
// settled by the payee.
type ContractTerm struct {
// PaymentPreimage is the preimage which is to be revealed in the
// occasion that an HTLC paying to the hash of this preimage is
// extended.
PaymentPreimage lntypes.Preimage
// Value is the expected amount of milli-satoshis to be paid to an HTLC
// which can be satisfied by the above preimage.
Value lnwire.MilliSatoshi
// State describes the state the invoice is in.
State ContractState
}
// Invoice is a payment invoice generated by a payee in order to request
// payment for some good or service. The inclusion of invoices within Lightning
// creates a payment work flow for merchants very similar to that of the
// existing financial system within PayPal, etc. Invoices are added to the
// database when a payment is requested, then can be settled manually once the
// payment is received at the upper layer. For record keeping purposes,
// invoices are never deleted from the database, instead a bit is toggled
// denoting the invoice has been fully settled. Within the database, all
// invoices must have a unique payment hash which is generated by taking the
// sha256 of the payment preimage.
type Invoice struct {
// Memo is an optional memo to be stored along side an invoice. The
// memo may contain further details pertaining to the invoice itself,
// or any other message which fits within the size constraints.
Memo []byte
// Receipt is an optional field dedicated for storing a
// cryptographically binding receipt of payment.
//
// TODO(roasbeef): document scheme.
Receipt []byte
// PaymentRequest is an optional field where a payment request created
// for this invoice can be stored.
PaymentRequest []byte
// FinalCltvDelta is the minimum required number of blocks before htlc
// expiry when the invoice is accepted.
FinalCltvDelta int32
// Expiry defines how long after creation this invoice should expire.
Expiry time.Duration
// CreationDate is the exact time the invoice was created.
CreationDate time.Time
// SettleDate is the exact time the invoice was settled.
SettleDate time.Time
// Terms are the contractual payment terms of the invoice. Once all the
// terms have been satisfied by the payer, then the invoice can be
// considered fully fulfilled.
//
// TODO(roasbeef): later allow for multiple terms to fulfill the final
// invoice: payment fragmentation, etc.
Terms ContractTerm
// AddIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all invoices created.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been added before they re-connected.
//
// NOTE: This index starts at 1.
AddIndex uint64
// SettleIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all settled invoices.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been settled before they re-connected.
//
// NOTE: This index starts at 1.
SettleIndex uint64
// AmtPaid is the final amount that we ultimately accepted for pay for
// this invoice. We specify this value independently as it's possible
// that the invoice originally didn't specify an amount, or the sender
// overpaid.
AmtPaid lnwire.MilliSatoshi
// Htlcs records all htlcs that paid to this invoice. Some of these
// htlcs may have been marked as canceled.
Htlcs map[CircuitKey]*InvoiceHTLC
}
// HtlcState defines the states an htlc paying to an invoice can be in.
type HtlcState uint8
// InvoiceHTLC contains details about an htlc paying to this invoice.
type InvoiceHTLC struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// AcceptHeight is the block height at which the invoice registry
// decided to accept this htlc as a payment to the invoice. At this
// height, the invoice cltv delay must have been met.
AcceptHeight uint32
// AcceptTime is the wall clock time at which the invoice registry
// decided to accept the htlc.
AcceptTime time.Time
// ResolveTime is the wall clock time at which the invoice registry
// decided to settle the htlc.
ResolveTime time.Time
// Expiry is the expiry height of this htlc.
Expiry uint32
// State indicates the state the invoice htlc is currently in. A
// canceled htlc isn't just removed from the invoice htlcs map, because
// we need AcceptHeight to properly cancel the htlc back.
State HtlcState
}
func validateInvoice(i *Invoice) error {
if len(i.Memo) > MaxMemoSize {
return fmt.Errorf("max length a memo is %v, and invoice "+
"of length %v was provided", MaxMemoSize, len(i.Memo))
}
if len(i.Receipt) > MaxReceiptSize {
return fmt.Errorf("max length a receipt is %v, and invoice "+
"of length %v was provided", MaxReceiptSize,
len(i.Receipt))
}
if len(i.PaymentRequest) > MaxPaymentRequestSize {
return fmt.Errorf("max length of payment request is %v, length "+
"provided was %v", MaxPaymentRequestSize,
len(i.PaymentRequest))
}
return nil
}
// FetchAllInvoices returns all invoices currently stored within the database.
// If the pendingOnly param is true, then only unsettled invoices will be
// returned, skipping all invoices that are fully settled.
func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
var invoices []Invoice
err := kvdb.View(d, func(tx kvdb.RTx) error {
invoiceB := tx.ReadBucket(invoiceBucket)
if invoiceB == nil {
return ErrNoInvoicesCreated
}
// Iterate through the entire key space of the top-level
// invoice bucket. If key with a non-nil value stores the next
// invoice ID which maps to the corresponding invoice.
return invoiceB.ForEach(func(k, v []byte) error {
if v == nil {
return nil
}
invoiceReader := bytes.NewReader(v)
invoice, err := deserializeInvoice(invoiceReader)
if err != nil {
return err
}
if pendingOnly &&
invoice.Terms.State == ContractSettled {
return nil
}
invoices = append(invoices, invoice)
return nil
})
}, func() {
invoices = nil
})
if err != nil {
return nil, err
}
return invoices, nil
}
// serializeInvoice serializes an invoice to a writer.
//
// Note: this function is in use for a migration. Before making changes that
// would modify the on disk format, make a copy of the original code and store
// it with the migration.
func serializeInvoice(w io.Writer, i *Invoice) error {
if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil {
return err
}
if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil {
return err
}
birthBytes, err := i.CreationDate.MarshalBinary()
if err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil {
return err
}
settleBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil {
return err
}
if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil {
return err
}
var scratch [8]byte
byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.Terms.State); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.AddIndex); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil {
return err
}
if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil {
return err
}
if err := serializeHtlcs(w, i.Htlcs); err != nil {
return err
}
return nil
}
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
// a writer.
func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
for key, htlc := range htlcs {
// Encode the htlc in a tlv stream.
chanID := key.ChanID.ToUint64()
amt := uint64(htlc.Amt)
acceptTime := uint64(htlc.AcceptTime.UnixNano())
resolveTime := uint64(htlc.ResolveTime.UnixNano())
state := uint8(htlc.State)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
tlv.MakePrimitiveRecord(amtType, &amt),
tlv.MakePrimitiveRecord(
acceptHeightType, &htlc.AcceptHeight,
),
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(stateType, &state),
)
if err != nil {
return err
}
var b bytes.Buffer
if err := tlvStream.Encode(&b); err != nil {
return err
}
// Write the length of the tlv stream followed by the stream
// bytes.
err = binary.Write(w, byteOrder, uint64(b.Len()))
if err != nil {
return err
}
if _, err := w.Write(b.Bytes()); err != nil {
return err
}
}
return nil
}
func deserializeInvoice(r io.Reader) (Invoice, error) {
var err error
invoice := Invoice{}
// TODO(roasbeef): use read full everywhere
invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "")
if err != nil {
return invoice, err
}
invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "")
if err != nil {
return invoice, err
}
invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "")
if err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil {
return invoice, err
}
var expiry int64
if err := binary.Read(r, byteOrder, &expiry); err != nil {
return invoice, err
}
invoice.Expiry = time.Duration(expiry)
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
if err != nil {
return invoice, err
}
if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil {
return invoice, err
}
settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled")
if err != nil {
return invoice, err
}
if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil {
return invoice, err
}
if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil {
return invoice, err
}
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return invoice, err
}
invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil {
return invoice, err
}
invoice.Htlcs, err = deserializeHtlcs(r)
if err != nil {
return Invoice{}, err
}
return invoice, nil
}
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
// as a map.
func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
htlcs := make(map[CircuitKey]*InvoiceHTLC, 0)
for {
// Read the length of the tlv stream for this htlc.
var streamLen uint64
if err := binary.Read(r, byteOrder, &streamLen); err != nil {
if err == io.EOF {
break
}
return nil, err
}
streamBytes := make([]byte, streamLen)
if _, err := r.Read(streamBytes); err != nil {
return nil, err
}
streamReader := bytes.NewReader(streamBytes)
// Decode the contents into the htlc fields.
var (
htlc InvoiceHTLC
key CircuitKey
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt uint64
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
tlv.MakePrimitiveRecord(amtType, &amt),
tlv.MakePrimitiveRecord(
acceptHeightType, &htlc.AcceptHeight,
),
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(stateType, &state),
)
if err != nil {
return nil, err
}
if err := tlvStream.Decode(streamReader); err != nil {
return nil, err
}
key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
htlc.AcceptTime = time.Unix(0, int64(acceptTime))
htlc.ResolveTime = time.Unix(0, int64(resolveTime))
htlc.State = HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt)
htlcs[key] = &htlc
}
return htlcs, nil
}
package migration_01_to_11
import (
"io"
)
// deserializeCloseChannelSummaryV6 reads the v6 database format for
// ChannelCloseSummary.
//
// NOTE: deprecated, only for migration.
func deserializeCloseChannelSummaryV6(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{}
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
)
if err != nil {
return nil, err
}
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
err = ReadElements(r, &c.RemoteCurrentRevocation)
switch {
case err == io.EOF:
return c, nil
// If we got a non-eof error, then we know there's an actually issue.
// Otherwise, it may have been the case that this summary didn't have
// the set of optional fields.
case err != nil:
return nil, err
}
if err := ReadChanConfig(r, &c.LocalChanConfig); err != nil {
return nil, err
}
// Finally, we'll attempt to read the next unrevoked commitment point
// for the remote party. If we closed the channel before receiving a
// funding locked message, then this can be nil. As a result, we'll use
// the same technique to read the field, only if there's still data
// left in the buffer.
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil && err != io.EOF {
// If we got a non-eof error, then we know there's an actually
// issue. Otherwise, it may have been the case that this
// summary didn't have the set of optional fields.
return nil, err
}
return c, nil
}
package migration_01_to_11
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
package migration_01_to_11
import (
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// metaBucket stores all the meta information concerning the state of
// the database.
metaBucket = []byte("metadata")
// dbVersionKey is a boltdb key and it's used for storing/retrieving
// current database version.
dbVersionKey = []byte("dbp")
)
// Meta structure holds the database meta information.
type Meta struct {
// DbVersionNumber is the current schema version of the database.
DbVersionNumber uint32
}
// putMeta is an internal helper function used in order to allow callers to
// re-use a database transaction. See the publicly exported PutMeta method for
// more information.
func putMeta(meta *Meta, tx kvdb.RwTx) error {
metaBucket, err := tx.CreateTopLevelBucket(metaBucket)
if err != nil {
return err
}
return putDbVersion(metaBucket, meta)
}
func putDbVersion(metaBucket kvdb.RwBucket, meta *Meta) error {
scratch := make([]byte, 4)
byteOrder.PutUint32(scratch, meta.DbVersionNumber)
return metaBucket.Put(dbVersionKey, scratch)
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
)
var (
// paymentBucket is the name of the bucket within the database that
// stores all data related to payments.
//
// Within the payments bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint64. BoltDB's sequence
// feature is used for generating monotonically increasing id.
//
// NOTE: Deprecated. Kept around for migration purposes.
paymentBucket = []byte("payments")
// paymentStatusBucket is the name of the bucket within the database
// that stores the status of a payment indexed by the payment's
// preimage.
//
// NOTE: Deprecated. Kept around for migration purposes.
paymentStatusBucket = []byte("payment-status")
)
// outgoingPayment represents a successful payment between the daemon and a
// remote node. Details such as the total fee paid, and the time of the payment
// are stored.
//
// NOTE: Deprecated. Kept around for migration purposes.
type outgoingPayment struct {
Invoice
// Fee is the total fee paid for the payment in milli-satoshis.
Fee lnwire.MilliSatoshi
// TotalTimeLock is the total cumulative time-lock in the HTLC extended
// from the second-to-last hop to the destination.
TimeLockLength uint32
// Path encodes the path the payment took through the network. The path
// excludes the outgoing node and consists of the hex-encoded
// compressed public key of each of the nodes involved in the payment.
Path [][33]byte
// PaymentPreimage is the preImage of a successful payment. This is used
// to calculate the PaymentHash as well as serve as a proof of payment.
PaymentPreimage [32]byte
}
// addPayment saves a successful payment to the database. It is assumed that
// all payment are sent using unique payment hashes.
//
// NOTE: Deprecated. Kept around for migration purposes.
func (db *DB) addPayment(payment *outgoingPayment) error {
// Validate the field of the inner voice within the outgoing payment,
// these must also adhere to the same constraints as regular invoices.
if err := validateInvoice(&payment.Invoice); err != nil {
return err
}
// We first serialize the payment before starting the database
// transaction so we can avoid creating a DB payment in the case of a
// serialization error.
var b bytes.Buffer
if err := serializeOutgoingPayment(&b, payment); err != nil {
return err
}
paymentBytes := b.Bytes()
return kvdb.Update(db, func(tx kvdb.RwTx) error {
payments, err := tx.CreateTopLevelBucket(paymentBucket)
if err != nil {
return err
}
// Obtain the new unique sequence number for this payment.
paymentID, err := payments.NextSequence()
if err != nil {
return err
}
// We use BigEndian for keys as it orders keys in
// ascending order. This allows bucket scans to order payments
// in the order in which they were created.
paymentIDBytes := make([]byte, 8)
binary.BigEndian.PutUint64(paymentIDBytes, paymentID)
return payments.Put(paymentIDBytes, paymentBytes)
}, func() {})
}
// fetchAllPayments returns all outgoing payments in DB.
//
// NOTE: Deprecated. Kept around for migration purposes.
func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) {
var payments []*outgoingPayment
err := kvdb.View(db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(paymentBucket)
if bucket == nil {
return ErrNoPaymentsCreated
}
return bucket.ForEach(func(k, v []byte) error {
// If the value is nil, then we ignore it as it may be
// a sub-bucket.
if v == nil {
return nil
}
r := bytes.NewReader(v)
payment, err := deserializeOutgoingPayment(r)
if err != nil {
return err
}
payments = append(payments, payment)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
}
return payments, nil
}
// fetchPaymentStatus returns the payment status for outgoing payment.
// If status of the payment isn't found, it will default to "StatusUnknown".
//
// NOTE: Deprecated. Kept around for migration purposes.
func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var paymentStatus = StatusUnknown
err := kvdb.View(db, func(tx kvdb.RTx) error {
var err error
paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash)
return err
}, func() {
paymentStatus = StatusUnknown
})
if err != nil {
return StatusUnknown, err
}
return paymentStatus, nil
}
// fetchPaymentStatusTx is a helper method that returns the payment status for
// outgoing payment. If status of the payment isn't found, it will default to
// "StatusUnknown". It accepts the boltdb transactions such that this method
// can be composed into other atomic operations.
//
// NOTE: Deprecated. Kept around for migration purposes.
func fetchPaymentStatusTx(tx kvdb.RTx, paymentHash [32]byte) (PaymentStatus, error) {
// The default status for all payments that aren't recorded in database.
var paymentStatus = StatusUnknown
bucket := tx.ReadBucket(paymentStatusBucket)
if bucket == nil {
return paymentStatus, nil
}
paymentStatusBytes := bucket.Get(paymentHash[:])
if paymentStatusBytes == nil {
return paymentStatus, nil
}
paymentStatus.FromBytes(paymentStatusBytes)
return paymentStatus, nil
}
func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error {
var scratch [8]byte
if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil {
return err
}
byteOrder.PutUint64(scratch[:], uint64(p.Fee))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
// First write out the length of the bytes to prefix the value.
pathLen := uint32(len(p.Path))
byteOrder.PutUint32(scratch[:4], pathLen)
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
// Then with the path written, we write out the series of public keys
// involved in the path.
for _, hop := range p.Path {
if _, err := w.Write(hop[:]); err != nil {
return err
}
}
byteOrder.PutUint32(scratch[:4], p.TimeLockLength)
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
if _, err := w.Write(p.PaymentPreimage[:]); err != nil {
return err
}
return nil
}
func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) {
var scratch [8]byte
p := &outgoingPayment{}
inv, err := deserializeInvoiceLegacy(r)
if err != nil {
return nil, err
}
p.Invoice = inv
if _, err := r.Read(scratch[:]); err != nil {
return nil, err
}
p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if _, err = r.Read(scratch[:4]); err != nil {
return nil, err
}
pathLen := byteOrder.Uint32(scratch[:4])
path := make([][33]byte, pathLen)
for i := uint32(0); i < pathLen; i++ {
if _, err := r.Read(path[i][:]); err != nil {
return nil, err
}
}
p.Path = path
if _, err = r.Read(scratch[:4]); err != nil {
return nil, err
}
p.TimeLockLength = byteOrder.Uint32(scratch[:4])
if _, err := r.Read(p.PaymentPreimage[:]); err != nil {
return nil, err
}
return p, nil
}
// serializePaymentAttemptInfoMigration9 is the serializePaymentAttemptInfo
// version as existed when migration #9 was created. We keep this around, along
// with the methods below to ensure that clients that upgrade will use the
// correct version of this method.
func serializePaymentAttemptInfoMigration9(w io.Writer, a *PaymentAttemptInfo) error {
if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil {
return err
}
if err := serializeRouteMigration9(w, a.Route); err != nil {
return err
}
return nil
}
func serializeHopMigration9(w io.Writer, h *Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
return nil
}
func serializeRouteMigration9(w io.Writer, r Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHopMigration9(w, h); err != nil {
return err
}
}
return nil
}
func deserializePaymentAttemptInfoMigration9(r io.Reader) (*PaymentAttemptInfo, error) {
a := &PaymentAttemptInfo{}
err := ReadElements(r, &a.PaymentID, &a.SessionKey)
if err != nil {
return nil, err
}
a.Route, err = deserializeRouteMigration9(r)
if err != nil {
return nil, err
}
return a, nil
}
func deserializeRouteMigration9(r io.Reader) (Route, error) {
rt := Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHopMigration9(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
return rt, nil
}
func deserializeHopMigration9(r io.Reader) (*Hop, error) {
h := &Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
return h, nil
}
// fetchPaymentsMigration9 returns all sent payments found in the DB using the
// payment attempt info format that was present as of migration #9. We need
// this as otherwise, the current FetchPayments version will use the latest
// decoding format. Note that we only need this for the
// TestOutgoingPaymentsMigration migration test case.
func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) {
var payments []*Payment
err := kvdb.View(db, func(tx kvdb.RTx) error {
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
return paymentsBucket.ForEach(func(k, v []byte) error {
bucket := paymentsBucket.NestedReadBucket(k)
if bucket == nil {
// We only expect sub-buckets to be found in
// this top-level bucket.
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
p, err := fetchPaymentMigration9(bucket)
if err != nil {
return err
}
payments = append(payments, p)
// For older versions of lnd, duplicate payments to a
// payment has was possible. These will be found in a
// sub-bucket indexed by their sequence number if
// available.
dup := bucket.NestedReadBucket(paymentDuplicateBucket)
if dup == nil {
return nil
}
return dup.ForEach(func(k, v []byte) error {
subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
// We one bucket for each duplicate to
// be found.
return fmt.Errorf("non bucket element" +
"in duplicate bucket")
}
p, err := fetchPaymentMigration9(subBucket)
if err != nil {
return err
}
payments = append(payments, p)
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
}
// Before returning, sort the payments by their sequence number.
sort.Slice(payments, func(i, j int) bool {
return payments[i].sequenceNum < payments[j].sequenceNum
})
return payments, nil
}
func fetchPaymentMigration9(bucket kvdb.RBucket) (*Payment, error) {
var (
err error
p = &Payment{}
)
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return nil, fmt.Errorf("sequence number not found")
}
p.sequenceNum = binary.BigEndian.Uint64(seqBytes)
// Get the payment status.
p.Status = fetchPaymentStatus(bucket)
// Get the PaymentCreationInfo.
b := bucket.Get(paymentCreationInfoKey)
if b == nil {
return nil, fmt.Errorf("creation info not found")
}
r := bytes.NewReader(b)
p.Info, err = deserializePaymentCreationInfo(r)
if err != nil {
return nil, err
}
// Get the PaymentAttemptInfo. This can be unset.
b = bucket.Get(paymentAttemptInfoKey)
if b != nil {
r = bytes.NewReader(b)
p.Attempt, err = deserializePaymentAttemptInfoMigration9(r)
if err != nil {
return nil, err
}
}
// Get the payment preimage. This is only found for
// completed payments.
b = bucket.Get(paymentSettleInfoKey)
if b != nil {
var preimg lntypes.Preimage
copy(preimg[:], b[:])
p.PaymentPreimage = &preimg
}
// Get failure reason if available.
b = bucket.Get(paymentFailInfoKey)
if b != nil {
reason := FailureReason(b[0])
p.Failure = &reason
}
return p, nil
}
package migration_01_to_11
import (
"bytes"
"io"
"github.com/lightningnetwork/lnd/kvdb"
)
// MigrateRouteSerialization migrates the way we serialize routes across the
// entire database. At the time of writing of this migration, this includes our
// payment attempts, as well as the payment results in mission control.
func MigrateRouteSerialization(tx kvdb.RwTx) error {
// First, we'll do all the payment attempts.
rootPaymentBucket := tx.ReadWriteBucket(paymentsRootBucket)
if rootPaymentBucket == nil {
return nil
}
// As we can't mutate a bucket while we're iterating over it with
// ForEach, we'll need to collect all the known payment hashes in
// memory first.
var payHashes [][]byte
err := rootPaymentBucket.ForEach(func(k, v []byte) error {
if v != nil {
return nil
}
payHashes = append(payHashes, k)
return nil
})
if err != nil {
return err
}
// Now that we have all the payment hashes, we can carry out the
// migration itself.
for _, payHash := range payHashes {
payHashBucket := rootPaymentBucket.NestedReadWriteBucket(payHash)
// First, we'll migrate the main (non duplicate) payment to
// this hash.
err := migrateAttemptEncoding(tx, payHashBucket)
if err != nil {
return err
}
// Now that we've migrated the main payment, we'll also check
// for any duplicate payments to the same payment hash.
dupBucket := payHashBucket.NestedReadWriteBucket(paymentDuplicateBucket)
// If there's no dup bucket, then we can move on to the next
// payment.
if dupBucket == nil {
continue
}
// Otherwise, we'll now iterate through all the duplicate pay
// hashes and migrate those.
var dupSeqNos [][]byte
err = dupBucket.ForEach(func(k, v []byte) error {
dupSeqNos = append(dupSeqNos, k)
return nil
})
if err != nil {
return err
}
// Now in this second pass, we'll re-serialize their duplicate
// payment attempts under the new encoding.
for _, seqNo := range dupSeqNos {
dupPayHashBucket := dupBucket.NestedReadWriteBucket(seqNo)
err := migrateAttemptEncoding(tx, dupPayHashBucket)
if err != nil {
return err
}
}
}
log.Infof("Migration of route/hop serialization complete!")
log.Infof("Migrating to new mission control store by clearing " +
"existing data")
resultsKey := []byte("missioncontrol-results")
err = tx.DeleteTopLevelBucket(resultsKey)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
log.Infof("Migration to new mission control completed!")
return nil
}
// migrateAttemptEncoding migrates payment attempts using the legacy format to
// the new format.
func migrateAttemptEncoding(tx kvdb.RwTx, payHashBucket kvdb.RwBucket) error {
payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey)
if payAttemptBytes == nil {
return nil
}
// For our migration, we'll first read out the existing payment attempt
// using the legacy serialization of the attempt.
payAttemptReader := bytes.NewReader(payAttemptBytes)
payAttempt, err := deserializePaymentAttemptInfoLegacy(
payAttemptReader,
)
if err != nil {
return err
}
// Now that we have the old attempts, we'll explicitly mark this as
// needing a legacy payload, since after this migration, the modern
// payload will be the default if signalled.
for _, hop := range payAttempt.Route.Hops {
hop.LegacyPayload = true
}
// Finally, we'll write out the payment attempt using the new encoding.
var b bytes.Buffer
err = serializePaymentAttemptInfo(&b, payAttempt)
if err != nil {
return err
}
return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes())
}
func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) {
a := &PaymentAttemptInfo{}
err := ReadElements(r, &a.PaymentID, &a.SessionKey)
if err != nil {
return nil, err
}
a.Route, err = deserializeRouteLegacy(r)
if err != nil {
return nil, err
}
return a, nil
}
func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error {
if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil {
return err
}
if err := serializeRouteLegacy(w, a.Route); err != nil {
return err
}
return nil
}
func deserializeHopLegacy(r io.Reader) (*Hop, error) {
h := &Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
return h, nil
}
func serializeHopLegacy(w io.Writer, h *Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
return nil
}
func deserializeRouteLegacy(r io.Reader) (Route, error) {
rt := Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHopLegacy(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
return rt, nil
}
func serializeRouteLegacy(w io.Writer, r Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHopLegacy(w, h); err != nil {
return err
}
}
return nil
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"fmt"
"io"
bitcoinCfg "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11/zpay32"
"github.com/lightningnetwork/lnd/kvdb"
litecoinCfg "github.com/ltcsuite/ltcd/chaincfg"
)
// MigrateInvoices adds invoice htlcs and a separate cltv delta field to the
// invoices.
func MigrateInvoices(tx kvdb.RwTx) error {
log.Infof("Migrating invoices to new invoice format")
invoiceB := tx.ReadWriteBucket(invoiceBucket)
if invoiceB == nil {
return nil
}
// Iterate through the entire key space of the top-level invoice bucket.
// If key with a non-nil value stores the next invoice ID which maps to
// the corresponding invoice. Store those keys first, because it isn't
// safe to modify the bucket inside a ForEach loop.
var invoiceKeys [][]byte
err := invoiceB.ForEach(func(k, v []byte) error {
if v == nil {
return nil
}
invoiceKeys = append(invoiceKeys, k)
return nil
})
if err != nil {
return err
}
nets := []*bitcoinCfg.Params{
&bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams,
&bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params,
}
ltcNets := []*litecoinCfg.Params{
&litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams,
&litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params,
}
for _, net := range ltcNets {
var convertedNet bitcoinCfg.Params
convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit
nets = append(nets, &convertedNet)
}
// Iterate over all stored keys and migrate the invoices.
for _, k := range invoiceKeys {
v := invoiceB.Get(k)
// Deserialize the invoice with the deserializing function that
// was in use for this version of the database.
invoiceReader := bytes.NewReader(v)
invoice, err := deserializeInvoiceLegacy(invoiceReader)
if err != nil {
return err
}
if invoice.Terms.State == ContractAccepted {
return fmt.Errorf("cannot upgrade with invoice(s) " +
"in accepted state, see release notes")
}
// Try to decode the payment request for every possible net to
// avoid passing a the active network to channeldb. This would
// be a layering violation, while this migration is only running
// once and will likely be removed in the future.
var payReq *zpay32.Invoice
for _, net := range nets {
payReq, err = zpay32.Decode(
string(invoice.PaymentRequest), net,
)
if err == nil {
break
}
}
if payReq == nil {
return fmt.Errorf("cannot decode payreq")
}
invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry())
invoice.Expiry = payReq.Expiry()
// Serialize the invoice in the new format and use it to replace
// the old invoice in the database.
var buf bytes.Buffer
if err := serializeInvoice(&buf, &invoice); err != nil {
return err
}
err = invoiceB.Put(k, buf.Bytes())
if err != nil {
return err
}
}
log.Infof("Migration of invoices completed!")
return nil
}
func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) {
var err error
invoice := Invoice{}
// TODO(roasbeef): use read full everywhere
invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "")
if err != nil {
return invoice, err
}
invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "")
if err != nil {
return invoice, err
}
invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "")
if err != nil {
return invoice, err
}
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
if err != nil {
return invoice, err
}
if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil {
return invoice, err
}
settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled")
if err != nil {
return invoice, err
}
if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil {
return invoice, err
}
if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil {
return invoice, err
}
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return invoice, err
}
invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil {
return invoice, err
}
if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil {
return invoice, err
}
return invoice, nil
}
// serializeInvoiceLegacy serializes an invoice in the format of the previous db
// version.
func serializeInvoiceLegacy(w io.Writer, i *Invoice) error {
if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil {
return err
}
birthBytes, err := i.CreationDate.MarshalBinary()
if err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil {
return err
}
settleBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil {
return err
}
if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil {
return err
}
var scratch [8]byte
byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.Terms.State); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.AddIndex); err != nil {
return err
}
if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil {
return err
}
if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil {
return err
}
return nil
}
package migration_01_to_11
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
)
// MigrateNodeAndEdgeUpdateIndex is a migration function that will update the
// database from version 0 to version 1. In version 1, we add two new indexes
// (one for nodes and one for edges) to keep track of the last time a node or
// edge was updated on the network. These new indexes allow us to implement the
// new graph sync protocol added.
func MigrateNodeAndEdgeUpdateIndex(tx kvdb.RwTx) error {
// First, we'll populating the node portion of the new index. Before we
// can add new values to the index, we'll first create the new bucket
// where these items will be housed.
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return fmt.Errorf("unable to create node bucket: %w", err)
}
nodeUpdateIndex, err := nodes.CreateBucketIfNotExists(
nodeUpdateIndexBucket,
)
if err != nil {
return fmt.Errorf("unable to create node update index: %w", err)
}
log.Infof("Populating new node update index bucket")
// Now that we know the bucket has been created, we'll iterate over the
// entire node bucket so we can add the (updateTime || nodePub) key
// into the node update index.
err = nodes.ForEach(func(nodePub, nodeInfo []byte) error {
if len(nodePub) != 33 {
return nil
}
log.Tracef("Adding %x to node update index", nodePub)
// The first 8 bytes of a node's serialize data is the update
// time, so we can extract that without decoding the entire
// structure.
updateTime := nodeInfo[:8]
// Now that we have the update time, we can construct the key
// to insert into the index.
var indexKey [8 + 33]byte
copy(indexKey[:8], updateTime)
copy(indexKey[8:], nodePub)
return nodeUpdateIndex.Put(indexKey[:], nil)
})
if err != nil {
return fmt.Errorf("unable to update node indexes: %w", err)
}
log.Infof("Populating new edge update index bucket")
// With the set of nodes updated, we'll now update all edges to have a
// corresponding entry in the edge update index.
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return fmt.Errorf("unable to create edge bucket: %w", err)
}
edgeUpdateIndex, err := edges.CreateBucketIfNotExists(
edgeUpdateIndexBucket,
)
if err != nil {
return fmt.Errorf("unable to create edge update index: %w", err)
}
// We'll now run through each edge policy in the database, and update
// the index to ensure each edge has the proper record.
err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error {
if len(edgeKey) != 41 {
return nil
}
// Now that we know this is the proper record, we'll grab the
// channel ID (last 8 bytes of the key), and then decode the
// edge policy so we can access the update time.
chanID := edgeKey[33:]
edgePolicyReader := bytes.NewReader(edgePolicyBytes)
edgePolicy, err := deserializeChanEdgePolicy(
edgePolicyReader, nodes,
)
if err != nil {
return err
}
log.Tracef("Adding chan_id=%v to edge update index",
edgePolicy.ChannelID)
// We'll now construct the index key using the channel ID, and
// the last time it was updated: (updateTime || chanID).
var indexKey [8 + 8]byte
byteOrder.PutUint64(
indexKey[:], uint64(edgePolicy.LastUpdate.Unix()),
)
copy(indexKey[8:], chanID)
return edgeUpdateIndex.Put(indexKey[:], nil)
})
if err != nil {
return fmt.Errorf("unable to update edge indexes: %w", err)
}
log.Infof("Migration to node and edge update indexes complete!")
return nil
}
// MigrateInvoiceTimeSeries is a database migration that assigns all existing
// invoices an index in the add and/or the settle index. Additionally, all
// existing invoices will have their bytes padded out in order to encode the
// add+settle index as well as the amount paid.
func MigrateInvoiceTimeSeries(tx kvdb.RwTx) error {
invoices, err := tx.CreateTopLevelBucket(invoiceBucket)
if err != nil {
return err
}
addIndex, err := invoices.CreateBucketIfNotExists(
addIndexBucket,
)
if err != nil {
return err
}
settleIndex, err := invoices.CreateBucketIfNotExists(
settleIndexBucket,
)
if err != nil {
return err
}
log.Infof("Migrating invoice database to new time series format")
// Now that we have all the buckets we need, we'll run through each
// invoice in the database, and update it to reflect the new format
// expected post migration.
// NOTE: we store the converted invoices and put them back into the
// database after the loop, since modifying the bucket within the
// ForEach loop is not safe.
var invoicesKeys [][]byte
var invoicesValues [][]byte
err = invoices.ForEach(func(invoiceNum, invoiceBytes []byte) error {
// If this is a sub bucket, then we'll skip it.
if invoiceBytes == nil {
return nil
}
// First, we'll make a copy of the encoded invoice bytes.
invoiceBytesCopy := make([]byte, len(invoiceBytes))
copy(invoiceBytesCopy, invoiceBytes)
// With the bytes copied over, we'll append 24 additional
// bytes. We do this so we can decode the invoice under the new
// serialization format.
padding := bytes.Repeat([]byte{0}, 24)
invoiceBytesCopy = append(invoiceBytesCopy, padding...)
invoiceReader := bytes.NewReader(invoiceBytesCopy)
invoice, err := deserializeInvoiceLegacy(invoiceReader)
if err != nil {
return fmt.Errorf("unable to decode invoice: %w", err)
}
// Now that we have the fully decoded invoice, we can update
// the various indexes that we're added, and finally the
// invoice itself before re-inserting it.
// First, we'll get the new sequence in the addIndex in order
// to create the proper mapping.
nextAddSeqNo, err := addIndex.NextSequence()
if err != nil {
return err
}
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
err = addIndex.Put(seqNoBytes[:], invoiceNum[:])
if err != nil {
return err
}
log.Tracef("Adding invoice (preimage=%x, add_index=%v) to add "+
"time series", invoice.Terms.PaymentPreimage[:],
nextAddSeqNo)
// Next, we'll check if the invoice has been settled or not. If
// so, then we'll also add it to the settle index.
var nextSettleSeqNo uint64
if invoice.Terms.State == ContractSettled {
nextSettleSeqNo, err = settleIndex.NextSequence()
if err != nil {
return err
}
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
err := settleIndex.Put(seqNoBytes[:], invoiceNum)
if err != nil {
return err
}
invoice.AmtPaid = invoice.Terms.Value
log.Tracef("Adding invoice (preimage=%x, "+
"settle_index=%v) to add time series",
invoice.Terms.PaymentPreimage[:],
nextSettleSeqNo)
}
// Finally, we'll update the invoice itself with the new
// indexing information as well as the amount paid if it has
// been settled or not.
invoice.AddIndex = nextAddSeqNo
invoice.SettleIndex = nextSettleSeqNo
// We've fully migrated an invoice, so we'll now update the
// invoice in-place.
var b bytes.Buffer
if err := serializeInvoiceLegacy(&b, &invoice); err != nil {
return err
}
// Save the key and value pending update for after the ForEach
// is done.
invoicesKeys = append(invoicesKeys, invoiceNum)
invoicesValues = append(invoicesValues, b.Bytes())
return nil
})
if err != nil {
return err
}
// Now put the converted invoices into the DB.
for i := range invoicesKeys {
key := invoicesKeys[i]
value := invoicesValues[i]
if err := invoices.Put(key, value); err != nil {
return err
}
}
log.Infof("Migration to invoice time series index complete!")
return nil
}
// MigrateInvoiceTimeSeriesOutgoingPayments is a follow up to the
// migrateInvoiceTimeSeries migration. As at the time of writing, the
// OutgoingPayment struct embeddeds an instance of the Invoice struct. As a
// result, we also need to migrate the internal invoice to the new format.
func MigrateInvoiceTimeSeriesOutgoingPayments(tx kvdb.RwTx) error {
payBucket := tx.ReadWriteBucket(paymentBucket)
if payBucket == nil {
return nil
}
log.Infof("Migrating invoice database to new outgoing payment format")
// We store the keys and values we want to modify since it is not safe
// to modify them directly within the ForEach loop.
var paymentKeys [][]byte
var paymentValues [][]byte
err := payBucket.ForEach(func(payID, paymentBytes []byte) error {
log.Tracef("Migrating payment %x", payID[:])
// The internal invoices for each payment only contain a
// populated contract term, and creation date, as a result,
// most of the bytes will be "empty".
// We'll calculate the end of the invoice index assuming a
// "minimal" index that's embedded within the greater
// OutgoingPayment. The breakdown is:
// 3 bytes empty var bytes, 16 bytes creation date, 16 bytes
// settled date, 32 bytes payment pre-image, 8 bytes value, 1
// byte settled.
endOfInvoiceIndex := 1 + 1 + 1 + 16 + 16 + 32 + 8 + 1
// We'll now extract the prefix of the pure invoice embedded
// within.
invoiceBytes := paymentBytes[:endOfInvoiceIndex]
// With the prefix extracted, we'll copy over the invoice, and
// also add padding for the new 24 bytes of fields, and finally
// append the remainder of the outgoing payment.
paymentCopy := make([]byte, len(invoiceBytes))
copy(paymentCopy[:], invoiceBytes)
padding := bytes.Repeat([]byte{0}, 24)
paymentCopy = append(paymentCopy, padding...)
paymentCopy = append(
paymentCopy, paymentBytes[endOfInvoiceIndex:]...,
)
// At this point, we now have the new format of the outgoing
// payments, we'll attempt to deserialize it to ensure the
// bytes are properly formatted.
paymentReader := bytes.NewReader(paymentCopy)
_, err := deserializeOutgoingPayment(paymentReader)
if err != nil {
return fmt.Errorf("unable to deserialize payment: %w",
err)
}
// Now that we know the modifications was successful, we'll
// store it to our slice of keys and values, and write it back
// to disk in the new format after the ForEach loop is over.
paymentKeys = append(paymentKeys, payID)
paymentValues = append(paymentValues, paymentCopy)
return nil
})
if err != nil {
return err
}
// Finally store the updated payments to the bucket.
for i := range paymentKeys {
key := paymentKeys[i]
value := paymentValues[i]
if err := payBucket.Put(key, value); err != nil {
return err
}
}
log.Infof("Migration to outgoing payment invoices complete!")
return nil
}
// MigrateEdgePolicies is a migration function that will update the edges
// bucket. It ensure that edges with unknown policies will also have an entry
// in the bucket. After the migration, there will be two edge entries for
// every channel, regardless of whether the policies are known.
func MigrateEdgePolicies(tx kvdb.RwTx) error {
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return nil
}
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return nil
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return nil
}
// checkKey gets the policy from the database with a low-level call
// so that it is still possible to distinguish between unknown and
// not present.
checkKey := func(channelId uint64, keyBytes []byte) error {
var channelID [8]byte
byteOrder.PutUint64(channelID[:], channelId)
_, err := fetchChanEdgePolicy(edges,
channelID[:], keyBytes, nodes)
if err == ErrEdgeNotFound {
log.Tracef("Adding unknown edge policy present for node %x, channel %v",
keyBytes, channelId)
err := putChanEdgePolicyUnknown(edges, channelId, keyBytes)
if err != nil {
return err
}
return nil
}
return err
}
// Iterate over all channels and check both edge policies.
err := edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
infoReader := bytes.NewReader(edgeInfoBytes)
edgeInfo, err := deserializeChanEdgeInfo(infoReader)
if err != nil {
return err
}
for _, key := range [][]byte{edgeInfo.NodeKey1Bytes[:],
edgeInfo.NodeKey2Bytes[:]} {
if err := checkKey(edgeInfo.ChannelID, key); err != nil {
return err
}
}
return nil
})
if err != nil {
return fmt.Errorf("unable to update edge policies: %w", err)
}
log.Infof("Migration of edge policies complete!")
return nil
}
// PaymentStatusesMigration is a database migration intended for adding payment
// statuses for each existing payment entity in bucket to be able control
// transitions of statuses and prevent cases such as double payment
func PaymentStatusesMigration(tx kvdb.RwTx) error {
// Get the bucket dedicated to storing statuses of payments,
// where a key is payment hash, value is payment status.
paymentStatuses, err := tx.CreateTopLevelBucket(paymentStatusBucket)
if err != nil {
return err
}
log.Infof("Migrating database to support payment statuses")
circuitAddKey := []byte("circuit-adds")
circuits := tx.ReadWriteBucket(circuitAddKey)
if circuits != nil {
log.Infof("Marking all known circuits with status InFlight")
err = circuits.ForEach(func(k, v []byte) error {
// Parse the first 8 bytes as the short chan ID for the
// circuit. We'll skip all short chan IDs are not
// locally initiated, which includes all non-zero short
// chan ids.
chanID := binary.BigEndian.Uint64(k[:8])
if chanID != 0 {
return nil
}
// The payment hash is the third item in the serialized
// payment circuit. The first two items are an AddRef
// (10 bytes) and the incoming circuit key (16 bytes).
const payHashOffset = 10 + 16
paymentHash := v[payHashOffset : payHashOffset+32]
return paymentStatuses.Put(
paymentHash[:], StatusInFlight.Bytes(),
)
})
if err != nil {
return err
}
}
log.Infof("Marking all existing payments with status Completed")
// Get the bucket dedicated to storing payments
bucket := tx.ReadWriteBucket(paymentBucket)
if bucket == nil {
return nil
}
// For each payment in the bucket, deserialize the payment and mark it
// as completed.
err = bucket.ForEach(func(k, v []byte) error {
// Ignores if it is sub-bucket.
if v == nil {
return nil
}
r := bytes.NewReader(v)
payment, err := deserializeOutgoingPayment(r)
if err != nil {
return err
}
// Calculate payment hash for current payment.
paymentHash := sha256.Sum256(payment.PaymentPreimage[:])
// Update status for current payment to completed. If it fails,
// the migration is aborted and the payment bucket is returned
// to its previous state.
return paymentStatuses.Put(paymentHash[:], StatusSucceeded.Bytes())
})
if err != nil {
return err
}
log.Infof("Migration of payment statuses complete!")
return nil
}
// MigratePruneEdgeUpdateIndex is a database migration that attempts to resolve
// some lingering bugs with regards to edge policies and their update index.
// Stale entries within the edge update index were not being properly pruned due
// to a miscalculation on the offset of an edge's policy last update. This
// migration also fixes the case where the public keys within edge policies were
// being serialized with an extra byte, causing an even greater error when
// attempting to perform the offset calculation described earlier.
func MigratePruneEdgeUpdateIndex(tx kvdb.RwTx) error {
// To begin the migration, we'll retrieve the update index bucket. If it
// does not exist, we have nothing left to do so we can simply exit.
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return nil
}
edgeUpdateIndex := edges.NestedReadWriteBucket(edgeUpdateIndexBucket)
if edgeUpdateIndex == nil {
return nil
}
// Retrieve some buckets that will be needed later on. These should
// already exist given the assumption that the buckets above do as
// well.
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return fmt.Errorf("error creating edge index bucket: %w", err)
}
if edgeIndex == nil {
return fmt.Errorf("unable to create/fetch edge index " +
"bucket")
}
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return fmt.Errorf("unable to make node bucket")
}
log.Info("Migrating database to properly prune edge update index")
// We'll need to properly prune all the outdated entries within the edge
// update index. To do so, we'll gather all of the existing policies
// within the graph to re-populate them later on.
var edgeKeys [][]byte
err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error {
// All valid entries are indexed by a public key (33 bytes)
// followed by a channel ID (8 bytes), so we'll skip any entries
// with keys that do not match this.
if len(edgeKey) != 33+8 {
return nil
}
edgeKeys = append(edgeKeys, edgeKey)
return nil
})
if err != nil {
return fmt.Errorf("unable to gather existing edge policies: %w",
err)
}
log.Info("Constructing set of edge update entries to purge.")
// Build the set of keys that we will remove from the edge update index.
// This will include all keys contained within the bucket.
var updateKeysToRemove [][]byte
err = edgeUpdateIndex.ForEach(func(updKey, _ []byte) error {
updateKeysToRemove = append(updateKeysToRemove, updKey)
return nil
})
if err != nil {
return fmt.Errorf("unable to gather existing edge updates: %w",
err)
}
log.Infof("Removing %d entries from edge update index.",
len(updateKeysToRemove))
// With the set of keys contained in the edge update index constructed,
// we'll proceed in purging all of them from the index.
for _, updKey := range updateKeysToRemove {
if err := edgeUpdateIndex.Delete(updKey); err != nil {
return err
}
}
log.Infof("Repopulating edge update index with %d valid entries.",
len(edgeKeys))
// For each edge key, we'll retrieve the policy, deserialize it, and
// re-add it to the different buckets. By doing so, we'll ensure that
// all existing edge policies are serialized correctly within their
// respective buckets and that the correct entries are populated within
// the edge update index.
for _, edgeKey := range edgeKeys {
edgePolicyBytes := edges.Get(edgeKey)
// Skip any entries with unknown policies as there will not be
// any entries for them in the edge update index.
if bytes.Equal(edgePolicyBytes[:], unknownPolicy) {
continue
}
edgePolicy, err := deserializeChanEdgePolicy(
bytes.NewReader(edgePolicyBytes), nodes,
)
if err != nil {
return err
}
_, err = updateEdgePolicy(tx, edgePolicy)
if err != nil {
return err
}
}
log.Info("Migration to properly prune edge update index complete!")
return nil
}
// MigrateOptionalChannelCloseSummaryFields migrates the serialized format of
// ChannelCloseSummary to a format where optional fields' presence is indicated
// with boolean markers.
func MigrateOptionalChannelCloseSummaryFields(tx kvdb.RwTx) error {
closedChanBucket := tx.ReadWriteBucket(closedChannelBucket)
if closedChanBucket == nil {
return nil
}
log.Info("Migrating to new closed channel format...")
// We store the converted keys and values and put them back into the
// database after the loop, since modifying the bucket within the
// ForEach loop is not safe.
var closedChansKeys [][]byte
var closedChansValues [][]byte
err := closedChanBucket.ForEach(func(chanID, summary []byte) error {
r := bytes.NewReader(summary)
// Read the old (v6) format from the database.
c, err := deserializeCloseChannelSummaryV6(r)
if err != nil {
return err
}
// Serialize using the new format, and put back into the
// bucket.
var b bytes.Buffer
if err := serializeChannelCloseSummary(&b, c); err != nil {
return err
}
// Now that we know the modifications was successful, we'll
// Store the key and value to our slices, and write it back to
// disk in the new format after the ForEach loop is over.
closedChansKeys = append(closedChansKeys, chanID)
closedChansValues = append(closedChansValues, b.Bytes())
return nil
})
if err != nil {
return fmt.Errorf("unable to update closed channels: %w", err)
}
// Now put the new format back into the DB.
for i := range closedChansKeys {
key := closedChansKeys[i]
value := closedChansValues[i]
if err := closedChanBucket.Put(key, value); err != nil {
return err
}
}
log.Info("Migration to new closed channel format complete!")
return nil
}
var messageStoreBucket = []byte("message-store")
// MigrateGossipMessageStoreKeys migrates the key format for gossip messages
// found in the message store to a new one that takes into consideration the of
// the message being stored.
func MigrateGossipMessageStoreKeys(tx kvdb.RwTx) error {
// We'll start by retrieving the bucket in which these messages are
// stored within. If there isn't one, there's nothing left for us to do
// so we can avoid the migration.
messageStore := tx.ReadWriteBucket(messageStoreBucket)
if messageStore == nil {
return nil
}
log.Info("Migrating to the gossip message store new key format")
// Otherwise we'll proceed with the migration. We'll start by coalescing
// all the current messages within the store, which are indexed by the
// public key of the peer which they should be sent to, followed by the
// short channel ID of the channel for which the message belongs to. We
// should only expect to find channel announcement signatures as that
// was the only support message type previously.
msgs := make(map[[33 + 8]byte]*lnwire.AnnounceSignatures)
err := messageStore.ForEach(func(k, v []byte) error {
var msgKey [33 + 8]byte
copy(msgKey[:], k)
msg := &lnwire.AnnounceSignatures{}
if err := msg.Decode(bytes.NewReader(v), 0); err != nil {
return err
}
msgs[msgKey] = msg
return nil
})
if err != nil {
return err
}
// Then, we'll go over all of our messages, remove their previous entry,
// and add another with the new key format. Once we've done this for
// every message, we can consider the migration complete.
for oldMsgKey, msg := range msgs {
if err := messageStore.Delete(oldMsgKey[:]); err != nil {
return err
}
// Construct the new key for which we'll find this message with
// in the store. It'll be the same as the old, but we'll also
// include the message type.
var msgType [2]byte
binary.BigEndian.PutUint16(msgType[:], uint16(msg.MsgType()))
newMsgKey := append(oldMsgKey[:], msgType[:]...)
// Serialize the message with its wire encoding.
var b bytes.Buffer
if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil {
return err
}
if err := messageStore.Put(newMsgKey, b.Bytes()); err != nil {
return err
}
}
log.Info("Migration to the gossip message store new key format complete!")
return nil
}
// MigrateOutgoingPayments moves the OutgoingPayments into a new bucket format
// where they all reside in a top-level bucket indexed by the payment hash. In
// this sub-bucket we store information relevant to this payment, such as the
// payment status.
//
// Since the router cannot handle resumed payments that have the status
// InFlight (we have no PaymentAttemptInfo available for pre-migration
// payments) we delete those statuses, so only Completed payments remain in the
// new bucket structure.
func MigrateOutgoingPayments(tx kvdb.RwTx) error {
log.Infof("Migrating outgoing payments to new bucket structure")
oldPayments := tx.ReadWriteBucket(paymentBucket)
// Return early if there are no payments to migrate.
if oldPayments == nil {
log.Infof("No outgoing payments found, nothing to migrate.")
return nil
}
newPayments, err := tx.CreateTopLevelBucket(paymentsRootBucket)
if err != nil {
return err
}
// Helper method to get the source pubkey. We define it such that we
// only attempt to fetch it if needed.
sourcePub := func() ([33]byte, error) {
var pub [33]byte
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return pub, ErrGraphNotFound
}
selfPub := nodes.Get(sourceKey)
if selfPub == nil {
return pub, ErrSourceNodeNotSet
}
copy(pub[:], selfPub[:])
return pub, nil
}
err = oldPayments.ForEach(func(k, v []byte) error {
// Ignores if it is sub-bucket.
if v == nil {
return nil
}
// Read the old payment format.
r := bytes.NewReader(v)
payment, err := deserializeOutgoingPayment(r)
if err != nil {
return err
}
// Calculate payment hash from the payment preimage.
paymentHash := sha256.Sum256(payment.PaymentPreimage[:])
// Now create and add a PaymentCreationInfo to the bucket.
c := &PaymentCreationInfo{
PaymentHash: paymentHash,
Value: payment.Terms.Value,
CreationDate: payment.CreationDate,
PaymentRequest: payment.PaymentRequest,
}
var infoBuf bytes.Buffer
if err := serializePaymentCreationInfo(&infoBuf, c); err != nil {
return err
}
sourcePubKey, err := sourcePub()
if err != nil {
return err
}
// Do the same for the PaymentAttemptInfo.
totalAmt := payment.Terms.Value + payment.Fee
rt := Route{
TotalTimeLock: payment.TimeLockLength,
TotalAmount: totalAmt,
SourcePubKey: sourcePubKey,
Hops: []*Hop{},
}
for _, hop := range payment.Path {
rt.Hops = append(rt.Hops, &Hop{
PubKeyBytes: hop,
AmtToForward: totalAmt,
})
}
// Since the old format didn't store the fee for individual
// hops, we let the last hop eat the whole fee for the total to
// add up.
if len(rt.Hops) > 0 {
rt.Hops[len(rt.Hops)-1].AmtToForward = payment.Terms.Value
}
// Since we don't have the session key for old payments, we
// create a random one to be able to serialize the attempt
// info.
priv, _ := btcec.NewPrivateKey()
s := &PaymentAttemptInfo{
PaymentID: 0, // unknown.
SessionKey: priv, // unknown.
Route: rt,
}
var attemptBuf bytes.Buffer
if err := serializePaymentAttemptInfoMigration9(&attemptBuf, s); err != nil {
return err
}
// Reuse the existing payment sequence number.
var seqNum [8]byte
copy(seqNum[:], k)
// Create a bucket indexed by the payment hash.
bucket, err := newPayments.CreateBucket(paymentHash[:])
// If the bucket already exists, it means that we are migrating
// from a database containing duplicate payments to a payment
// hash. To keep this information, we store such duplicate
// payments in a sub-bucket.
if err == kvdb.ErrBucketExists {
pHashBucket := newPayments.NestedReadWriteBucket(paymentHash[:])
// Create a bucket for duplicate payments within this
// payment hash's bucket.
dup, err := pHashBucket.CreateBucketIfNotExists(
paymentDuplicateBucket,
)
if err != nil {
return err
}
// Each duplicate will get its own sub-bucket within
// this bucket, so use their sequence number to index
// them by.
bucket, err = dup.CreateBucket(seqNum[:])
if err != nil {
return err
}
} else if err != nil {
return err
}
// Store the payment's information to the bucket.
err = bucket.Put(paymentSequenceKey, seqNum[:])
if err != nil {
return err
}
err = bucket.Put(paymentCreationInfoKey, infoBuf.Bytes())
if err != nil {
return err
}
err = bucket.Put(paymentAttemptInfoKey, attemptBuf.Bytes())
if err != nil {
return err
}
err = bucket.Put(paymentSettleInfoKey, payment.PaymentPreimage[:])
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
// To continue producing unique sequence numbers, we set the sequence
// of the new bucket to that of the old one.
seq := oldPayments.Sequence()
if err := newPayments.SetSequence(seq); err != nil {
return err
}
// Now we delete the old buckets. Deleting the payment status buckets
// deletes all payment statuses other than Complete.
err = tx.DeleteTopLevelBucket(paymentStatusBucket)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
// Finally delete the old payment bucket.
err = tx.DeleteTopLevelBucket(paymentBucket)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
log.Infof("Migration of outgoing payment bucket structure completed!")
return nil
}
package migration_01_to_11
import "time"
const (
// DefaultRejectCacheSize is the default number of rejectCacheEntries to
// cache for use in the rejection cache of incoming gossip traffic. This
// produces a cache size of around 1MB.
DefaultRejectCacheSize = 50000
// DefaultChannelCacheSize is the default number of ChannelEdges cached
// in order to reply to gossip queries. This produces a cache size of
// around 40MB.
DefaultChannelCacheSize = 20000
// DefaultDBTimeout specifies the default timeout value when opening
// the bbolt database.
DefaultDBTimeout = time.Second * 60
)
// Options holds parameters for tuning and customizing a channeldb.DB.
type Options struct {
// RejectCacheSize is the maximum number of rejectCacheEntries to hold
// in the rejection cache.
RejectCacheSize int
// ChannelCacheSize is the maximum number of ChannelEdges to hold in the
// channel cache.
ChannelCacheSize int
// NoFreelistSync, if true, prevents the database from syncing its
// freelist to disk, resulting in improved performance at the expense of
// increased startup time.
NoFreelistSync bool
// DBTimeout specifies the timeout value to use when opening the wallet
// database.
DBTimeout time.Duration
}
// DefaultOptions returns an Options populated with default values.
func DefaultOptions() Options {
return Options{
RejectCacheSize: DefaultRejectCacheSize,
ChannelCacheSize: DefaultChannelCacheSize,
NoFreelistSync: true,
DBTimeout: DefaultDBTimeout,
}
}
// OptionModifier is a function signature for modifying the default Options.
type OptionModifier func(*Options)
package migration_01_to_11
import "github.com/lightningnetwork/lnd/kvdb"
// fetchPaymentStatus fetches the payment status of the payment. If the payment
// isn't found, it will default to "StatusUnknown".
func fetchPaymentStatus(bucket kvdb.RBucket) PaymentStatus {
if bucket.Get(paymentSettleInfoKey) != nil {
return StatusSucceeded
}
if bucket.Get(paymentFailInfoKey) != nil {
return StatusFailed
}
if bucket.Get(paymentCreationInfoKey) != nil {
return StatusInFlight
}
return StatusUnknown
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sort"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// paymentsRootBucket is the name of the top-level bucket within the
// database that stores all data related to payments. Within this
// bucket, each payment hash its own sub-bucket keyed by its payment
// hash.
//
// Bucket hierarchy:
//
// root-bucket
// |
// |-- <paymenthash>
// | |--sequence-key: <sequence number>
// | |--creation-info-key: <creation info>
// | |--attempt-info-key: <attempt info>
// | |--settle-info-key: <settle info>
// | |--fail-info-key: <fail info>
// | |
// | |--duplicate-bucket (only for old, completed payments)
// | |
// | |-- <seq-num>
// | | |--sequence-key: <sequence number>
// | | |--creation-info-key: <creation info>
// | | |--attempt-info-key: <attempt info>
// | | |--settle-info-key: <settle info>
// | | |--fail-info-key: <fail info>
// | |
// | |-- <seq-num>
// | | |
// | ... ...
// |
// |-- <paymenthash>
// | |
// | ...
// ...
//
paymentsRootBucket = []byte("payments-root-bucket")
// paymentDublicateBucket is the name of a optional sub-bucket within
// the payment hash bucket, that is used to hold duplicate payments to
// a payment hash. This is needed to support information from earlier
// versions of lnd, where it was possible to pay to a payment hash more
// than once.
paymentDuplicateBucket = []byte("payment-duplicate-bucket")
// paymentSequenceKey is a key used in the payment's sub-bucket to
// store the sequence number of the payment.
paymentSequenceKey = []byte("payment-sequence-key")
// paymentCreationInfoKey is a key used in the payment's sub-bucket to
// store the creation info of the payment.
paymentCreationInfoKey = []byte("payment-creation-info")
// paymentAttemptInfoKey is a key used in the payment's sub-bucket to
// store the info about the latest attempt that was done for the
// payment in question.
paymentAttemptInfoKey = []byte("payment-attempt-info")
// paymentSettleInfoKey is a key used in the payment's sub-bucket to
// store the settle info of the payment.
paymentSettleInfoKey = []byte("payment-settle-info")
// paymentFailInfoKey is a key used in the payment's sub-bucket to
// store information about the reason a payment failed.
paymentFailInfoKey = []byte("payment-fail-info")
)
// FailureReason encodes the reason a payment ultimately failed.
type FailureReason byte
const (
// FailureReasonTimeout indicates that the payment did timeout before a
// successful payment attempt was made.
FailureReasonTimeout FailureReason = 0
// FailureReasonNoRoute indicates no successful route to the
// destination was found during path finding.
FailureReasonNoRoute FailureReason = 1
// FailureReasonError indicates that an unexpected error happened during
// payment.
FailureReasonError FailureReason = 2
// FailureReasonIncorrectPaymentDetails indicates that either the hash
// is unknown or the final cltv delta or amount is incorrect.
FailureReasonIncorrectPaymentDetails FailureReason = 3
// TODO(halseth): cancel state.
// TODO(joostjager): Add failure reasons for:
// LocalLiquidityInsufficient, RemoteCapacityInsufficient.
)
// String returns a human readable FailureReason
func (r FailureReason) String() string {
switch r {
case FailureReasonTimeout:
return "timeout"
case FailureReasonNoRoute:
return "no_route"
case FailureReasonError:
return "error"
case FailureReasonIncorrectPaymentDetails:
return "incorrect_payment_details"
}
return "unknown"
}
// PaymentStatus represent current status of payment
type PaymentStatus byte
const (
// StatusUnknown is the status where a payment has never been initiated
// and hence is unknown.
StatusUnknown PaymentStatus = 0
// StatusInFlight is the status where a payment has been initiated, but
// a response has not been received.
StatusInFlight PaymentStatus = 1
// StatusSucceeded is the status where a payment has been initiated and
// the payment was completed successfully.
StatusSucceeded PaymentStatus = 2
// StatusFailed is the status where a payment has been initiated and a
// failure result has come back.
StatusFailed PaymentStatus = 3
)
// Bytes returns status as slice of bytes.
func (ps PaymentStatus) Bytes() []byte {
return []byte{byte(ps)}
}
// FromBytes sets status from slice of bytes.
func (ps *PaymentStatus) FromBytes(status []byte) error {
if len(status) != 1 {
return errors.New("payment status is empty")
}
switch PaymentStatus(status[0]) {
case StatusUnknown, StatusInFlight, StatusSucceeded, StatusFailed:
*ps = PaymentStatus(status[0])
default:
return errors.New("unknown payment status")
}
return nil
}
// String returns readable representation of payment status.
func (ps PaymentStatus) String() string {
switch ps {
case StatusUnknown:
return "Unknown"
case StatusInFlight:
return "In Flight"
case StatusSucceeded:
return "Succeeded"
case StatusFailed:
return "Failed"
default:
return "Unknown"
}
}
// PaymentCreationInfo is the information necessary to have ready when
// initiating a payment, moving it into state InFlight.
type PaymentCreationInfo struct {
// PaymentHash is the hash this payment is paying to.
PaymentHash lntypes.Hash
// Value is the amount we are paying.
Value lnwire.MilliSatoshi
// CreatingDate is the time when this payment was initiated.
CreationDate time.Time
// PaymentRequest is the full payment request, if any.
PaymentRequest []byte
}
// PaymentAttemptInfo contains information about a specific payment attempt for
// a given payment. This information is used by the router to handle any errors
// coming back after an attempt is made, and to query the switch about the
// status of a payment. For settled payment this will be the information for
// the succeeding payment attempt.
type PaymentAttemptInfo struct {
// PaymentID is the unique ID used for this attempt.
PaymentID uint64
// SessionKey is the ephemeral key used for this payment attempt.
SessionKey *btcec.PrivateKey
// Route is the route attempted to send the HTLC.
Route Route
}
// Payment is a wrapper around a payment's PaymentCreationInfo,
// PaymentAttemptInfo, and preimage. All payments will have the
// PaymentCreationInfo set, the PaymentAttemptInfo will be set only if at least
// one payment attempt has been made, while only completed payments will have a
// non-zero payment preimage.
type Payment struct {
// sequenceNum is a unique identifier used to sort the payments in
// order of creation.
sequenceNum uint64
// Status is the current PaymentStatus of this payment.
Status PaymentStatus
// Info holds all static information about this payment, and is
// populated when the payment is initiated.
Info *PaymentCreationInfo
// Attempt is the information about the last payment attempt made.
//
// NOTE: Can be nil if no attempt is yet made.
Attempt *PaymentAttemptInfo
// PaymentPreimage is the preimage of a successful payment. This serves
// as a proof of payment. It will only be non-nil for settled payments.
//
// NOTE: Can be nil if payment is not settled.
PaymentPreimage *lntypes.Preimage
// Failure is a failure reason code indicating the reason the payment
// failed. It is only non-nil for failed payments.
//
// NOTE: Can be nil if payment is not failed.
Failure *FailureReason
}
// FetchPayments returns all sent payments found in the DB.
func (db *DB) FetchPayments() ([]*Payment, error) {
var payments []*Payment
err := kvdb.View(db, func(tx kvdb.RTx) error {
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
return paymentsBucket.ForEach(func(k, v []byte) error {
bucket := paymentsBucket.NestedReadBucket(k)
if bucket == nil {
// We only expect sub-buckets to be found in
// this top-level bucket.
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
p, err := fetchPayment(bucket)
if err != nil {
return err
}
payments = append(payments, p)
// For older versions of lnd, duplicate payments to a
// payment has was possible. These will be found in a
// sub-bucket indexed by their sequence number if
// available.
dup := bucket.NestedReadBucket(paymentDuplicateBucket)
if dup == nil {
return nil
}
return dup.ForEach(func(k, v []byte) error {
subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
// We one bucket for each duplicate to
// be found.
return fmt.Errorf("non bucket element" +
"in duplicate bucket")
}
p, err := fetchPayment(subBucket)
if err != nil {
return err
}
payments = append(payments, p)
return nil
})
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
}
// Before returning, sort the payments by their sequence number.
sort.Slice(payments, func(i, j int) bool {
return payments[i].sequenceNum < payments[j].sequenceNum
})
return payments, nil
}
func fetchPayment(bucket kvdb.RBucket) (*Payment, error) {
var (
err error
p = &Payment{}
)
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return nil, fmt.Errorf("sequence number not found")
}
p.sequenceNum = binary.BigEndian.Uint64(seqBytes)
// Get the payment status.
p.Status = fetchPaymentStatus(bucket)
// Get the PaymentCreationInfo.
b := bucket.Get(paymentCreationInfoKey)
if b == nil {
return nil, fmt.Errorf("creation info not found")
}
r := bytes.NewReader(b)
p.Info, err = deserializePaymentCreationInfo(r)
if err != nil {
return nil, err
}
// Get the PaymentAttemptInfo. This can be unset.
b = bucket.Get(paymentAttemptInfoKey)
if b != nil {
r = bytes.NewReader(b)
p.Attempt, err = deserializePaymentAttemptInfo(r)
if err != nil {
return nil, err
}
}
// Get the payment preimage. This is only found for
// completed payments.
b = bucket.Get(paymentSettleInfoKey)
if b != nil {
var preimg lntypes.Preimage
copy(preimg[:], b[:])
p.PaymentPreimage = &preimg
}
// Get failure reason if available.
b = bucket.Get(paymentFailInfoKey)
if b != nil {
reason := FailureReason(b[0])
p.Failure = &reason
}
return p, nil
}
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
var scratch [8]byte
if _, err := w.Write(c.PaymentHash[:]); err != nil {
return err
}
byteOrder.PutUint64(scratch[:], uint64(c.Value))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
byteOrder.PutUint64(scratch[:], uint64(c.CreationDate.Unix()))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest)))
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
if _, err := w.Write(c.PaymentRequest[:]); err != nil {
return err
}
return nil
}
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
var scratch [8]byte
c := &PaymentCreationInfo{}
if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil {
return nil, err
}
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return nil, err
}
c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return nil, err
}
c.CreationDate = time.Unix(int64(byteOrder.Uint64(scratch[:])), 0)
if _, err := io.ReadFull(r, scratch[:4]); err != nil {
return nil, err
}
reqLen := uint32(byteOrder.Uint32(scratch[:4]))
payReq := make([]byte, reqLen)
if reqLen > 0 {
if _, err := io.ReadFull(r, payReq[:]); err != nil {
return nil, err
}
}
c.PaymentRequest = payReq
return c, nil
}
func serializePaymentAttemptInfo(w io.Writer, a *PaymentAttemptInfo) error {
if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil {
return err
}
if err := SerializeRoute(w, a.Route); err != nil {
return err
}
return nil
}
func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) {
a := &PaymentAttemptInfo{}
err := ReadElements(r, &a.PaymentID, &a.SessionKey)
if err != nil {
return nil, err
}
a.Route, err = DeserializeRoute(r)
if err != nil {
return nil, err
}
return a, nil
}
func serializeHop(w io.Writer, h *Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
return err
}
// For legacy payloads, we don't need to write any TLV records, so
// we'll write a zero indicating the our serialized TLV map has no
// records.
if h.LegacyPayload {
return WriteElements(w, uint32(0))
}
// Otherwise, we'll transform our slice of records into a map of the
// raw bytes, then serialize them in-line with a length (number of
// elements) prefix.
mapRecords, err := tlv.RecordsToMap(h.TLVRecords)
if err != nil {
return err
}
numRecords := uint32(len(mapRecords))
if err := WriteElements(w, numRecords); err != nil {
return err
}
for recordType, rawBytes := range mapRecords {
if err := WriteElements(w, recordType); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
return err
}
}
return nil
}
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
// to read/write a TLV stream larger than this.
const maxOnionPayloadSize = 1300
func deserializeHop(r io.Reader) (*Hop, error) {
h := &Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
// TODO(roasbeef): change field to allow LegacyPayload false to be the
// legacy default?
err := binary.Read(r, byteOrder, &h.LegacyPayload)
if err != nil {
return nil, err
}
var numElements uint32
if err := ReadElements(r, &numElements); err != nil {
return nil, err
}
// If there're no elements, then we can return early.
if numElements == 0 {
return h, nil
}
tlvMap := make(map[uint64][]byte)
for i := uint32(0); i < numElements; i++ {
var tlvType uint64
if err := ReadElements(r, &tlvType); err != nil {
return nil, err
}
rawRecordBytes, err := wire.ReadVarBytes(
r, 0, maxOnionPayloadSize, "tlv",
)
if err != nil {
return nil, err
}
tlvMap[tlvType] = rawRecordBytes
}
h.TLVRecords = tlv.MapToRecords(tlvMap)
return h, nil
}
// SerializeRoute serializes a route.
func SerializeRoute(w io.Writer, r Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHop(w, h); err != nil {
return err
}
}
return nil
}
// DeserializeRoute deserializes a route.
func DeserializeRoute(r io.Reader) (Route, error) {
rt := Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHop(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
return rt, nil
}
package migration_01_to_11
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
)
// VertexSize is the size of the array to store a vertex.
const VertexSize = 33
// ErrNoRouteHopsProvided is returned when a caller attempts to construct a new
// sphinx packet, but provides an empty set of hops for each route.
var ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided")
// Vertex is a simple alias for the serialization of a compressed Bitcoin
// public key.
type Vertex [VertexSize]byte
// NewVertex returns a new Vertex given a public key.
func NewVertex(pub *btcec.PublicKey) Vertex {
var v Vertex
copy(v[:], pub.SerializeCompressed())
return v
}
// NewVertexFromBytes returns a new Vertex based on a serialized pubkey in a
// byte slice.
func NewVertexFromBytes(b []byte) (Vertex, error) {
vertexLen := len(b)
if vertexLen != VertexSize {
return Vertex{}, fmt.Errorf("invalid vertex length of %v, "+
"want %v", vertexLen, VertexSize)
}
var v Vertex
copy(v[:], b)
return v, nil
}
// NewVertexFromStr returns a new Vertex given its hex-encoded string format.
func NewVertexFromStr(v string) (Vertex, error) {
// Return error if hex string is of incorrect length.
if len(v) != VertexSize*2 {
return Vertex{}, fmt.Errorf("invalid vertex string length of "+
"%v, want %v", len(v), VertexSize*2)
}
vertex, err := hex.DecodeString(v)
if err != nil {
return Vertex{}, err
}
return NewVertexFromBytes(vertex)
}
// String returns a human readable version of the Vertex which is the
// hex-encoding of the serialized compressed public key.
func (v Vertex) String() string {
return fmt.Sprintf("%x", v[:])
}
// Hop represents an intermediate or final node of the route. This naming
// is in line with the definition given in BOLT #4: Onion Routing Protocol.
// The struct houses the channel along which this hop can be reached and
// the values necessary to create the HTLC that needs to be sent to the
// next hop. It is also used to encode the per-hop payload included within
// the Sphinx packet.
type Hop struct {
// PubKeyBytes is the raw bytes of the public key of the target node.
PubKeyBytes Vertex
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// OutgoingTimeLock is the timelock value that should be used when
// crafting the _outgoing_ HTLC from this hop.
OutgoingTimeLock uint32
// AmtToForward is the amount that this hop will forward to the next
// hop. This value is less than the value that the incoming HTLC
// carries as a fee will be subtracted by the hop.
AmtToForward lnwire.MilliSatoshi
// TLVRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node.
TLVRecords []tlv.Record
// LegacyPayload if true, then this signals that this node doesn't
// understand the new TLV payload, so we must instead use the legacy
// payload.
LegacyPayload bool
}
// PackHopPayload writes to the passed io.Writer, the series of byes that can
// be placed directly into the per-hop payload (EOB) for this hop. This will
// include the required routing fields, as well as serializing any of the
// passed optional TLVRecords. nextChanID is the unique channel ID that
// references the _outgoing_ channel ID that follows this hop. This field
// follows the same semantics as the NextAddress field in the onion: it should
// be set to zero to indicate the terminal hop.
func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
// If this is a legacy payload, then we'll exit here as this method
// shouldn't be called.
if h.LegacyPayload == true {
return fmt.Errorf("cannot pack hop payloads for legacy " +
"payloads")
}
// Otherwise, we'll need to make a new stream that includes our
// required routing fields, as well as these optional values.
var records []tlv.Record
// Every hop must have an amount to forward and CLTV expiry.
amt := uint64(h.AmtToForward)
records = append(records,
record.NewAmtToFwdRecord(&amt),
record.NewLockTimeRecord(&h.OutgoingTimeLock),
)
// BOLT 04 says the next_hop_id should be omitted for the final hop,
// but present for all others.
//
// TODO(conner): test using hop.Exit once available
if nextChanID != 0 {
records = append(records,
record.NewNextHopIDRecord(&nextChanID),
)
}
// Append any custom types destined for this hop.
records = append(records, h.TLVRecords...)
// To ensure we produce a canonical stream, we'll sort the records
// before encoding them as a stream in the hop payload.
tlv.SortRecords(records)
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// Route represents a path through the channel graph which runs over one or
// more channels in succession. This struct carries all the information
// required to craft the Sphinx onion packet, and send the payment along the
// first hop in the path. A route is only selected as valid if all the channels
// have sufficient capacity to carry the initial payment amount after fees are
// accounted for.
type Route struct {
// TotalTimeLock is the cumulative (final) time lock across the entire
// route. This is the CLTV value that should be extended to the first
// hop in the route. All other hops will decrement the time-lock as
// advertised, leaving enough time for all hops to wait for or present
// the payment preimage to complete the payment.
TotalTimeLock uint32
// TotalAmount is the total amount of funds required to complete a
// payment over this route. This value includes the cumulative fees at
// each hop. As a result, the HTLC extended to the first-hop in the
// route will need to have at least this many satoshis, otherwise the
// route will fail at an intermediate node due to an insufficient
// amount of fees.
TotalAmount lnwire.MilliSatoshi
// SourcePubKey is the pubkey of the node where this route originates
// from.
SourcePubKey Vertex
// Hops contains details concerning the specific forwarding details at
// each hop.
Hops []*Hop
}
// HopFee returns the fee charged by the route hop indicated by hopIndex.
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
var incomingAmt lnwire.MilliSatoshi
if hopIndex == 0 {
incomingAmt = r.TotalAmount
} else {
incomingAmt = r.Hops[hopIndex-1].AmtToForward
}
// Fee is calculated as difference between incoming and outgoing amount.
return incomingAmt - r.Hops[hopIndex].AmtToForward
}
// TotalFees is the sum of the fees paid at each hop within the final route. In
// the case of a one-hop payment, this value will be zero as we don't need to
// pay a fee to ourself.
func (r *Route) TotalFees() lnwire.MilliSatoshi {
if len(r.Hops) == 0 {
return 0
}
return r.TotalAmount - r.Hops[len(r.Hops)-1].AmtToForward
}
// NewRouteFromHops creates a new Route structure from the minimally required
// information to perform the payment. It infers fee amounts and populates the
// node, chan and prev/next hop maps.
func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32,
sourceVertex Vertex, hops []*Hop) (*Route, error) {
if len(hops) == 0 {
return nil, ErrNoRouteHopsProvided
}
// First, we'll create a route struct and populate it with the fields
// for which the values are provided as arguments of this function.
// TotalFees is determined based on the difference between the amount
// that is send from the source and the final amount that is received
// by the destination.
route := &Route{
SourcePubKey: sourceVertex,
Hops: hops,
TotalTimeLock: timeLock,
TotalAmount: amtToSend,
}
return route, nil
}
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
// contains the per-hop paylods used to encoding the HTLC routing data for each
// hop in the route. This method also accepts an optional EOB payload for the
// final hop.
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
var path sphinx.PaymentPath
// For each hop encoded within the route, we'll convert the hop struct
// to an OnionHop with matching per-hop payload within the path as used
// by the sphinx package.
for i, hop := range r.Hops {
pub, err := btcec.ParsePubKey(hop.PubKeyBytes[:])
if err != nil {
return nil, err
}
// As a base case, the next hop is set to all zeroes in order
// to indicate that the "last hop" as no further hops after it.
nextHop := uint64(0)
// If we aren't on the last hop, then we set the "next address"
// field to be the channel that directly follows it.
if i != len(r.Hops)-1 {
nextHop = r.Hops[i+1].ChannelID
}
var payload sphinx.HopPayload
// If this is the legacy payload, then we can just include the
// hop data as normal.
if hop.LegacyPayload {
// Before we encode this value, we'll pack the next hop
// into the NextAddress field of the hop info to ensure
// we point to the right now.
hopData := sphinx.HopData{
ForwardAmount: uint64(hop.AmtToForward),
OutgoingCltv: hop.OutgoingTimeLock,
}
binary.BigEndian.PutUint64(
hopData.NextAddress[:], nextHop,
)
payload, err = sphinx.NewLegacyHopPayload(&hopData)
if err != nil {
return nil, err
}
} else {
// For non-legacy payloads, we'll need to pack the
// routing information, along with any extra TLV
// information into the new per-hop payload format.
// We'll also pass in the chan ID of the hop this
// channel should be forwarded to so we can construct a
// valid payload.
var b bytes.Buffer
err := hop.PackHopPayload(&b, nextHop)
if err != nil {
return nil, err
}
payload, err = sphinx.NewTLVHopPayload(b.Bytes())
if err != nil {
return nil, err
}
}
path[i] = sphinx.OnionHop{
NodePub: *pub,
HopPayload: payload,
}
}
return &path, nil
}
// String returns a human readable representation of the route.
func (r *Route) String() string {
var b strings.Builder
for i, hop := range r.Hops {
if i > 0 {
b.WriteString(",")
}
b.WriteString(strconv.FormatUint(hop.ChannelID, 10))
}
return fmt.Sprintf("amt=%v, fees=%v, tl=%v, chans=%v",
r.TotalAmount-r.TotalFees(), r.TotalFees(), r.TotalTimeLock,
b.String(),
)
}
package zpay32
import (
"fmt"
"strconv"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
)
var (
// toMSat is a map from a unit to a function that converts an amount
// of that unit to millisatoshis.
toMSat = map[byte]func(uint64) (lnwire.MilliSatoshi, error){
'm': mBtcToMSat,
'u': uBtcToMSat,
'n': nBtcToMSat,
'p': pBtcToMSat,
}
// fromMSat is a map from a unit to a function that converts an amount
// in millisatoshis to an amount of that unit.
fromMSat = map[byte]func(lnwire.MilliSatoshi) (uint64, error){
'm': mSatToMBtc,
'u': mSatToUBtc,
'n': mSatToNBtc,
'p': mSatToPBtc,
}
)
// mBtcToMSat converts the given amount in milliBTC to millisatoshis.
func mBtcToMSat(m uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(m) * 100000000, nil
}
// uBtcToMSat converts the given amount in microBTC to millisatoshis.
func uBtcToMSat(u uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(u * 100000), nil
}
// nBtcToMSat converts the given amount in nanoBTC to millisatoshis.
func nBtcToMSat(n uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(n * 100), nil
}
// pBtcToMSat converts the given amount in picoBTC to millisatoshis.
func pBtcToMSat(p uint64) (lnwire.MilliSatoshi, error) {
if p < 10 {
return 0, fmt.Errorf("minimum amount is 10p")
}
if p%10 != 0 {
return 0, fmt.Errorf("amount %d pBTC not expressible in msat",
p)
}
return lnwire.MilliSatoshi(p / 10), nil
}
// mSatToMBtc converts the given amount in millisatoshis to milliBTC.
func mSatToMBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100000000 != 0 {
return 0, fmt.Errorf("%d msat not expressible "+
"in mBTC", msat)
}
return uint64(msat / 100000000), nil
}
// mSatToUBtc converts the given amount in millisatoshis to microBTC.
func mSatToUBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100000 != 0 {
return 0, fmt.Errorf("%d msat not expressible "+
"in uBTC", msat)
}
return uint64(msat / 100000), nil
}
// mSatToNBtc converts the given amount in millisatoshis to nanoBTC.
func mSatToNBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100 != 0 {
return 0, fmt.Errorf("%d msat not expressible in nBTC", msat)
}
return uint64(msat / 100), nil
}
// mSatToPBtc converts the given amount in millisatoshis to picoBTC.
func mSatToPBtc(msat lnwire.MilliSatoshi) (uint64, error) {
return uint64(msat * 10), nil
}
// decodeAmount returns the amount encoded by the provided string in
// millisatoshi.
func decodeAmount(amount string) (lnwire.MilliSatoshi, error) {
if len(amount) < 1 {
return 0, fmt.Errorf("amount must be non-empty")
}
// If last character is a digit, then the amount can just be
// interpreted as BTC.
char := amount[len(amount)-1]
digit := char - '0'
if digit >= 0 && digit <= 9 {
btc, err := strconv.ParseUint(amount, 10, 64)
if err != nil {
return 0, err
}
return lnwire.MilliSatoshi(btc) * mSatPerBtc, nil
}
// If not a digit, it must be part of the known units.
conv, ok := toMSat[char]
if !ok {
return 0, fmt.Errorf("unknown multiplier %c", char)
}
// Known unit.
num := amount[:len(amount)-1]
if len(num) < 1 {
return 0, fmt.Errorf("number must be non-empty")
}
am, err := strconv.ParseUint(num, 10, 64)
if err != nil {
return 0, err
}
return conv(am)
}
// encodeAmount encodes the provided millisatoshi amount using as few characters
// as possible.
func encodeAmount(msat lnwire.MilliSatoshi) (string, error) {
// If possible to express in BTC, that will always be the shortest
// representation.
if msat%mSatPerBtc == 0 {
return strconv.FormatInt(int64(msat/mSatPerBtc), 10), nil
}
// Should always be expressible in pico BTC.
pico, err := fromMSat['p'](msat)
if err != nil {
return "", fmt.Errorf("unable to express %d msat as pBTC: %w",
msat, err)
}
shortened := strconv.FormatUint(pico, 10) + "p"
for unit, conv := range fromMSat {
am, err := conv(msat)
if err != nil {
// Not expressible using this unit.
continue
}
// Save the shortest found representation.
str := strconv.FormatUint(am, 10) + string(unit)
if len(str) < len(shortened) {
shortened = str
}
}
return shortened, nil
}
package zpay32
import (
"fmt"
"strings"
)
const charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
var gen = []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3}
// NOTE: This method it a slight modification of the method bech32.Decode found
// btcutil, allowing strings to be more than 90 characters.
// decodeBech32 decodes a bech32 encoded string, returning the human-readable
// part and the data part excluding the checksum.
// Note: the data will be base32 encoded, that is each element of the returned
// byte array will encode 5 bits of data. Use the ConvertBits method to convert
// this to 8-bit representation.
func decodeBech32(bech string) (string, []byte, error) {
// The maximum allowed length for a bech32 string is 90. It must also
// be at least 8 characters, since it needs a non-empty HRP, a
// separator, and a 6 character checksum.
// NB: The 90 character check specified in BIP173 is skipped here, to
// allow strings longer than 90 characters.
if len(bech) < 8 {
return "", nil, fmt.Errorf("invalid bech32 string length %d",
len(bech))
}
// Only ASCII characters between 33 and 126 are allowed.
for i := 0; i < len(bech); i++ {
if bech[i] < 33 || bech[i] > 126 {
return "", nil, fmt.Errorf("invalid character in "+
"string: '%c'", bech[i])
}
}
// The characters must be either all lowercase or all uppercase.
lower := strings.ToLower(bech)
upper := strings.ToUpper(bech)
if bech != lower && bech != upper {
return "", nil, fmt.Errorf("string not all lowercase or all " +
"uppercase")
}
// We'll work with the lowercase string from now on.
bech = lower
// The string is invalid if the last '1' is non-existent, it is the
// first character of the string (no human-readable part) or one of the
// last 6 characters of the string (since checksum cannot contain '1'),
// or if the string is more than 90 characters in total.
one := strings.LastIndexByte(bech, '1')
if one < 1 || one+7 > len(bech) {
return "", nil, fmt.Errorf("invalid index of 1")
}
// The human-readable part is everything before the last '1'.
hrp := bech[:one]
data := bech[one+1:]
// Each character corresponds to the byte with value of the index in
// 'charset'.
decoded, err := toBytes(data)
if err != nil {
return "", nil, fmt.Errorf("failed converting data to bytes: "+
"%v", err)
}
if !bech32VerifyChecksum(hrp, decoded) {
moreInfo := ""
checksum := bech[len(bech)-6:]
expected, err := toChars(bech32Checksum(hrp,
decoded[:len(decoded)-6]))
if err == nil {
moreInfo = fmt.Sprintf("Expected %v, got %v.",
expected, checksum)
}
return "", nil, fmt.Errorf("checksum failed. " + moreInfo)
}
// We exclude the last 6 bytes, which is the checksum.
return hrp, decoded[:len(decoded)-6], nil
}
// toBytes converts each character in the string 'chars' to the value of the
// index of the corresponding character in 'charset'.
func toBytes(chars string) ([]byte, error) {
decoded := make([]byte, 0, len(chars))
for i := 0; i < len(chars); i++ {
index := strings.IndexByte(charset, chars[i])
if index < 0 {
return nil, fmt.Errorf("invalid character not part of "+
"charset: %v", chars[i])
}
decoded = append(decoded, byte(index))
}
return decoded, nil
}
// toChars converts the byte slice 'data' to a string where each byte in 'data'
// encodes the index of a character in 'charset'.
func toChars(data []byte) (string, error) {
result := make([]byte, 0, len(data))
for _, b := range data {
if int(b) >= len(charset) {
return "", fmt.Errorf("invalid data byte: %v", b)
}
result = append(result, charset[b])
}
return string(result), nil
}
// For more details on the checksum calculation, please refer to BIP 173.
func bech32Checksum(hrp string, data []byte) []byte {
// Convert the bytes to list of integers, as this is needed for the
// checksum calculation.
integers := make([]int, len(data))
for i, b := range data {
integers[i] = int(b)
}
values := append(bech32HrpExpand(hrp), integers...)
values = append(values, []int{0, 0, 0, 0, 0, 0}...)
polymod := bech32Polymod(values) ^ 1
var res []byte
for i := 0; i < 6; i++ {
res = append(res, byte((polymod>>uint(5*(5-i)))&31))
}
return res
}
// For more details on the polymod calculation, please refer to BIP 173.
func bech32Polymod(values []int) int {
chk := 1
for _, v := range values {
b := chk >> 25
chk = (chk&0x1ffffff)<<5 ^ v
for i := 0; i < 5; i++ {
if (b>>uint(i))&1 == 1 {
chk ^= gen[i]
}
}
}
return chk
}
// For more details on HRP expansion, please refer to BIP 173.
func bech32HrpExpand(hrp string) []int {
v := make([]int, 0, len(hrp)*2+1)
for i := 0; i < len(hrp); i++ {
v = append(v, int(hrp[i]>>5))
}
v = append(v, 0)
for i := 0; i < len(hrp); i++ {
v = append(v, int(hrp[i]&31))
}
return v
}
// For more details on the checksum verification, please refer to BIP 173.
func bech32VerifyChecksum(hrp string, data []byte) bool {
integers := make([]int, len(data))
for i, b := range data {
integers[i] = int(b)
}
concat := append(bech32HrpExpand(hrp), integers...)
return bech32Polymod(concat) == 1
}
package zpay32
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"strings"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/bech32"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
)
// Decode parses the provided encoded invoice and returns a decoded Invoice if
// it is valid by BOLT-0011 and matches the provided active network.
func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) {
decodedInvoice := Invoice{}
// Before bech32 decoding the invoice, make sure that it is not too large.
// This is done as an anti-DoS measure since bech32 decoding is expensive.
if len(invoice) > maxInvoiceLength {
return nil, ErrInvoiceTooLarge
}
// Decode the invoice using the modified bech32 decoder.
hrp, data, err := decodeBech32(invoice)
if err != nil {
return nil, err
}
// We expect the human-readable part to at least have ln + one char
// encoding the network.
if len(hrp) < 3 {
return nil, fmt.Errorf("hrp too short")
}
// First two characters of HRP should be "ln".
if hrp[:2] != "ln" {
return nil, fmt.Errorf("prefix should be \"ln\"")
}
// The next characters should be a valid prefix for a segwit BIP173
// address that match the active network.
if !strings.HasPrefix(hrp[2:], net.Bech32HRPSegwit) {
return nil, fmt.Errorf(
"invoice not for current active network '%s'", net.Name)
}
decodedInvoice.Net = net
// Optionally, if there's anything left of the HRP after ln + the segwit
// prefix, we try to decode this as the payment amount.
var netPrefixLength = len(net.Bech32HRPSegwit) + 2
if len(hrp) > netPrefixLength {
amount, err := decodeAmount(hrp[netPrefixLength:])
if err != nil {
return nil, err
}
decodedInvoice.MilliSat = &amount
}
// Everything except the last 520 bits of the data encodes the invoice's
// timestamp and tagged fields.
if len(data) < signatureBase32Len {
return nil, errors.New("short invoice")
}
invoiceData := data[:len(data)-signatureBase32Len]
// Parse the timestamp and tagged fields, and fill the Invoice struct.
if err := parseData(&decodedInvoice, invoiceData, net); err != nil {
return nil, err
}
// The last 520 bits (104 groups) make up the signature.
sigBase32 := data[len(data)-signatureBase32Len:]
sigBase256, err := bech32.ConvertBits(sigBase32, 5, 8, true)
if err != nil {
return nil, err
}
var sig lnwire.Sig
copy(sig[:], sigBase256[:64])
recoveryID := sigBase256[64]
// The signature is over the hrp + the data the invoice, encoded in
// base 256.
taggedDataBytes, err := bech32.ConvertBits(invoiceData, 5, 8, true)
if err != nil {
return nil, err
}
toSign := append([]byte(hrp), taggedDataBytes...)
// We expect the signature to be over the single SHA-256 hash of that
// data.
hash := chainhash.HashB(toSign)
// If the destination pubkey was provided as a tagged field, use that
// to verify the signature, if not do public key recovery.
if decodedInvoice.Destination != nil {
signature, err := sig.ToSignature()
if err != nil {
return nil, fmt.Errorf("unable to deserialize "+
"signature: %v", err)
}
if !signature.Verify(hash, decodedInvoice.Destination) {
return nil, fmt.Errorf("invalid invoice signature")
}
} else {
headerByte := recoveryID + 27 + 4
compactSign := append([]byte{headerByte}, sig[:]...)
pubkey, _, err := ecdsa.RecoverCompact(compactSign, hash)
if err != nil {
return nil, err
}
decodedInvoice.Destination = pubkey
}
// If no feature vector was decoded, populate an empty one.
if decodedInvoice.Features == nil {
decodedInvoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
}
// Now that we have created the invoice, make sure it has the required
// fields set.
if err := validateInvoice(&decodedInvoice); err != nil {
return nil, err
}
return &decodedInvoice, nil
}
// parseData parses the data part of the invoice. It expects base32 data
// returned from the bech32.Decode method, except signature.
func parseData(invoice *Invoice, data []byte, net *chaincfg.Params) error {
// It must contain the timestamp, encoded using 35 bits (7 groups).
if len(data) < timestampBase32Len {
return fmt.Errorf("data too short: %d", len(data))
}
t, err := parseTimestamp(data[:timestampBase32Len])
if err != nil {
return err
}
invoice.Timestamp = time.Unix(int64(t), 0)
// The rest are tagged parts.
tagData := data[7:]
return parseTaggedFields(invoice, tagData, net)
}
// parseTimestamp converts a 35-bit timestamp (encoded in base32) to uint64.
func parseTimestamp(data []byte) (uint64, error) {
if len(data) != timestampBase32Len {
return 0, fmt.Errorf("timestamp must be 35 bits, was %d",
len(data)*5)
}
return base32ToUint64(data)
}
// parseTaggedFields takes the base32 encoded tagged fields of the invoice, and
// fills the Invoice struct accordingly.
func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) error {
index := 0
for len(fields)-index > 0 {
// If there are less than 3 groups to read, there cannot be more
// interesting information, as we need the type (1 group) and
// length (2 groups).
//
// This means the last tagged field is broken.
if len(fields)-index < 3 {
return ErrBrokenTaggedField
}
typ := fields[index]
dataLength, err := parseFieldDataLength(fields[index+1 : index+3])
if err != nil {
return err
}
// If we don't have enough field data left to read this length,
// return error.
if len(fields) < index+3+int(dataLength) {
return ErrInvalidFieldLength
}
base32Data := fields[index+3 : index+3+int(dataLength)]
// Advance the index in preparation for the next iteration.
index += 3 + int(dataLength)
switch typ {
case fieldTypeP:
if invoice.PaymentHash != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.PaymentHash, err = parse32Bytes(base32Data)
case fieldTypeS:
if invoice.PaymentAddr != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.PaymentAddr, err = parse32Bytes(base32Data)
case fieldTypeD:
if invoice.Description != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Description, err = parseDescription(base32Data)
case fieldTypeN:
if invoice.Destination != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Destination, err = parseDestination(base32Data)
case fieldTypeH:
if invoice.DescriptionHash != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.DescriptionHash, err = parse32Bytes(base32Data)
case fieldTypeX:
if invoice.expiry != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.expiry, err = parseExpiry(base32Data)
case fieldTypeC:
if invoice.minFinalCLTVExpiry != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.minFinalCLTVExpiry, err = parseMinFinalCLTVExpiry(base32Data)
case fieldTypeF:
if invoice.FallbackAddr != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.FallbackAddr, err = parseFallbackAddr(base32Data, net)
case fieldTypeR:
// An `r` field can be included in an invoice multiple
// times, so we won't skip it if we have already seen
// one.
routeHint, err := parseRouteHint(base32Data)
if err != nil {
return err
}
invoice.RouteHints = append(invoice.RouteHints, routeHint)
case fieldType9:
if invoice.Features != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Features, err = parseFeatures(base32Data)
default:
// Ignore unknown type.
}
// Check if there was an error from parsing any of the tagged
// fields and return it.
if err != nil {
return err
}
}
return nil
}
// parseFieldDataLength converts the two byte slice into a uint16.
func parseFieldDataLength(data []byte) (uint16, error) {
if len(data) != 2 {
return 0, fmt.Errorf("data length must be 2 bytes, was %d",
len(data))
}
return uint16(data[0])<<5 | uint16(data[1]), nil
}
// parse32Bytes converts a 256-bit value (encoded in base32) to *[32]byte. This
// can be used for payment hashes, description hashes, payment addresses, etc.
func parse32Bytes(data []byte) (*[32]byte, error) {
var paymentHash [32]byte
// As BOLT-11 states, a reader must skip over the 32-byte fields if
// it does not have a length of 52, so avoid returning an error.
if len(data) != hashBase32Len {
return nil, nil
}
hash, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
copy(paymentHash[:], hash)
return &paymentHash, nil
}
// parseDescription converts the data (encoded in base32) into a string to use
// as the description.
func parseDescription(data []byte) (*string, error) {
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
description := string(base256Data)
return &description, nil
}
// parseDestination converts the data (encoded in base32) into a 33-byte public
// key of the payee node.
func parseDestination(data []byte) (*btcec.PublicKey, error) {
// As BOLT-11 states, a reader must skip over the destination field
// if it does not have a length of 53, so avoid returning an error.
if len(data) != pubKeyBase32Len {
return nil, nil
}
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
return btcec.ParsePubKey(base256Data)
}
// parseExpiry converts the data (encoded in base32) into the expiry time.
func parseExpiry(data []byte) (*time.Duration, error) {
expiry, err := base32ToUint64(data)
if err != nil {
return nil, err
}
duration := time.Duration(expiry) * time.Second
return &duration, nil
}
// parseMinFinalCLTVExpiry converts the data (encoded in base32) into a uint64
// to use as the minFinalCLTVExpiry.
func parseMinFinalCLTVExpiry(data []byte) (*uint64, error) {
expiry, err := base32ToUint64(data)
if err != nil {
return nil, err
}
return &expiry, nil
}
// parseFallbackAddr converts the data (encoded in base32) into a fallback
// on-chain address.
func parseFallbackAddr(data []byte, net *chaincfg.Params) (btcutil.Address, error) {
// Checks if the data is empty or contains a version without an address.
if len(data) < 2 {
return nil, fmt.Errorf("empty fallback address field")
}
var addr btcutil.Address
version := data[0]
switch version {
case 0:
witness, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
switch len(witness) {
case 20:
addr, err = btcutil.NewAddressWitnessPubKeyHash(witness, net)
case 32:
addr, err = btcutil.NewAddressWitnessScriptHash(witness, net)
default:
return nil, fmt.Errorf("unknown witness program length %d",
len(witness))
}
if err != nil {
return nil, err
}
case 17:
pubKeyHash, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
addr, err = btcutil.NewAddressPubKeyHash(pubKeyHash, net)
if err != nil {
return nil, err
}
case 18:
scriptHash, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
addr, err = btcutil.NewAddressScriptHashFromHash(scriptHash, net)
if err != nil {
return nil, err
}
default:
// Ignore unknown version.
}
return addr, nil
}
// parseRouteHint converts the data (encoded in base32) into an array containing
// one or more routing hop hints that represent a single route hint.
func parseRouteHint(data []byte) ([]HopHint, error) {
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
// Check that base256Data is a multiple of hopHintLen.
if len(base256Data)%hopHintLen != 0 {
return nil, fmt.Errorf("expected length multiple of %d bytes, "+
"got %d", hopHintLen, len(base256Data))
}
var routeHint []HopHint
for len(base256Data) > 0 {
hopHint := HopHint{}
hopHint.NodeID, err = btcec.ParsePubKey(base256Data[:33])
if err != nil {
return nil, err
}
hopHint.ChannelID = binary.BigEndian.Uint64(base256Data[33:41])
hopHint.FeeBaseMSat = binary.BigEndian.Uint32(base256Data[41:45])
hopHint.FeeProportionalMillionths = binary.BigEndian.Uint32(base256Data[45:49])
hopHint.CLTVExpiryDelta = binary.BigEndian.Uint16(base256Data[49:51])
routeHint = append(routeHint, hopHint)
base256Data = base256Data[51:]
}
return routeHint, nil
}
// parseFeatures decodes any feature bits directly from the base32
// representation.
func parseFeatures(data []byte) (*lnwire.FeatureVector, error) {
rawFeatures := lnwire.NewRawFeatureVector()
err := rawFeatures.DecodeBase32(bytes.NewReader(data), len(data))
if err != nil {
return nil, err
}
return lnwire.NewFeatureVector(rawFeatures, lnwire.Features), nil
}
// base32ToUint64 converts a base32 encoded number to uint64.
func base32ToUint64(data []byte) (uint64, error) {
// Maximum that fits in uint64 is ceil(64 / 5) = 12 groups.
if len(data) > 13 {
return 0, fmt.Errorf("cannot parse data of length %d as uint64",
len(data))
}
val := uint64(0)
for i := 0; i < len(data); i++ {
val = val<<5 | uint64(data[i])
}
return val, nil
}
package zpay32
import "github.com/btcsuite/btcd/btcec/v2"
const (
// DefaultFinalCLTVDelta is the default value to be used as the final
// CLTV delta for a route if one is unspecified.
DefaultFinalCLTVDelta = 9
)
// HopHint is a routing hint that contains the minimum information of a channel
// required for an intermediate hop in a route to forward the payment to the
// next. This should be ideally used for private channels, since they are not
// publicly advertised to the network for routing.
type HopHint struct {
// NodeID is the public key of the node at the start of the channel.
NodeID *btcec.PublicKey
// ChannelID is the unique identifier of the channel.
ChannelID uint64
// FeeBaseMSat is the base fee of the channel in millisatoshis.
FeeBaseMSat uint32
// FeeProportionalMillionths is the fee rate, in millionths of a
// satoshi, for every satoshi sent through the channel.
FeeProportionalMillionths uint32
// CLTVExpiryDelta is the time-lock delta of the channel.
CLTVExpiryDelta uint16
}
// Copy returns a deep copy of the hop hint.
func (h HopHint) Copy() HopHint {
nodeID := *h.NodeID
return HopHint{
NodeID: &nodeID,
ChannelID: h.ChannelID,
FeeBaseMSat: h.FeeBaseMSat,
FeeProportionalMillionths: h.FeeProportionalMillionths,
CLTVExpiryDelta: h.CLTVExpiryDelta,
}
}
package zpay32
import (
"errors"
"fmt"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21"
)
const (
// mSatPerBtc is the number of millisatoshis in 1 BTC.
mSatPerBtc = 100000000000
// signatureBase32Len is the number of 5-bit groups needed to encode
// the 512 bit signature + 8 bit recovery ID.
signatureBase32Len = 104
// timestampBase32Len is the number of 5-bit groups needed to encode
// the 35-bit timestamp.
timestampBase32Len = 7
// hashBase32Len is the number of 5-bit groups needed to encode a
// 256-bit hash. Note that the last group will be padded with zeroes.
hashBase32Len = 52
// pubKeyBase32Len is the number of 5-bit groups needed to encode a
// 33-byte compressed pubkey. Note that the last group will be padded
// with zeroes.
pubKeyBase32Len = 53
// hopHintLen is the number of bytes needed to encode the hop hint of a
// single private route.
hopHintLen = 51
// The following byte values correspond to the supported field types.
// The field name is the character representing that 5-bit value in the
// bech32 string.
// fieldTypeP is the field containing the payment hash.
fieldTypeP = 1
// fieldTypeD contains a short description of the payment.
fieldTypeD = 13
// fieldTypeN contains the pubkey of the target node.
fieldTypeN = 19
// fieldTypeH contains the hash of a description of the payment.
fieldTypeH = 23
// fieldTypeX contains the expiry in seconds of the invoice.
fieldTypeX = 6
// fieldTypeF contains a fallback on-chain address.
fieldTypeF = 9
// fieldTypeR contains extra routing information.
fieldTypeR = 3
// fieldTypeC contains an optional requested final CLTV delta.
fieldTypeC = 24
// fieldType9 contains one or more bytes for signaling features
// supported or required by the receiver.
fieldType9 = 5
// fieldTypeS contains a 32-byte payment address, which is a nonce
// included in the final hop's payload to prevent intermediaries from
// probing the recipient.
fieldTypeS = 16
// maxInvoiceLength is the maximum total length an invoice can have.
// This is chosen to be the maximum number of bytes that can fit into a
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
maxInvoiceLength = 7089
// DefaultInvoiceExpiry is the default expiry duration from the creation
// timestamp if expiry is set to zero.
DefaultInvoiceExpiry = time.Hour
)
var (
// ErrInvoiceTooLarge is returned when an invoice exceeds
// maxInvoiceLength.
ErrInvoiceTooLarge = errors.New("invoice is too large")
// ErrInvalidFieldLength is returned when a tagged field was specified
// with a length larger than the left over bytes of the data field.
ErrInvalidFieldLength = errors.New("invalid field length")
// ErrBrokenTaggedField is returned when the last tagged field is
// incorrectly formatted and doesn't have enough bytes to be read.
ErrBrokenTaggedField = errors.New("last tagged field is broken")
)
// MessageSigner is passed to the Encode method to provide a signature
// corresponding to the node's pubkey.
type MessageSigner struct {
// SignCompact signs the passed hash with the node's privkey. The
// returned signature should be 65 bytes, where the last 64 are the
// compact signature, and the first one is a header byte. This is the
// format returned by ecdsa.SignCompact.
SignCompact func(hash []byte) ([]byte, error)
}
// Invoice represents a decoded invoice, or to-be-encoded invoice. Some of the
// fields are optional, and will only be non-nil if the invoice this was parsed
// from contains that field. When encoding, only the non-nil fields will be
// added to the encoded invoice.
type Invoice struct {
// Net specifies what network this Lightning invoice is meant for.
Net *chaincfg.Params
// MilliSat specifies the amount of this invoice in millisatoshi.
// Optional.
MilliSat *lnwire.MilliSatoshi
// Timestamp specifies the time this invoice was created.
// Mandatory
Timestamp time.Time
// PaymentHash is the payment hash to be used for a payment to this
// invoice.
PaymentHash *[32]byte
// PaymentAddr is the payment address to be used by payments to prevent
// probing of the destination.
PaymentAddr *[32]byte
// Destination is the public key of the target node. This will always
// be set after decoding, and can optionally be set before encoding to
// include the pubkey as an 'n' field. If this is not set before
// encoding then the destination pubkey won't be added as an 'n' field,
// and the pubkey will be extracted from the signature during decoding.
Destination *btcec.PublicKey
// minFinalCLTVExpiry is the value that the creator of the invoice
// expects to be used for the CLTV expiry of the HTLC extended to it in
// the last hop.
//
// NOTE: This value is optional, and should be set to nil if the
// invoice creator doesn't have a strong requirement on the CLTV expiry
// of the final HTLC extended to it.
//
// This field is un-exported and can only be read by the
// MinFinalCLTVExpiry() method. By forcing callers to read via this
// method, we can easily enforce the default if not specified.
minFinalCLTVExpiry *uint64
// Description is a short description of the purpose of this invoice.
// Optional. Non-nil iff DescriptionHash is nil.
Description *string
// DescriptionHash is the SHA256 hash of a description of the purpose of
// this invoice.
// Optional. Non-nil iff Description is nil.
DescriptionHash *[32]byte
// expiry specifies the timespan this invoice will be valid.
// Optional. If not set, a default expiry of 60 min will be implied.
//
// This field is unexported and can be read by the Expiry() method. This
// method makes sure the default expiry time is returned in case the
// field is not set.
expiry *time.Duration
// FallbackAddr is an on-chain address that can be used for payment in
// case the Lightning payment fails.
// Optional.
FallbackAddr btcutil.Address
// RouteHints represents one or more different route hints. Each route
// hint can be individually used to reach the destination. These usually
// represent private routes.
//
// NOTE: This is optional.
RouteHints [][]HopHint
// Features represents an optional field used to signal optional or
// required support for features by the receiver.
Features *lnwire.FeatureVector
}
// Amount is a functional option that allows callers of NewInvoice to set the
// amount in millisatoshis that the Invoice should encode.
func Amount(milliSat lnwire.MilliSatoshi) func(*Invoice) {
return func(i *Invoice) {
i.MilliSat = &milliSat
}
}
// Destination is a functional option that allows callers of NewInvoice to
// explicitly set the pubkey of the Invoice's destination node.
func Destination(destination *btcec.PublicKey) func(*Invoice) {
return func(i *Invoice) {
i.Destination = destination
}
}
// Description is a functional option that allows callers of NewInvoice to set
// the payment description of the created Invoice.
//
// NOTE: Must be used if and only if DescriptionHash is not used.
func Description(description string) func(*Invoice) {
return func(i *Invoice) {
i.Description = &description
}
}
// CLTVExpiry is an optional value which allows the receiver of the payment to
// specify the delta between the current height and the HTLC extended to the
// receiver.
func CLTVExpiry(delta uint64) func(*Invoice) {
return func(i *Invoice) {
i.minFinalCLTVExpiry = &delta
}
}
// DescriptionHash is a functional option that allows callers of NewInvoice to
// set the payment description hash of the created Invoice.
//
// NOTE: Must be used if and only if Description is not used.
func DescriptionHash(descriptionHash [32]byte) func(*Invoice) {
return func(i *Invoice) {
i.DescriptionHash = &descriptionHash
}
}
// Expiry is a functional option that allows callers of NewInvoice to set the
// expiry of the created Invoice. If not set, a default expiry of 60 min will
// be implied.
func Expiry(expiry time.Duration) func(*Invoice) {
return func(i *Invoice) {
i.expiry = &expiry
}
}
// FallbackAddr is a functional option that allows callers of NewInvoice to set
// the Invoice's fallback on-chain address that can be used for payment in case
// the Lightning payment fails
func FallbackAddr(fallbackAddr btcutil.Address) func(*Invoice) {
return func(i *Invoice) {
i.FallbackAddr = fallbackAddr
}
}
// RouteHint is a functional option that allows callers of NewInvoice to add
// one or more hop hints that represent a private route to the destination.
func RouteHint(routeHint []HopHint) func(*Invoice) {
return func(i *Invoice) {
i.RouteHints = append(i.RouteHints, routeHint)
}
}
// Features is a functional option that allows callers of NewInvoice to set the
// desired feature bits that are advertised on the invoice. If this option is
// not used, an empty feature vector will automatically be populated.
func Features(features *lnwire.FeatureVector) func(*Invoice) {
return func(i *Invoice) {
i.Features = features
}
}
// PaymentAddr is a functional option that allows callers of NewInvoice to set
// the desired payment address that is advertised on the invoice.
func PaymentAddr(addr [32]byte) func(*Invoice) {
return func(i *Invoice) {
i.PaymentAddr = &addr
}
}
// NewInvoice creates a new Invoice object. The last parameter is a set of
// variadic arguments for setting optional fields of the invoice.
//
// NOTE: Either Description or DescriptionHash must be provided for the Invoice
// to be considered valid.
func NewInvoice(net *chaincfg.Params, paymentHash [32]byte,
timestamp time.Time, options ...func(*Invoice)) (*Invoice, error) {
invoice := &Invoice{
Net: net,
PaymentHash: &paymentHash,
Timestamp: timestamp,
}
for _, option := range options {
option(invoice)
}
// If no features were set, we'll populate an empty feature vector.
if invoice.Features == nil {
invoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
}
if err := validateInvoice(invoice); err != nil {
return nil, err
}
return invoice, nil
}
// Expiry returns the expiry time for this invoice. If expiry time is not set
// explicitly, the default 3600 second expiry will be returned.
func (invoice *Invoice) Expiry() time.Duration {
if invoice.expiry != nil {
return *invoice.expiry
}
// If no expiry is set for this invoice, default is 3600 seconds.
return DefaultInvoiceExpiry
}
// MinFinalCLTVExpiry returns the minimum final CLTV expiry delta as specified
// by the creator of the invoice. This value specifies the delta between the
// current height and the expiry height of the HTLC extended in the last hop.
func (invoice *Invoice) MinFinalCLTVExpiry() uint64 {
if invoice.minFinalCLTVExpiry != nil {
return *invoice.minFinalCLTVExpiry
}
return DefaultFinalCLTVDelta
}
// validateInvoice does a sanity check of the provided Invoice, making sure it
// has all the necessary fields set for it to be considered valid by BOLT-0011.
func validateInvoice(invoice *Invoice) error {
// The net must be set.
if invoice.Net == nil {
return fmt.Errorf("net params not set")
}
// The invoice must contain a payment hash.
if invoice.PaymentHash == nil {
return fmt.Errorf("no payment hash found")
}
// Either Description or DescriptionHash must be set, not both.
if invoice.Description != nil && invoice.DescriptionHash != nil {
return fmt.Errorf("both description and description hash set")
}
if invoice.Description == nil && invoice.DescriptionHash == nil {
return fmt.Errorf("neither description nor description hash set")
}
// Check that we support the field lengths.
if len(invoice.PaymentHash) != 32 {
return fmt.Errorf("unsupported payment hash length: %d",
len(invoice.PaymentHash))
}
if invoice.DescriptionHash != nil && len(invoice.DescriptionHash) != 32 {
return fmt.Errorf("unsupported description hash length: %d",
len(invoice.DescriptionHash))
}
if invoice.Destination != nil &&
len(invoice.Destination.SerializeCompressed()) != 33 {
return fmt.Errorf("unsupported pubkey length: %d",
len(invoice.Destination.SerializeCompressed()))
}
// Ensure that all invoices have feature vectors.
if invoice.Features == nil {
return fmt.Errorf("missing feature vector")
}
return nil
}
package channeldb
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// HTLCAttemptInfo contains static information about a specific HTLC attempt
// for a payment. This information is used by the router to handle any errors
// coming back after an attempt is made, and to query the switch about the
// status of the attempt.
type HTLCAttemptInfo struct {
// AttemptID is the unique ID used for this attempt.
AttemptID uint64
// sessionKey is the raw bytes ephemeral key used for this attempt.
// These bytes are lazily read off disk to save ourselves the expensive
// EC operations used by btcec.PrivKeyFromBytes.
sessionKey [btcec.PrivKeyBytesLen]byte
// cachedSessionKey is our fully deserialized sesionKey. This value
// may be nil if the attempt has just been read from disk and its
// session key has not been used yet.
cachedSessionKey *btcec.PrivateKey
// Route is the route attempted to send the HTLC.
Route route.Route
// AttemptTime is the time at which this HTLC was attempted.
AttemptTime time.Time
// Hash is the hash used for this single HTLC attempt. For AMP payments
// this will differ across attempts, for non-AMP payments each attempt
// will use the same hash. This can be nil for older payment attempts,
// in which the payment's PaymentHash in the PaymentCreationInfo should
// be used.
Hash *lntypes.Hash
// onionBlob is the cached value for onion blob created from the sphinx
// construction.
onionBlob [lnwire.OnionPacketSize]byte
// circuit is the cached value for sphinx circuit.
circuit *sphinx.Circuit
}
// NewHtlcAttempt creates a htlc attempt.
func NewHtlcAttempt(attemptID uint64, sessionKey *btcec.PrivateKey,
route route.Route, attemptTime time.Time,
hash *lntypes.Hash) (*HTLCAttempt, error) {
var scratch [btcec.PrivKeyBytesLen]byte
copy(scratch[:], sessionKey.Serialize())
info := HTLCAttemptInfo{
AttemptID: attemptID,
sessionKey: scratch,
cachedSessionKey: sessionKey,
Route: route,
AttemptTime: attemptTime,
Hash: hash,
}
if err := info.attachOnionBlobAndCircuit(); err != nil {
return nil, err
}
return &HTLCAttempt{HTLCAttemptInfo: info}, nil
}
// SessionKey returns the ephemeral key used for a htlc attempt. This function
// performs expensive ec-ops to obtain the session key if it is not cached.
func (h *HTLCAttemptInfo) SessionKey() *btcec.PrivateKey {
if h.cachedSessionKey == nil {
h.cachedSessionKey, _ = btcec.PrivKeyFromBytes(
h.sessionKey[:],
)
}
return h.cachedSessionKey
}
// OnionBlob returns the onion blob created from the sphinx construction.
func (h *HTLCAttemptInfo) OnionBlob() ([lnwire.OnionPacketSize]byte, error) {
var zeroBytes [lnwire.OnionPacketSize]byte
if h.onionBlob == zeroBytes {
if err := h.attachOnionBlobAndCircuit(); err != nil {
return zeroBytes, err
}
}
return h.onionBlob, nil
}
// Circuit returns the sphinx circuit for this attempt.
func (h *HTLCAttemptInfo) Circuit() (*sphinx.Circuit, error) {
if h.circuit == nil {
if err := h.attachOnionBlobAndCircuit(); err != nil {
return nil, err
}
}
return h.circuit, nil
}
// attachOnionBlobAndCircuit creates a sphinx packet and caches the onion blob
// and circuit for this attempt.
func (h *HTLCAttemptInfo) attachOnionBlobAndCircuit() error {
onionBlob, circuit, err := generateSphinxPacket(
&h.Route, h.Hash[:], h.SessionKey(),
)
if err != nil {
return err
}
copy(h.onionBlob[:], onionBlob)
h.circuit = circuit
return nil
}
// HTLCAttempt contains information about a specific HTLC attempt for a given
// payment. It contains the HTLCAttemptInfo used to send the HTLC, as well
// as a timestamp and any known outcome of the attempt.
type HTLCAttempt struct {
HTLCAttemptInfo
// Settle is the preimage of a successful payment. This serves as a
// proof of payment. It will only be non-nil for settled payments.
//
// NOTE: Can be nil if payment is not settled.
Settle *HTLCSettleInfo
// Fail is a failure reason code indicating the reason the payment
// failed. It is only non-nil for failed payments.
//
// NOTE: Can be nil if payment is not failed.
Failure *HTLCFailInfo
}
// HTLCSettleInfo encapsulates the information that augments an HTLCAttempt in
// the event that the HTLC is successful.
type HTLCSettleInfo struct {
// Preimage is the preimage of a successful HTLC. This serves as a proof
// of payment.
Preimage lntypes.Preimage
// SettleTime is the time at which this HTLC was settled.
SettleTime time.Time
}
// HTLCFailReason is the reason an htlc failed.
type HTLCFailReason byte
const (
// HTLCFailUnknown is recorded for htlcs that failed with an unknown
// reason.
HTLCFailUnknown HTLCFailReason = 0
// HTLCFailUnknown is recorded for htlcs that had a failure message that
// couldn't be decrypted.
HTLCFailUnreadable HTLCFailReason = 1
// HTLCFailInternal is recorded for htlcs that failed because of an
// internal error.
HTLCFailInternal HTLCFailReason = 2
// HTLCFailMessage is recorded for htlcs that failed with a network
// failure message.
HTLCFailMessage HTLCFailReason = 3
)
// HTLCFailInfo encapsulates the information that augments an HTLCAttempt in the
// event that the HTLC fails.
type HTLCFailInfo struct {
// FailTime is the time at which this HTLC was failed.
FailTime time.Time
// Message is the wire message that failed this HTLC. This field will be
// populated when the failure reason is HTLCFailMessage.
Message lnwire.FailureMessage
// Reason is the failure reason for this HTLC.
Reason HTLCFailReason
// The position in the path of the intermediate or final node that
// generated the failure message. Position zero is the sender node. This
// field will be populated when the failure reason is either
// HTLCFailMessage or HTLCFailUnknown.
FailureSourceIndex uint32
}
// MPPaymentState wraps a series of info needed for a given payment, which is
// used by both MPP and AMP. This is a memory representation of the payment's
// current state and is updated whenever the payment is read from disk.
type MPPaymentState struct {
// NumAttemptsInFlight specifies the number of HTLCs the payment is
// waiting results for.
NumAttemptsInFlight int
// RemainingAmt specifies how much more money to be sent.
RemainingAmt lnwire.MilliSatoshi
// FeesPaid specifies the total fees paid so far that can be used to
// calculate remaining fee budget.
FeesPaid lnwire.MilliSatoshi
// HasSettledHTLC is true if at least one of the payment's HTLCs is
// settled.
HasSettledHTLC bool
// PaymentFailed is true if the payment has been marked as failed with
// a reason.
PaymentFailed bool
}
// MPPayment is a wrapper around a payment's PaymentCreationInfo and
// HTLCAttempts. All payments will have the PaymentCreationInfo set, any
// HTLCs made in attempts to be completed will populated in the HTLCs slice.
// Each populated HTLCAttempt represents an attempted HTLC, each of which may
// have the associated Settle or Fail struct populated if the HTLC is no longer
// in-flight.
type MPPayment struct {
// SequenceNum is a unique identifier used to sort the payments in
// order of creation.
SequenceNum uint64
// Info holds all static information about this payment, and is
// populated when the payment is initiated.
Info *PaymentCreationInfo
// HTLCs holds the information about individual HTLCs that we send in
// order to make the payment.
HTLCs []HTLCAttempt
// FailureReason is the failure reason code indicating the reason the
// payment failed.
//
// NOTE: Will only be set once the daemon has given up on the payment
// altogether.
FailureReason *FailureReason
// Status is the current PaymentStatus of this payment.
Status PaymentStatus
// State is the current state of the payment that holds a number of key
// insights and is used to determine what to do on each payment loop
// iteration.
State *MPPaymentState
}
// Terminated returns a bool to specify whether the payment is in a terminal
// state.
func (m *MPPayment) Terminated() bool {
// If the payment is in terminal state, it cannot be updated.
return m.Status.updatable() != nil
}
// TerminalInfo returns any HTLC settle info recorded. If no settle info is
// recorded, any payment level failure will be returned. If neither a settle
// nor a failure is recorded, both return values will be nil.
func (m *MPPayment) TerminalInfo() (*HTLCAttempt, *FailureReason) {
for _, h := range m.HTLCs {
if h.Settle != nil {
return &h, nil
}
}
return nil, m.FailureReason
}
// SentAmt returns the sum of sent amount and fees for HTLCs that are either
// settled or still in flight.
func (m *MPPayment) SentAmt() (lnwire.MilliSatoshi, lnwire.MilliSatoshi) {
var sent, fees lnwire.MilliSatoshi
for _, h := range m.HTLCs {
if h.Failure != nil {
continue
}
// The attempt was not failed, meaning the amount was
// potentially sent to the receiver.
sent += h.Route.ReceiverAmt()
fees += h.Route.TotalFees()
}
return sent, fees
}
// InFlightHTLCs returns the HTLCs that are still in-flight, meaning they have
// not been settled or failed.
func (m *MPPayment) InFlightHTLCs() []HTLCAttempt {
var inflights []HTLCAttempt
for _, h := range m.HTLCs {
if h.Settle != nil || h.Failure != nil {
continue
}
inflights = append(inflights, h)
}
return inflights
}
// GetAttempt returns the specified htlc attempt on the payment.
func (m *MPPayment) GetAttempt(id uint64) (*HTLCAttempt, error) {
// TODO(yy): iteration can be slow, make it into a tree or use BS.
for _, htlc := range m.HTLCs {
htlc := htlc
if htlc.AttemptID == id {
return &htlc, nil
}
}
return nil, errors.New("htlc attempt not found on payment")
}
// Registrable returns an error to specify whether adding more HTLCs to the
// payment with its current status is allowed. A payment can accept new HTLC
// registrations when it's newly created, or none of its HTLCs is in a terminal
// state.
func (m *MPPayment) Registrable() error {
// If updating the payment is not allowed, we can't register new HTLCs.
// Otherwise, the status must be either `StatusInitiated` or
// `StatusInFlight`.
if err := m.Status.updatable(); err != nil {
return err
}
// Exit early if this is not inflight.
if m.Status != StatusInFlight {
return nil
}
// There are still inflight HTLCs and we need to check whether there
// are settled HTLCs or the payment is failed. If we already have
// settled HTLCs, we won't allow adding more HTLCs.
if m.State.HasSettledHTLC {
return ErrPaymentPendingSettled
}
// If the payment is already failed, we won't allow adding more HTLCs.
if m.State.PaymentFailed {
return ErrPaymentPendingFailed
}
// Otherwise we can add more HTLCs.
return nil
}
// setState creates and attaches a new MPPaymentState to the payment. It also
// updates the payment's status based on its current state.
func (m *MPPayment) setState() error {
// Fetch the total amount and fees that has already been sent in
// settled and still in-flight shards.
sentAmt, fees := m.SentAmt()
// Sanity check we haven't sent a value larger than the payment amount.
totalAmt := m.Info.Value
if sentAmt > totalAmt {
return fmt.Errorf("%w: sent=%v, total=%v", ErrSentExceedsTotal,
sentAmt, totalAmt)
}
// Get any terminal info for this payment.
settle, failure := m.TerminalInfo()
// Now determine the payment's status.
status, err := decidePaymentStatus(m.HTLCs, m.FailureReason)
if err != nil {
return err
}
// Update the payment state and status.
m.State = &MPPaymentState{
NumAttemptsInFlight: len(m.InFlightHTLCs()),
RemainingAmt: totalAmt - sentAmt,
FeesPaid: fees,
HasSettledHTLC: settle != nil,
PaymentFailed: failure != nil,
}
m.Status = status
return nil
}
// SetState calls the internal method setState. This is a temporary method
// to be used by the tests in routing. Once the tests are updated to use mocks,
// this method can be removed.
//
// TODO(yy): delete.
func (m *MPPayment) SetState() error {
return m.setState()
}
// NeedWaitAttempts decides whether we need to hold creating more HTLC attempts
// and wait for the results of the payment's inflight HTLCs. Return an error if
// the payment is in an unexpected state.
func (m *MPPayment) NeedWaitAttempts() (bool, error) {
// Check when the remainingAmt is not zero, which means we have more
// money to be sent.
if m.State.RemainingAmt != 0 {
switch m.Status {
// If the payment is newly created, no need to wait for HTLC
// results.
case StatusInitiated:
return false, nil
// If we have inflight HTLCs, we'll check if we have terminal
// states to decide if we need to wait.
case StatusInFlight:
// We still have money to send, and one of the HTLCs is
// settled. We'd stop sending money and wait for all
// inflight HTLC attempts to finish.
if m.State.HasSettledHTLC {
log.Warnf("payment=%v has remaining amount "+
"%v, yet at least one of its HTLCs is "+
"settled", m.Info.PaymentIdentifier,
m.State.RemainingAmt)
return true, nil
}
// The payment has a failure reason though we still
// have money to send, we'd stop sending money and wait
// for all inflight HTLC attempts to finish.
if m.State.PaymentFailed {
return true, nil
}
// Otherwise we don't need to wait for inflight HTLCs
// since we still have money to be sent.
return false, nil
// We need to send more money, yet the payment is already
// succeeded. Return an error in this case as the receiver is
// violating the protocol.
case StatusSucceeded:
return false, fmt.Errorf("%w: parts of the payment "+
"already succeeded but still have remaining "+
"amount %v", ErrPaymentInternal,
m.State.RemainingAmt)
// The payment is failed and we have no inflight HTLCs, no need
// to wait.
case StatusFailed:
return false, nil
// Unknown payment status.
default:
return false, fmt.Errorf("%w: %s",
ErrUnknownPaymentStatus, m.Status)
}
}
// Now we determine whether we need to wait when the remainingAmt is
// already zero.
switch m.Status {
// When the payment is newly created, yet the payment has no remaining
// amount, return an error.
case StatusInitiated:
return false, fmt.Errorf("%w: %v", ErrPaymentInternal, m.Status)
// If the payment is inflight, we must wait.
//
// NOTE: an edge case is when all HTLCs are failed while the payment is
// not failed we'd still be in this inflight state. However, since the
// remainingAmt is zero here, it means we cannot be in that state as
// otherwise the remainingAmt would not be zero.
case StatusInFlight:
return true, nil
// If the payment is already succeeded, no need to wait.
case StatusSucceeded:
return false, nil
// If the payment is already failed, yet the remaining amount is zero,
// return an error as this indicates an error state. We will only each
// this status when there are no inflight HTLCs and the payment is
// marked as failed with a reason, which means the remainingAmt must
// not be zero because our sentAmt is zero.
case StatusFailed:
return false, fmt.Errorf("%w: %v", ErrPaymentInternal, m.Status)
// Unknown payment status.
default:
return false, fmt.Errorf("%w: %s", ErrUnknownPaymentStatus,
m.Status)
}
}
// GetState returns the internal state of the payment.
func (m *MPPayment) GetState() *MPPaymentState {
return m.State
}
// Status returns the current status of the payment.
func (m *MPPayment) GetStatus() PaymentStatus {
return m.Status
}
// GetPayment returns all the HTLCs for this payment.
func (m *MPPayment) GetHTLCs() []HTLCAttempt {
return m.HTLCs
}
// AllowMoreAttempts is used to decide whether we can safely attempt more HTLCs
// for a given payment state. Return an error if the payment is in an
// unexpected state.
func (m *MPPayment) AllowMoreAttempts() (bool, error) {
// Now check whether the remainingAmt is zero or not. If we don't have
// any remainingAmt, no more HTLCs should be made.
if m.State.RemainingAmt == 0 {
// If the payment is newly created, yet we don't have any
// remainingAmt, return an error.
if m.Status == StatusInitiated {
return false, fmt.Errorf("%w: initiated payment has "+
"zero remainingAmt", ErrPaymentInternal)
}
// Otherwise, exit early since all other statuses with zero
// remainingAmt indicate no more HTLCs can be made.
return false, nil
}
// Otherwise, the remaining amount is not zero, we now decide whether
// to make more attempts based on the payment's current status.
//
// If at least one of the payment's attempts is settled, yet we haven't
// sent all the amount, it indicates something is wrong with the peer
// as the preimage is received. In this case, return an error state.
if m.Status == StatusSucceeded {
return false, fmt.Errorf("%w: payment already succeeded but "+
"still have remaining amount %v", ErrPaymentInternal,
m.State.RemainingAmt)
}
// Now check if we can register a new HTLC.
err := m.Registrable()
if err != nil {
log.Warnf("Payment(%v): cannot register HTLC attempt: %v, "+
"current status: %s", m.Info.PaymentIdentifier,
err, m.Status)
return false, nil
}
// Now we know we can register new HTLCs.
return true, nil
}
// serializeHTLCSettleInfo serializes the details of a settled htlc.
func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error {
if _, err := w.Write(s.Preimage[:]); err != nil {
return err
}
if err := serializeTime(w, s.SettleTime); err != nil {
return err
}
return nil
}
// deserializeHTLCSettleInfo deserializes the details of a settled htlc.
func deserializeHTLCSettleInfo(r io.Reader) (*HTLCSettleInfo, error) {
s := &HTLCSettleInfo{}
if _, err := io.ReadFull(r, s.Preimage[:]); err != nil {
return nil, err
}
var err error
s.SettleTime, err = deserializeTime(r)
if err != nil {
return nil, err
}
return s, nil
}
// serializeHTLCFailInfo serializes the details of a failed htlc including the
// wire failure.
func serializeHTLCFailInfo(w io.Writer, f *HTLCFailInfo) error {
if err := serializeTime(w, f.FailTime); err != nil {
return err
}
// Write failure. If there is no failure message, write an empty
// byte slice.
var messageBytes bytes.Buffer
if f.Message != nil {
err := lnwire.EncodeFailureMessage(&messageBytes, f.Message, 0)
if err != nil {
return err
}
}
if err := wire.WriteVarBytes(w, 0, messageBytes.Bytes()); err != nil {
return err
}
return WriteElements(w, byte(f.Reason), f.FailureSourceIndex)
}
// deserializeHTLCFailInfo deserializes the details of a failed htlc including
// the wire failure.
func deserializeHTLCFailInfo(r io.Reader) (*HTLCFailInfo, error) {
f := &HTLCFailInfo{}
var err error
f.FailTime, err = deserializeTime(r)
if err != nil {
return nil, err
}
// Read failure.
failureBytes, err := wire.ReadVarBytes(
r, 0, math.MaxUint16, "failure",
)
if err != nil {
return nil, err
}
if len(failureBytes) > 0 {
f.Message, err = lnwire.DecodeFailureMessage(
bytes.NewReader(failureBytes), 0,
)
if err != nil {
return nil, err
}
}
var reason byte
err = ReadElements(r, &reason, &f.FailureSourceIndex)
if err != nil {
return nil, err
}
f.Reason = HTLCFailReason(reason)
return f, nil
}
// deserializeTime deserializes time as unix nanoseconds.
func deserializeTime(r io.Reader) (time.Time, error) {
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return time.Time{}, err
}
// Convert to time.Time. Interpret unix nano time zero as a zero
// time.Time value.
unixNano := byteOrder.Uint64(scratch[:])
if unixNano == 0 {
return time.Time{}, nil
}
return time.Unix(0, int64(unixNano)), nil
}
// serializeTime serializes time as unix nanoseconds.
func serializeTime(w io.Writer, t time.Time) error {
var scratch [8]byte
// Convert to unix nano seconds, but only if time is non-zero. Calling
// UnixNano() on a zero time yields an undefined result.
var unixNano int64
if !t.IsZero() {
unixNano = t.UnixNano()
}
byteOrder.PutUint64(scratch[:], uint64(unixNano))
_, err := w.Write(scratch[:])
return err
}
// generateSphinxPacket generates then encodes a sphinx packet which encodes
// the onion route specified by the passed layer 3 route. The blob returned
// from this function can immediately be included within an HTLC add packet to
// be sent to the first hop within the route.
func generateSphinxPacket(rt *route.Route, paymentHash []byte,
sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) {
// Now that we know we have an actual route, we'll map the route into a
// sphinx payment path which includes per-hop payloads for each hop
// that give each node within the route the necessary information
// (fees, CLTV value, etc.) to properly forward the payment.
sphinxPath, err := rt.ToSphinxPath()
if err != nil {
return nil, nil, err
}
log.Tracef("Constructed per-hop payloads for payment_hash=%x: %v",
paymentHash, lnutils.NewLogClosure(func() string {
path := make(
[]sphinx.OnionHop, sphinxPath.TrueRouteLength(),
)
for i := range path {
hopCopy := sphinxPath[i]
path[i] = hopCopy
}
return spew.Sdump(path)
}),
)
// Next generate the onion routing packet which allows us to perform
// privacy preserving source routing across the network.
sphinxPacket, err := sphinx.NewOnionPacket(
sphinxPath, sessionKey, paymentHash,
sphinx.DeterministicPacketFiller,
)
if err != nil {
return nil, nil, err
}
// Finally, encode Sphinx packet using its wire representation to be
// included within the HTLC add packet.
var onionBlob bytes.Buffer
if err := sphinxPacket.Encode(&onionBlob); err != nil {
return nil, nil, err
}
log.Tracef("Generated sphinx packet: %v",
lnutils.NewLogClosure(func() string {
// We make a copy of the ephemeral key and unset the
// internal curve here in order to keep the logs from
// getting noisy.
key := *sphinxPacket.EphemeralKey
packetCopy := *sphinxPacket
packetCopy.EphemeralKey = &key
return spew.Sdump(packetCopy)
}),
)
return onionBlob.Bytes(), &sphinx.Circuit{
SessionKey: sessionKey,
PaymentPath: sphinxPath.NodeKeys(),
}, nil
}
package channeldb
import (
"bytes"
"io"
"net"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// nodeInfoBucket stores metadata pertaining to nodes that we've had
// direct channel-based correspondence with. This bucket allows one to
// query for all open channels pertaining to the node by exploring each
// node's sub-bucket within the openChanBucket.
nodeInfoBucket = []byte("nib")
)
// LinkNode stores metadata related to node's that we have/had a direct
// channel open with. Information such as the Bitcoin network the node
// advertised, and its identity public key are also stored. Additionally, this
// struct and the bucket its stored within have store data similar to that of
// Bitcoin's addrmanager. The TCP address information stored within the struct
// can be used to establish persistent connections will all channel
// counterparties on daemon startup.
//
// TODO(roasbeef): also add current OnionKey plus rotation schedule?
// TODO(roasbeef): add bitfield for supported services
// - possibly add a wire.NetAddress type, type
type LinkNode struct {
// Network indicates the Bitcoin network that the LinkNode advertises
// for incoming channel creation.
Network wire.BitcoinNet
// IdentityPub is the node's current identity public key. Any
// channel/topology related information received by this node MUST be
// signed by this public key.
IdentityPub *btcec.PublicKey
// LastSeen tracks the last time this node was seen within the network.
// A node should be marked as seen if the daemon either is able to
// establish an outgoing connection to the node or receives a new
// incoming connection from the node. This timestamp (stored in unix
// epoch) may be used within a heuristic which aims to determine when a
// channel should be unilaterally closed due to inactivity.
//
// TODO(roasbeef): replace with block hash/height?
// * possibly add a time-value metric into the heuristic?
LastSeen time.Time
// Addresses is a list of IP address in which either we were able to
// reach the node over in the past, OR we received an incoming
// authenticated connection for the stored identity public key.
Addresses []net.Addr
// db is the database instance this node was fetched from. This is used
// to sync back the node's state if it is updated.
db *LinkNodeDB
}
// NewLinkNode creates a new LinkNode from the provided parameters, which is
// backed by an instance of a link node DB.
func NewLinkNode(db *LinkNodeDB, bitNet wire.BitcoinNet, pub *btcec.PublicKey,
addrs ...net.Addr) *LinkNode {
return &LinkNode{
Network: bitNet,
IdentityPub: pub,
LastSeen: time.Now(),
Addresses: addrs,
db: db,
}
}
// UpdateLastSeen updates the last time this node was directly encountered on
// the Lightning Network.
func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error {
l.LastSeen = lastSeen
return l.Sync()
}
// AddAddress appends the specified TCP address to the list of known addresses
// this node is/was known to be reachable at.
func (l *LinkNode) AddAddress(addr net.Addr) error {
for _, a := range l.Addresses {
if a.String() == addr.String() {
return nil
}
}
l.Addresses = append(l.Addresses, addr)
return l.Sync()
}
// Sync performs a full database sync which writes the current up-to-date data
// within the struct to the database.
func (l *LinkNode) Sync() error {
// Finally update the database by storing the link node and updating
// any relevant indexes.
return kvdb.Update(l.db.backend, func(tx kvdb.RwTx) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
}
return putLinkNode(nodeMetaBucket, l)
}, func() {})
}
// putLinkNode serializes then writes the encoded version of the passed link
// node into the nodeMetaBucket. This function is provided in order to allow
// the ability to re-use a database transaction across many operations.
func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error {
// First serialize the LinkNode into its raw-bytes encoding.
var b bytes.Buffer
if err := serializeLinkNode(&b, l); err != nil {
return err
}
// Finally insert the link-node into the node metadata bucket keyed
// according to the its pubkey serialized in compressed form.
nodePub := l.IdentityPub.SerializeCompressed()
return nodeMetaBucket.Put(nodePub, b.Bytes())
}
// LinkNodeDB is a database that keeps track of all link nodes.
type LinkNodeDB struct {
backend kvdb.Backend
}
// DeleteLinkNode removes the link node with the given identity from the
// database.
func (l *LinkNodeDB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(l.backend, func(tx kvdb.RwTx) error {
return deleteLinkNode(tx, identity)
}, func() {})
}
func deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return ErrLinkNodesNotFound
}
pubKey := identity.SerializeCompressed()
return nodeMetaBucket.Delete(pubKey)
}
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target
// identity public key. If a particular LinkNode for the passed identity public
// key cannot be found, then ErrNodeNotFound if returned.
func (l *LinkNodeDB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
var linkNode *LinkNode
err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
node, err := fetchLinkNode(tx, identity)
if err != nil {
return err
}
linkNode = node
return nil
}, func() {
linkNode = nil
})
return linkNode, err
}
func fetchLinkNode(tx kvdb.RTx, targetPub *btcec.PublicKey) (*LinkNode, error) {
// First fetch the bucket for storing node metadata, bailing out early
// if it hasn't been created yet.
nodeMetaBucket := tx.ReadBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound
}
// If a link node for that particular public key cannot be located,
// then exit early with an ErrNodeNotFound.
pubKey := targetPub.SerializeCompressed()
nodeBytes := nodeMetaBucket.Get(pubKey)
if nodeBytes == nil {
return nil, ErrNodeNotFound
}
// Finally, decode and allocate a fresh LinkNode object to be returned
// to the caller.
nodeReader := bytes.NewReader(nodeBytes)
return deserializeLinkNode(nodeReader)
}
// TODO(roasbeef): update link node addrs in server upon connection
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with
// whom we have active channels with.
func (l *LinkNodeDB) FetchAllLinkNodes() ([]*LinkNode, error) {
var linkNodes []*LinkNode
err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
nodes, err := fetchAllLinkNodes(tx)
if err != nil {
return err
}
linkNodes = nodes
return nil
}, func() {
linkNodes = nil
})
if err != nil {
return nil, err
}
return linkNodes, nil
}
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
// with whom we have active channels with.
func fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) {
nodeMetaBucket := tx.ReadBucket(nodeInfoBucket)
if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound
}
var linkNodes []*LinkNode
err := nodeMetaBucket.ForEach(func(k, v []byte) error {
if v == nil {
return nil
}
nodeReader := bytes.NewReader(v)
linkNode, err := deserializeLinkNode(nodeReader)
if err != nil {
return err
}
linkNodes = append(linkNodes, linkNode)
return nil
})
if err != nil {
return nil, err
}
return linkNodes, nil
}
func serializeLinkNode(w io.Writer, l *LinkNode) error {
var buf [8]byte
byteOrder.PutUint32(buf[:4], uint32(l.Network))
if _, err := w.Write(buf[:4]); err != nil {
return err
}
serializedID := l.IdentityPub.SerializeCompressed()
if _, err := w.Write(serializedID); err != nil {
return err
}
seenUnix := uint64(l.LastSeen.Unix())
byteOrder.PutUint64(buf[:], seenUnix)
if _, err := w.Write(buf[:]); err != nil {
return err
}
numAddrs := uint32(len(l.Addresses))
byteOrder.PutUint32(buf[:4], numAddrs)
if _, err := w.Write(buf[:4]); err != nil {
return err
}
for _, addr := range l.Addresses {
if err := graphdb.SerializeAddr(w, addr); err != nil {
return err
}
}
return nil
}
func deserializeLinkNode(r io.Reader) (*LinkNode, error) {
var (
err error
buf [8]byte
)
node := &LinkNode{}
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return nil, err
}
node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4]))
var pub [33]byte
if _, err := io.ReadFull(r, pub[:]); err != nil {
return nil, err
}
node.IdentityPub, err = btcec.ParsePubKey(pub[:])
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, buf[:]); err != nil {
return nil, err
}
node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0)
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return nil, err
}
numAddrs := byteOrder.Uint32(buf[:4])
node.Addresses = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := graphdb.DeserializeAddr(r)
if err != nil {
return nil, err
}
node.Addresses[i] = addr
}
return node, nil
}
package channeldb
import (
"github.com/lightningnetwork/lnd/clock"
)
const (
// DefaultRejectCacheSize is the default number of rejectCacheEntries to
// cache for use in the rejection cache of incoming gossip traffic. This
// produces a cache size of around 1MB.
DefaultRejectCacheSize = 50000
// DefaultChannelCacheSize is the default number of ChannelEdges cached
// in order to reply to gossip queries. This produces a cache size of
// around 40MB.
DefaultChannelCacheSize = 20000
// DefaultPreAllocCacheNumNodes is the default number of channels we
// assume for mainnet for pre-allocating the graph cache. As of
// September 2021, there currently are 14k nodes in a strictly pruned
// graph, so we choose a number that is slightly higher.
DefaultPreAllocCacheNumNodes = 15000
)
// OptionalMiragtionConfig defines the flags used to signal whether a
// particular migration needs to be applied.
type OptionalMiragtionConfig struct {
// PruneRevocationLog specifies that the revocation log migration needs
// to be applied.
PruneRevocationLog bool
}
// Options holds parameters for tuning and customizing a channeldb.DB.
type Options struct {
OptionalMiragtionConfig
// NoMigration specifies that underlying backend was opened in read-only
// mode and migrations shouldn't be performed. This can be useful for
// applications that use the channeldb package as a library.
NoMigration bool
// NoRevLogAmtData when set to true, indicates that amount data should
// not be stored in the revocation log.
NoRevLogAmtData bool
// clock is the time source used by the database.
clock clock.Clock
// dryRun will fail to commit a successful migration when opening the
// database if set to true.
dryRun bool
// keepFailedPaymentAttempts determines whether failed htlc attempts
// are kept on disk or removed to save space.
keepFailedPaymentAttempts bool
// storeFinalHtlcResolutions determines whether to persistently store
// the final resolution of incoming htlcs.
storeFinalHtlcResolutions bool
}
// DefaultOptions returns an Options populated with default values.
func DefaultOptions() Options {
return Options{
OptionalMiragtionConfig: OptionalMiragtionConfig{},
NoMigration: false,
clock: clock.NewDefaultClock(),
}
}
// OptionModifier is a function signature for modifying the default Options.
type OptionModifier func(*Options)
// OptionNoRevLogAmtData sets the NoRevLogAmtData option to the given value. If
// it is set to true then amount data will not be stored in the revocation log.
func OptionNoRevLogAmtData(noAmtData bool) OptionModifier {
return func(o *Options) {
o.NoRevLogAmtData = noAmtData
}
}
// OptionNoMigration allows the database to be opened in read only mode by
// disabling migrations.
func OptionNoMigration(b bool) OptionModifier {
return func(o *Options) {
o.NoMigration = b
}
}
// OptionClock sets a non-default clock dependency.
func OptionClock(clock clock.Clock) OptionModifier {
return func(o *Options) {
o.clock = clock
}
}
// OptionDryRunMigration controls whether or not to intentionally fail to commit a
// successful migration that occurs when opening the database.
func OptionDryRunMigration(dryRun bool) OptionModifier {
return func(o *Options) {
o.dryRun = dryRun
}
}
// OptionKeepFailedPaymentAttempts controls whether failed payment attempts are
// kept on disk after a payment settles.
func OptionKeepFailedPaymentAttempts(keepFailedPaymentAttempts bool) OptionModifier {
return func(o *Options) {
o.keepFailedPaymentAttempts = keepFailedPaymentAttempts
}
}
// OptionStoreFinalHtlcResolutions controls whether to persistently store the
// final resolution of incoming htlcs.
func OptionStoreFinalHtlcResolutions(
storeFinalHtlcResolutions bool) OptionModifier {
return func(o *Options) {
o.storeFinalHtlcResolutions = storeFinalHtlcResolutions
}
}
// OptionPruneRevocationLog specifies whether the migration for pruning
// revocation logs needs to be applied or not.
func OptionPruneRevocationLog(prune bool) OptionModifier {
return func(o *Options) {
o.OptionalMiragtionConfig.PruneRevocationLog = prune
}
}
package channeldb
import "github.com/lightningnetwork/lnd/kvdb"
type paginator struct {
// cursor is the cursor which we are using to iterate through a bucket.
cursor kvdb.RCursor
// reversed indicates whether we are paginating forwards or backwards.
reversed bool
// indexOffset is the index from which we will begin querying.
indexOffset uint64
// totalItems is the total number of items we allow in our response.
totalItems uint64
}
// newPaginator returns a struct which can be used to query an indexed bucket
// in pages.
func newPaginator(c kvdb.RCursor, reversed bool,
indexOffset, totalItems uint64) paginator {
return paginator{
cursor: c,
reversed: reversed,
indexOffset: indexOffset,
totalItems: totalItems,
}
}
// keyValueForIndex seeks our cursor to a given index and returns the key and
// value at that position.
func (p paginator) keyValueForIndex(index uint64) ([]byte, []byte) {
var keyIndex [8]byte
byteOrder.PutUint64(keyIndex[:], index)
return p.cursor.Seek(keyIndex[:])
}
// lastIndex returns the last value in our index, if our index is empty it
// returns 0.
func (p paginator) lastIndex() uint64 {
keyIndex, _ := p.cursor.Last()
if keyIndex == nil {
return 0
}
return byteOrder.Uint64(keyIndex)
}
// nextKey is a helper closure to determine what key we should use next when
// we are iterating, depending on whether we are iterating forwards or in
// reverse.
func (p paginator) nextKey() ([]byte, []byte) {
if p.reversed {
return p.cursor.Prev()
}
return p.cursor.Next()
}
// cursorStart gets the index key and value for the first item we are looking
// up, taking into account that we may be paginating in reverse. The index
// offset provided is *excusive* so we will start with the item after the offset
// for forwards queries, and the item before the index for backwards queries.
func (p paginator) cursorStart() ([]byte, []byte) {
indexKey, indexValue := p.keyValueForIndex(p.indexOffset + 1)
// If the query is specifying reverse iteration, then we must
// handle a few offset cases.
if p.reversed {
switch {
// This indicates the default case, where no offset was
// specified. In that case we just start from the last
// entry.
case p.indexOffset == 0:
indexKey, indexValue = p.cursor.Last()
// This indicates the offset being set to the very
// first entry. Since there are no entries before
// this offset, and the direction is reversed, we can
// return without adding any invoices to the response.
case p.indexOffset == 1:
return nil, nil
// If we have been given an index offset that is beyond our last
// index value, we just return the last indexed value in our set
// since we are querying in reverse. We do not cover the case
// where our index offset equals our last index value, because
// index offset is exclusive, so we would want to start at the
// value before our last index.
case p.indexOffset > p.lastIndex():
return p.cursor.Last()
// Otherwise we have an index offset which is within our set of
// indexed keys, and we want to start at the item before our
// offset. We seek to our index offset, then return the element
// before it. We do this rather than p.indexOffset-1 to account
// for indexes that have gaps.
default:
p.keyValueForIndex(p.indexOffset)
indexKey, indexValue = p.cursor.Prev()
}
}
return indexKey, indexValue
}
// query gets the start point for our index offset and iterates through keys
// in our index until we reach the total number of items required for the query
// or we run out of cursor values. This function takes a fetchAndAppend function
// which is responsible for looking up the entry at that index, adding the entry
// to its set of return items (if desired) and return a boolean which indicates
// whether the item was added. This is required to allow the paginator to
// determine when the response has the maximum number of required items.
func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error {
indexKey, indexValue := p.cursorStart()
var totalItems int
for ; indexKey != nil; indexKey, indexValue = p.nextKey() {
// If our current return payload exceeds the max number
// of invoices, then we'll exit now.
if uint64(totalItems) >= p.totalItems {
break
}
added, err := fetchAndAppend(indexKey, indexValue)
if err != nil {
return err
}
// If we added an item to our set in the latest fetch and append
// we increment our total count.
if added {
totalItems++
}
}
return nil
}
package channeldb
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
)
const (
// paymentSeqBlockSize is the block size used when we batch allocate
// payment sequences for future payments.
paymentSeqBlockSize = 1000
// paymentProgressLogInterval is the interval we use limiting the
// logging output of payment processing.
paymentProgressLogInterval = 30 * time.Second
)
var (
// ErrAlreadyPaid signals we have already paid this payment hash.
ErrAlreadyPaid = errors.New("invoice is already paid")
// ErrPaymentInFlight signals that payment for this payment hash is
// already "in flight" on the network.
ErrPaymentInFlight = errors.New("payment is in transition")
// ErrPaymentExists is returned when we try to initialize an already
// existing payment that is not failed.
ErrPaymentExists = errors.New("payment already exists")
// ErrPaymentInternal is returned when performing the payment has a
// conflicting state, such as,
// - payment has StatusSucceeded but remaining amount is not zero.
// - payment has StatusInitiated but remaining amount is zero.
// - payment has StatusFailed but remaining amount is zero.
ErrPaymentInternal = errors.New("internal error")
// ErrPaymentNotInitiated is returned if the payment wasn't initiated.
ErrPaymentNotInitiated = errors.New("payment isn't initiated")
// ErrPaymentAlreadySucceeded is returned in the event we attempt to
// change the status of a payment already succeeded.
ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded")
// ErrPaymentAlreadyFailed is returned in the event we attempt to alter
// a failed payment.
ErrPaymentAlreadyFailed = errors.New("payment has already failed")
// ErrUnknownPaymentStatus is returned when we do not recognize the
// existing state of a payment.
ErrUnknownPaymentStatus = errors.New("unknown payment status")
// ErrPaymentTerminal is returned if we attempt to alter a payment that
// already has reached a terminal condition.
ErrPaymentTerminal = errors.New("payment has reached terminal " +
"condition")
// ErrAttemptAlreadySettled is returned if we try to alter an already
// settled HTLC attempt.
ErrAttemptAlreadySettled = errors.New("attempt already settled")
// ErrAttemptAlreadyFailed is returned if we try to alter an already
// failed HTLC attempt.
ErrAttemptAlreadyFailed = errors.New("attempt already failed")
// ErrValueMismatch is returned if we try to register a non-MPP attempt
// with an amount that doesn't match the payment amount.
ErrValueMismatch = errors.New("attempted value doesn't match payment " +
"amount")
// ErrValueExceedsAmt is returned if we try to register an attempt that
// would take the total sent amount above the payment amount.
ErrValueExceedsAmt = errors.New("attempted value exceeds payment " +
"amount")
// ErrNonMPPayment is returned if we try to register an MPP attempt for
// a payment that already has a non-MPP attempt registered.
ErrNonMPPayment = errors.New("payment has non-MPP attempts")
// ErrMPPayment is returned if we try to register a non-MPP attempt for
// a payment that already has an MPP attempt registered.
ErrMPPayment = errors.New("payment has MPP attempts")
// ErrMPPRecordInBlindedPayment is returned if we try to register an
// attempt with an MPP record for a payment to a blinded path.
ErrMPPRecordInBlindedPayment = errors.New("blinded payment cannot " +
"contain MPP records")
// ErrBlindedPaymentTotalAmountMismatch is returned if we try to
// register an HTLC shard to a blinded route where the total amount
// doesn't match existing shards.
ErrBlindedPaymentTotalAmountMismatch = errors.New("blinded path " +
"total amount mismatch")
// ErrMPPPaymentAddrMismatch is returned if we try to register an MPP
// shard where the payment address doesn't match existing shards.
ErrMPPPaymentAddrMismatch = errors.New("payment address mismatch")
// ErrMPPTotalAmountMismatch is returned if we try to register an MPP
// shard where the total amount doesn't match existing shards.
ErrMPPTotalAmountMismatch = errors.New("mp payment total amount " +
"mismatch")
// ErrPaymentPendingSettled is returned when we try to add a new
// attempt to a payment that has at least one of its HTLCs settled.
ErrPaymentPendingSettled = errors.New("payment has settled htlcs")
// ErrPaymentPendingFailed is returned when we try to add a new attempt
// to a payment that already has a failure reason.
ErrPaymentPendingFailed = errors.New("payment has failure reason")
// ErrSentExceedsTotal is returned if the payment's current total sent
// amount exceed the total amount.
ErrSentExceedsTotal = errors.New("total sent exceeds total amount")
// errNoAttemptInfo is returned when no attempt info is stored yet.
errNoAttemptInfo = errors.New("unable to find attempt info for " +
"inflight payment")
// errNoSequenceNrIndex is returned when an attempt to lookup a payment
// index is made for a sequence number that is not indexed.
errNoSequenceNrIndex = errors.New("payment sequence number index " +
"does not exist")
)
// PaymentControl implements persistence for payments and payment attempts.
type PaymentControl struct {
paymentSeqMx sync.Mutex
currPaymentSeq uint64
storedPaymentSeq uint64
db *DB
}
// NewPaymentControl creates a new instance of the PaymentControl.
func NewPaymentControl(db *DB) *PaymentControl {
return &PaymentControl{
db: db,
}
}
// InitPayment checks or records the given PaymentCreationInfo with the DB,
// making sure it does not already exist as an in-flight payment. When this
// method returns successfully, the payment is guaranteed to be in the InFlight
// state.
func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
info *PaymentCreationInfo) error {
// Obtain a new sequence number for this payment. This is used
// to sort the payments in order of creation, and also acts as
// a unique identifier for each payment.
sequenceNum, err := p.nextPaymentSequence()
if err != nil {
return err
}
var b bytes.Buffer
if err := serializePaymentCreationInfo(&b, info); err != nil {
return err
}
infoBytes := b.Bytes()
var updateErr error
err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
prefetchPayment(tx, paymentHash)
bucket, err := createPaymentBucket(tx, paymentHash)
if err != nil {
return err
}
// Get the existing status of this payment, if any.
paymentStatus, err := fetchPaymentStatus(bucket)
switch {
// If no error is returned, it means we already have this
// payment. We'll check the status to decide whether we allow
// retrying the payment or return a specific error.
case err == nil:
if err := paymentStatus.initializable(); err != nil {
updateErr = err
return nil
}
// Otherwise, if the error is not `ErrPaymentNotInitiated`,
// we'll return the error.
case !errors.Is(err, ErrPaymentNotInitiated):
return err
}
// Before we set our new sequence number, we check whether this
// payment has a previously set sequence number and remove its
// index entry if it exists. This happens in the case where we
// have a previously attempted payment which was left in a state
// where we can retry.
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes != nil {
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
if err := indexBucket.Delete(seqBytes); err != nil {
return err
}
}
// Once we have obtained a sequence number, we add an entry
// to our index bucket which will map the sequence number to
// our payment identifier.
err = createPaymentIndexEntry(
tx, sequenceNum, info.PaymentIdentifier,
)
if err != nil {
return err
}
err = bucket.Put(paymentSequenceKey, sequenceNum)
if err != nil {
return err
}
// Add the payment info to the bucket, which contains the
// static information for this payment
err = bucket.Put(paymentCreationInfoKey, infoBytes)
if err != nil {
return err
}
// We'll delete any lingering HTLCs to start with, in case we
// are initializing a payment that was attempted earlier, but
// left in a state where we could retry.
err = bucket.DeleteNestedBucket(paymentHtlcsBucket)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
// Also delete any lingering failure info now that we are
// re-attempting.
return bucket.Delete(paymentFailInfoKey)
})
if err != nil {
return fmt.Errorf("unable to init payment: %w", err)
}
return updateErr
}
// DeleteFailedAttempts deletes all failed htlcs for a payment if configured
// by the PaymentControl db.
func (p *PaymentControl) DeleteFailedAttempts(hash lntypes.Hash) error {
if !p.db.keepFailedPaymentAttempts {
const failedHtlcsOnly = true
err := p.db.DeletePayment(hash, failedHtlcsOnly)
if err != nil {
return err
}
}
return nil
}
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
type paymentIndexType uint8
// paymentIndexTypeHash is a payment index type which indicates that we have
// created an index of payment sequence number to payment hash.
const paymentIndexTypeHash paymentIndexType = 0
// createPaymentIndexEntry creates a payment hash typed index for a payment. The
// index produced contains a payment index type (which can be used in future to
// signal different payment index types) and the payment identifier.
func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte,
id lntypes.Hash) error {
var b bytes.Buffer
if err := WriteElements(&b, paymentIndexTypeHash, id[:]); err != nil {
return err
}
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Put(sequenceNumber, b.Bytes())
}
// deserializePaymentIndex deserializes a payment index entry. This function
// currently only supports deserialization of payment hash indexes, and will
// fail for other types.
func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) {
var (
indexType paymentIndexType
paymentHash []byte
)
if err := ReadElements(r, &indexType, &paymentHash); err != nil {
return lntypes.Hash{}, err
}
// While we only have on payment index type, we do not need to use our
// index type to deserialize the index. However, we sanity check that
// this type is as expected, since we had to read it out anyway.
if indexType != paymentIndexTypeHash {
return lntypes.Hash{}, fmt.Errorf("unknown payment index "+
"type: %v", indexType)
}
hash, err := lntypes.MakeHash(paymentHash)
if err != nil {
return lntypes.Hash{}, err
}
return hash, nil
}
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
// DB.
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
attempt *HTLCAttemptInfo) (*MPPayment, error) {
// Serialize the information before opening the db transaction.
var a bytes.Buffer
err := serializeHTLCAttemptInfo(&a, attempt)
if err != nil {
return nil, err
}
htlcInfoBytes := a.Bytes()
htlcIDBytes := make([]byte, 8)
binary.BigEndian.PutUint64(htlcIDBytes, attempt.AttemptID)
var payment *MPPayment
err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err != nil {
return err
}
payment, err = fetchPayment(bucket)
if err != nil {
return err
}
// Check if registering a new attempt is allowed.
if err := payment.Registrable(); err != nil {
return err
}
// If the final hop has encrypted data, then we know this is a
// blinded payment. In blinded payments, MPP records are not set
// for split payments and the recipient is responsible for using
// a consistent PathID across the various encrypted data
// payloads that we received from them for this payment. All we
// need to check is that the total amount field for each HTLC
// in the split payment is correct.
isBlinded := len(attempt.Route.FinalHop().EncryptedData) != 0
// Make sure any existing shards match the new one with regards
// to MPP options.
mpp := attempt.Route.FinalHop().MPP
// MPP records should not be set for attempts to blinded paths.
if isBlinded && mpp != nil {
return ErrMPPRecordInBlindedPayment
}
for _, h := range payment.InFlightHTLCs() {
hMpp := h.Route.FinalHop().MPP
// If this is a blinded payment, then no existing HTLCs
// should have MPP records.
if isBlinded && hMpp != nil {
return ErrMPPRecordInBlindedPayment
}
// If this is a blinded payment, then we just need to
// check that the TotalAmtMsat field for this shard
// is equal to that of any other shard in the same
// payment.
if isBlinded {
if attempt.Route.FinalHop().TotalAmtMsat !=
h.Route.FinalHop().TotalAmtMsat {
//nolint:ll
return ErrBlindedPaymentTotalAmountMismatch
}
continue
}
switch {
// We tried to register a non-MPP attempt for a MPP
// payment.
case mpp == nil && hMpp != nil:
return ErrMPPayment
// We tried to register a MPP shard for a non-MPP
// payment.
case mpp != nil && hMpp == nil:
return ErrNonMPPayment
// Non-MPP payment, nothing more to validate.
case mpp == nil:
continue
}
// Check that MPP options match.
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
return ErrMPPPaymentAddrMismatch
}
if mpp.TotalMsat() != hMpp.TotalMsat() {
return ErrMPPTotalAmountMismatch
}
}
// If this is a non-MPP attempt, it must match the total amount
// exactly. Note that a blinded payment is considered an MPP
// attempt.
amt := attempt.Route.ReceiverAmt()
if !isBlinded && mpp == nil && amt != payment.Info.Value {
return ErrValueMismatch
}
// Ensure we aren't sending more than the total payment amount.
sentAmt, _ := payment.SentAmt()
if sentAmt+amt > payment.Info.Value {
return fmt.Errorf("%w: attempted=%v, payment amount="+
"%v", ErrValueExceedsAmt, sentAmt+amt,
payment.Info.Value)
}
htlcsBucket, err := bucket.CreateBucketIfNotExists(
paymentHtlcsBucket,
)
if err != nil {
return err
}
err = htlcsBucket.Put(
htlcBucketKey(htlcAttemptInfoKey, htlcIDBytes),
htlcInfoBytes,
)
if err != nil {
return err
}
// Retrieve attempt info for the notification.
payment, err = fetchPayment(bucket)
return err
})
if err != nil {
return nil, err
}
return payment, err
}
// SettleAttempt marks the given attempt settled with the preimage. If this is
// a multi shard payment, this might implicitly mean that the full payment
// succeeded.
//
// After invoking this method, InitPayment should always return an error to
// prevent us from making duplicate payments to the same payment hash. The
// provided preimage is atomically saved to the DB for record keeping.
func (p *PaymentControl) SettleAttempt(hash lntypes.Hash,
attemptID uint64, settleInfo *HTLCSettleInfo) (*MPPayment, error) {
var b bytes.Buffer
if err := serializeHTLCSettleInfo(&b, settleInfo); err != nil {
return nil, err
}
settleBytes := b.Bytes()
return p.updateHtlcKey(hash, attemptID, htlcSettleInfoKey, settleBytes)
}
// FailAttempt marks the given payment attempt failed.
func (p *PaymentControl) FailAttempt(hash lntypes.Hash,
attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) {
var b bytes.Buffer
if err := serializeHTLCFailInfo(&b, failInfo); err != nil {
return nil, err
}
failBytes := b.Bytes()
return p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes)
}
// updateHtlcKey updates a database key for the specified htlc.
func (p *PaymentControl) updateHtlcKey(paymentHash lntypes.Hash,
attemptID uint64, key, value []byte) (*MPPayment, error) {
aid := make([]byte, 8)
binary.BigEndian.PutUint64(aid, attemptID)
var payment *MPPayment
err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
payment = nil
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err != nil {
return err
}
p, err := fetchPayment(bucket)
if err != nil {
return err
}
// We can only update keys of in-flight payments. We allow
// updating keys even if the payment has reached a terminal
// condition, since the HTLC outcomes must still be updated.
if err := p.Status.updatable(); err != nil {
return err
}
htlcsBucket := bucket.NestedReadWriteBucket(paymentHtlcsBucket)
if htlcsBucket == nil {
return fmt.Errorf("htlcs bucket not found")
}
if htlcsBucket.Get(htlcBucketKey(htlcAttemptInfoKey, aid)) == nil {
return fmt.Errorf("HTLC with ID %v not registered",
attemptID)
}
// Make sure the shard is not already failed or settled.
if htlcsBucket.Get(htlcBucketKey(htlcFailInfoKey, aid)) != nil {
return ErrAttemptAlreadyFailed
}
if htlcsBucket.Get(htlcBucketKey(htlcSettleInfoKey, aid)) != nil {
return ErrAttemptAlreadySettled
}
// Add or update the key for this htlc.
err = htlcsBucket.Put(htlcBucketKey(key, aid), value)
if err != nil {
return err
}
// Retrieve attempt info for the notification.
payment, err = fetchPayment(bucket)
return err
})
if err != nil {
return nil, err
}
return payment, err
}
// Fail transitions a payment into the Failed state, and records the reason the
// payment failed. After invoking this method, InitPayment should return nil on
// its next call for this payment hash, allowing the switch to make a
// subsequent payment.
func (p *PaymentControl) Fail(paymentHash lntypes.Hash,
reason FailureReason) (*MPPayment, error) {
var (
updateErr error
payment *MPPayment
)
err := kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
payment = nil
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
if err == ErrPaymentNotInitiated {
updateErr = ErrPaymentNotInitiated
return nil
} else if err != nil {
return err
}
// We mark the payment as failed as long as it is known. This
// lets the last attempt to fail with a terminal write its
// failure to the PaymentControl without synchronizing with
// other attempts.
_, err = fetchPaymentStatus(bucket)
if errors.Is(err, ErrPaymentNotInitiated) {
updateErr = ErrPaymentNotInitiated
return nil
} else if err != nil {
return err
}
// Put the failure reason in the bucket for record keeping.
v := []byte{byte(reason)}
err = bucket.Put(paymentFailInfoKey, v)
if err != nil {
return err
}
// Retrieve attempt info for the notification, if available.
payment, err = fetchPayment(bucket)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return payment, updateErr
}
// FetchPayment returns information about a payment from the database.
func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
*MPPayment, error) {
var payment *MPPayment
err := kvdb.View(p.db, func(tx kvdb.RTx) error {
prefetchPayment(tx, paymentHash)
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return err
}
payment, err = fetchPayment(bucket)
return err
}, func() {
payment = nil
})
if err != nil {
return nil, err
}
return payment, nil
}
// prefetchPayment attempts to prefetch as much of the payment as possible to
// reduce DB roundtrips.
func prefetchPayment(tx kvdb.RTx, paymentHash lntypes.Hash) {
rb := kvdb.RootBucket(tx)
kvdb.Prefetch(
rb,
[]string{
// Prefetch all keys in the payment's bucket.
string(paymentsRootBucket),
string(paymentHash[:]),
},
[]string{
// Prefetch all keys in the payment's htlc bucket.
string(paymentsRootBucket),
string(paymentHash[:]),
string(paymentHtlcsBucket),
},
)
}
// createPaymentBucket creates or fetches the sub-bucket assigned to this
// payment hash.
func createPaymentBucket(tx kvdb.RwTx, paymentHash lntypes.Hash) (
kvdb.RwBucket, error) {
payments, err := tx.CreateTopLevelBucket(paymentsRootBucket)
if err != nil {
return nil, err
}
return payments.CreateBucketIfNotExists(paymentHash[:])
}
// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If
// the bucket does not exist, it returns ErrPaymentNotInitiated.
func fetchPaymentBucket(tx kvdb.RTx, paymentHash lntypes.Hash) (
kvdb.RBucket, error) {
payments := tx.ReadBucket(paymentsRootBucket)
if payments == nil {
return nil, ErrPaymentNotInitiated
}
bucket := payments.NestedReadBucket(paymentHash[:])
if bucket == nil {
return nil, ErrPaymentNotInitiated
}
return bucket, nil
}
// fetchPaymentBucketUpdate is identical to fetchPaymentBucket, but it returns a
// bucket that can be written to.
func fetchPaymentBucketUpdate(tx kvdb.RwTx, paymentHash lntypes.Hash) (
kvdb.RwBucket, error) {
payments := tx.ReadWriteBucket(paymentsRootBucket)
if payments == nil {
return nil, ErrPaymentNotInitiated
}
bucket := payments.NestedReadWriteBucket(paymentHash[:])
if bucket == nil {
return nil, ErrPaymentNotInitiated
}
return bucket, nil
}
// nextPaymentSequence returns the next sequence number to store for a new
// payment.
func (p *PaymentControl) nextPaymentSequence() ([]byte, error) {
p.paymentSeqMx.Lock()
defer p.paymentSeqMx.Unlock()
// Set a new upper bound in the DB every 1000 payments to avoid
// conflicts on the sequence when using etcd.
if p.currPaymentSeq == p.storedPaymentSeq {
var currPaymentSeq, newUpperBound uint64
if err := kvdb.Update(p.db.Backend, func(tx kvdb.RwTx) error {
paymentsBucket, err := tx.CreateTopLevelBucket(
paymentsRootBucket,
)
if err != nil {
return err
}
currPaymentSeq = paymentsBucket.Sequence()
newUpperBound = currPaymentSeq + paymentSeqBlockSize
return paymentsBucket.SetSequence(newUpperBound)
}, func() {}); err != nil {
return nil, err
}
// We lazy initialize the cached currPaymentSeq here using the
// first nextPaymentSequence() call. This if statement will auto
// initialize our stored currPaymentSeq, since by default both
// this variable and storedPaymentSeq are zero which in turn
// will have us fetch the current values from the DB.
if p.currPaymentSeq == 0 {
p.currPaymentSeq = currPaymentSeq
}
p.storedPaymentSeq = newUpperBound
}
p.currPaymentSeq++
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, p.currPaymentSeq)
return b, nil
}
// fetchPaymentStatus fetches the payment status of the payment. If the payment
// isn't found, it will return error `ErrPaymentNotInitiated`.
func fetchPaymentStatus(bucket kvdb.RBucket) (PaymentStatus, error) {
// Creation info should be set for all payments, regardless of state.
// If not, it is unknown.
if bucket.Get(paymentCreationInfoKey) == nil {
return 0, ErrPaymentNotInitiated
}
payment, err := fetchPayment(bucket)
if err != nil {
return 0, err
}
return payment.Status, nil
}
// FetchInFlightPayments returns all payments with status InFlight.
func (p *PaymentControl) FetchInFlightPayments() ([]*MPPayment, error) {
var (
inFlights []*MPPayment
start = time.Now()
lastLogTime = time.Now()
processedCount int
)
err := kvdb.View(p.db, func(tx kvdb.RTx) error {
payments := tx.ReadBucket(paymentsRootBucket)
if payments == nil {
return nil
}
return payments.ForEach(func(k, _ []byte) error {
bucket := payments.NestedReadBucket(k)
if bucket == nil {
return fmt.Errorf("non bucket element")
}
p, err := fetchPayment(bucket)
if err != nil {
return err
}
processedCount++
if time.Since(lastLogTime) >=
paymentProgressLogInterval {
log.Debugf("Scanning inflight payments "+
"(in progress), processed %d, last "+
"processed payment: %v", processedCount,
p.Info)
lastLogTime = time.Now()
}
// Skip the payment if it's terminated.
if p.Terminated() {
return nil
}
inFlights = append(inFlights, p)
return nil
})
}, func() {
inFlights = nil
})
if err != nil {
return nil, err
}
elapsed := time.Since(start)
log.Debugf("Completed scanning for inflight payments: "+
"total_processed=%d, found_inflight=%d, elapsed=%v",
processedCount, len(inFlights),
elapsed.Round(time.Millisecond))
return inFlights, nil
}
package channeldb
import "fmt"
// PaymentStatus represent current status of payment.
type PaymentStatus byte
const (
// NOTE: PaymentStatus = 0 was previously used for status unknown and
// is now deprecated.
// StatusInitiated is the status where a payment has just been
// initiated.
StatusInitiated PaymentStatus = 1
// StatusInFlight is the status where a payment has been initiated, but
// a response has not been received.
StatusInFlight PaymentStatus = 2
// StatusSucceeded is the status where a payment has been initiated and
// the payment was completed successfully.
StatusSucceeded PaymentStatus = 3
// StatusFailed is the status where a payment has been initiated and a
// failure result has come back.
StatusFailed PaymentStatus = 4
)
// errPaymentStatusUnknown is returned when a payment has an unknown status.
var errPaymentStatusUnknown = fmt.Errorf("unknown payment status")
// String returns readable representation of payment status.
func (ps PaymentStatus) String() string {
switch ps {
case StatusInitiated:
return "Initiated"
case StatusInFlight:
return "In Flight"
case StatusSucceeded:
return "Succeeded"
case StatusFailed:
return "Failed"
default:
return "Unknown"
}
}
// initializable returns an error to specify whether initiating the payment
// with its current status is allowed. A payment can only be initialized if it
// hasn't been created yet or already failed.
func (ps PaymentStatus) initializable() error {
switch ps {
// The payment has been created already. We will disallow creating it
// again in case other goroutines have already been creating HTLCs for
// it.
case StatusInitiated:
return ErrPaymentExists
// We already have an InFlight payment on the network. We will disallow
// any new payments.
case StatusInFlight:
return ErrPaymentInFlight
// The payment has been attempted and is succeeded so we won't allow
// creating it again.
case StatusSucceeded:
return ErrAlreadyPaid
// We allow retrying failed payments.
case StatusFailed:
return nil
default:
return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps)
}
}
// removable returns an error to specify whether deleting the payment with its
// current status is allowed. A payment cannot be safely deleted if it has
// inflight HTLCs.
func (ps PaymentStatus) removable() error {
switch ps {
// The payment has been created but has no HTLCs and can be removed.
case StatusInitiated:
return nil
// There are still inflight HTLCs and the payment needs to wait for the
// final outcomes.
case StatusInFlight:
return ErrPaymentInFlight
// The payment has been attempted and is succeeded and is allowed to be
// removed.
case StatusSucceeded:
return nil
// Failed payments are allowed to be removed.
case StatusFailed:
return nil
default:
return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps)
}
}
// updatable returns an error to specify whether the payment's HTLCs can be
// updated. A payment can update its HTLCs when it has inflight HTLCs.
func (ps PaymentStatus) updatable() error {
switch ps {
// Newly created payments can be updated.
case StatusInitiated:
return nil
// Inflight payments can be updated.
case StatusInFlight:
return nil
// If the payment has a terminal condition, we won't allow any updates.
case StatusSucceeded:
return ErrPaymentAlreadySucceeded
case StatusFailed:
return ErrPaymentAlreadyFailed
default:
return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps)
}
}
// decidePaymentStatus uses the payment's DB state to determine a memory status
// that's used by the payment router to decide following actions.
// Together, we use four variables to determine the payment's status,
// - inflight: whether there are any pending HTLCs.
// - settled: whether any of the HTLCs has been settled.
// - htlc failed: whether any of the HTLCs has been failed.
// - payment failed: whether the payment has been marked as failed.
//
// Based on the above variables, we derive the status using the following
// table,
// | inflight | settled | htlc failed | payment failed | status |
// |:--------:|:-------:|:-----------:|:--------------:|:--------------------:|
// | true | true | true | true | StatusInFlight |
// | true | true | true | false | StatusInFlight |
// | true | true | false | true | StatusInFlight |
// | true | true | false | false | StatusInFlight |
// | true | false | true | true | StatusInFlight |
// | true | false | true | false | StatusInFlight |
// | true | false | false | true | StatusInFlight |
// | true | false | false | false | StatusInFlight |
// | false | true | true | true | StatusSucceeded |
// | false | true | true | false | StatusSucceeded |
// | false | true | false | true | StatusSucceeded |
// | false | true | false | false | StatusSucceeded |
// | false | false | true | true | StatusFailed |
// | false | false | true | false | StatusInFlight |
// | false | false | false | true | StatusFailed |
// | false | false | false | false | StatusInitiated |
//
// When `inflight`, `settled`, `htlc failed`, and `payment failed` are false,
// this indicates the payment is newly created and hasn't made any HTLCs yet.
// When `inflight` and `settled` are false, `htlc failed` is true yet `payment
// failed` is false, this indicates all the payment's HTLCs have occurred a
// temporarily failure and the payment is still in-flight.
func decidePaymentStatus(htlcs []HTLCAttempt,
reason *FailureReason) (PaymentStatus, error) {
var (
inflight bool
htlcSettled bool
htlcFailed bool
paymentFailed bool
)
// If we have a failure reason, the payment is failed.
if reason != nil {
paymentFailed = true
}
// Go through all HTLCs for this payment, check whether we have any
// settled HTLC, and any still in-flight.
for _, h := range htlcs {
if h.Failure != nil {
htlcFailed = true
continue
}
if h.Settle != nil {
htlcSettled = true
continue
}
// If any of the HTLCs are not failed nor settled, we
// still have inflight HTLCs.
inflight = true
}
// Use the DB state to determine the status of the payment.
switch {
// If we have inflight HTLCs, no matter we have settled or failed
// HTLCs, or the payment failed, we still consider it inflight so we
// inform upper systems to wait for the results.
case inflight:
return StatusInFlight, nil
// If we have no in-flight HTLCs, and at least one of the HTLCs is
// settled, the payment succeeded.
//
// NOTE: when reaching this case, paymentFailed could be true, which
// means we have a conflicting state for this payment. We choose to
// mark the payment as succeeded because it's the receiver's
// responsibility to only settle the payment iff all HTLCs are
// received.
case htlcSettled:
return StatusSucceeded, nil
// If we have no in-flight HTLCs, and the payment failure is set, the
// payment is considered failed.
//
// NOTE: when reaching this case, settled must be false.
case paymentFailed:
return StatusFailed, nil
// If we have no in-flight HTLCs, yet the payment is NOT failed, it
// means all the HTLCs are failed. In this case we can attempt more
// HTLCs.
//
// NOTE: when reaching this case, both settled and paymentFailed must
// be false.
case htlcFailed:
return StatusInFlight, nil
// If none of the HTLCs is either settled or failed, and we have no
// inflight HTLCs, this means the payment has no HTLCs created yet.
//
// NOTE: when reaching this case, both settled and paymentFailed must
// be false.
case !htlcFailed:
return StatusInitiated, nil
// Otherwise an impossible state is reached.
//
// NOTE: we should never end up here.
default:
log.Error("Impossible payment state reached")
return 0, fmt.Errorf("%w: payment is corrupted",
errPaymentStatusUnknown)
}
}
package channeldb
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sort"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// paymentsRootBucket is the name of the top-level bucket within the
// database that stores all data related to payments. Within this
// bucket, each payment hash its own sub-bucket keyed by its payment
// hash.
//
// Bucket hierarchy:
//
// root-bucket
// |
// |-- <paymenthash>
// | |--sequence-key: <sequence number>
// | |--creation-info-key: <creation info>
// | |--fail-info-key: <(optional) fail info>
// | |
// | |--payment-htlcs-bucket (shard-bucket)
// | | |
// | | |-- ai<htlc attempt ID>: <htlc attempt info>
// | | |-- si<htlc attempt ID>: <(optional) settle info>
// | | |-- fi<htlc attempt ID>: <(optional) fail info>
// | | |
// | | ...
// | |
// | |
// | |--duplicate-bucket (only for old, completed payments)
// | |
// | |-- <seq-num>
// | | |--sequence-key: <sequence number>
// | | |--creation-info-key: <creation info>
// | | |--ai: <attempt info>
// | | |--si: <settle info>
// | | |--fi: <fail info>
// | |
// | |-- <seq-num>
// | | |
// | ... ...
// |
// |-- <paymenthash>
// | |
// | ...
// ...
//
paymentsRootBucket = []byte("payments-root-bucket")
// paymentSequenceKey is a key used in the payment's sub-bucket to
// store the sequence number of the payment.
paymentSequenceKey = []byte("payment-sequence-key")
// paymentCreationInfoKey is a key used in the payment's sub-bucket to
// store the creation info of the payment.
paymentCreationInfoKey = []byte("payment-creation-info")
// paymentHtlcsBucket is a bucket where we'll store the information
// about the HTLCs that were attempted for a payment.
paymentHtlcsBucket = []byte("payment-htlcs-bucket")
// htlcAttemptInfoKey is the key used as the prefix of an HTLC attempt
// to store the info about the attempt that was done for the HTLC in
// question. The HTLC attempt ID is concatenated at the end.
htlcAttemptInfoKey = []byte("ai")
// htlcSettleInfoKey is the key used as the prefix of an HTLC attempt
// settle info, if any. The HTLC attempt ID is concatenated at the end.
htlcSettleInfoKey = []byte("si")
// htlcFailInfoKey is the key used as the prefix of an HTLC attempt
// failure information, if any.The HTLC attempt ID is concatenated at
// the end.
htlcFailInfoKey = []byte("fi")
// paymentFailInfoKey is a key used in the payment's sub-bucket to
// store information about the reason a payment failed.
paymentFailInfoKey = []byte("payment-fail-info")
// paymentsIndexBucket is the name of the top-level bucket within the
// database that stores an index of payment sequence numbers to its
// payment hash.
// payments-sequence-index-bucket
// |--<sequence-number>: <payment hash>
// |--...
// |--<sequence-number>: <payment hash>
paymentsIndexBucket = []byte("payments-index-bucket")
)
var (
// ErrNoSequenceNumber is returned if we look up a payment which does
// not have a sequence number.
ErrNoSequenceNumber = errors.New("sequence number not found")
// ErrDuplicateNotFound is returned when we lookup a payment by its
// index and cannot find a payment with a matching sequence number.
ErrDuplicateNotFound = errors.New("duplicate payment not found")
// ErrNoDuplicateBucket is returned when we expect to find duplicates
// when looking up a payment from its index, but the payment does not
// have any.
ErrNoDuplicateBucket = errors.New("expected duplicate bucket")
// ErrNoDuplicateNestedBucket is returned if we do not find duplicate
// payments in their own sub-bucket.
ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " +
"found")
)
// FailureReason encodes the reason a payment ultimately failed.
type FailureReason byte
const (
// FailureReasonTimeout indicates that the payment did timeout before a
// successful payment attempt was made.
FailureReasonTimeout FailureReason = 0
// FailureReasonNoRoute indicates no successful route to the
// destination was found during path finding.
FailureReasonNoRoute FailureReason = 1
// FailureReasonError indicates that an unexpected error happened during
// payment.
FailureReasonError FailureReason = 2
// FailureReasonPaymentDetails indicates that either the hash is unknown
// or the final cltv delta or amount is incorrect.
FailureReasonPaymentDetails FailureReason = 3
// FailureReasonInsufficientBalance indicates that we didn't have enough
// balance to complete the payment.
FailureReasonInsufficientBalance FailureReason = 4
// FailureReasonCanceled indicates that the payment was canceled by the
// user.
FailureReasonCanceled FailureReason = 5
// TODO(joostjager): Add failure reasons for:
// LocalLiquidityInsufficient, RemoteCapacityInsufficient.
)
// Error returns a human-readable error string for the FailureReason.
func (r FailureReason) Error() string {
return r.String()
}
// String returns a human-readable FailureReason.
func (r FailureReason) String() string {
switch r {
case FailureReasonTimeout:
return "timeout"
case FailureReasonNoRoute:
return "no_route"
case FailureReasonError:
return "error"
case FailureReasonPaymentDetails:
return "incorrect_payment_details"
case FailureReasonInsufficientBalance:
return "insufficient_balance"
case FailureReasonCanceled:
return "canceled"
}
return "unknown"
}
// PaymentCreationInfo is the information necessary to have ready when
// initiating a payment, moving it into state InFlight.
type PaymentCreationInfo struct {
// PaymentIdentifier is the hash this payment is paying to in case of
// non-AMP payments, and the SetID for AMP payments.
PaymentIdentifier lntypes.Hash
// Value is the amount we are paying.
Value lnwire.MilliSatoshi
// CreationTime is the time when this payment was initiated.
CreationTime time.Time
// PaymentRequest is the full payment request, if any.
PaymentRequest []byte
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message only and therefore do not affect the onion payload size.
FirstHopCustomRecords lnwire.CustomRecords
}
// String returns a human-readable description of the payment creation info.
func (p *PaymentCreationInfo) String() string {
return fmt.Sprintf("payment_id=%v, amount=%v, created_at=%v",
p.PaymentIdentifier, p.Value, p.CreationTime)
}
// htlcBucketKey creates a composite key from prefix and id where the result is
// simply the two concatenated.
func htlcBucketKey(prefix, id []byte) []byte {
key := make([]byte, len(prefix)+len(id))
copy(key, prefix)
copy(key[len(prefix):], id)
return key
}
// FetchPayments returns all sent payments found in the DB.
//
// nolint: dupl
func (d *DB) FetchPayments() ([]*MPPayment, error) {
var payments []*MPPayment
err := kvdb.View(d, func(tx kvdb.RTx) error {
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
return paymentsBucket.ForEach(func(k, v []byte) error {
bucket := paymentsBucket.NestedReadBucket(k)
if bucket == nil {
// We only expect sub-buckets to be found in
// this top-level bucket.
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
p, err := fetchPayment(bucket)
if err != nil {
return err
}
payments = append(payments, p)
// For older versions of lnd, duplicate payments to a
// payment has was possible. These will be found in a
// sub-bucket indexed by their sequence number if
// available.
duplicatePayments, err := fetchDuplicatePayments(bucket)
if err != nil {
return err
}
payments = append(payments, duplicatePayments...)
return nil
})
}, func() {
payments = nil
})
if err != nil {
return nil, err
}
// Before returning, sort the payments by their sequence number.
sort.Slice(payments, func(i, j int) bool {
return payments[i].SequenceNum < payments[j].SequenceNum
})
return payments, nil
}
func fetchCreationInfo(bucket kvdb.RBucket) (*PaymentCreationInfo, error) {
b := bucket.Get(paymentCreationInfoKey)
if b == nil {
return nil, fmt.Errorf("creation info not found")
}
r := bytes.NewReader(b)
return deserializePaymentCreationInfo(r)
}
func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) {
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return nil, fmt.Errorf("sequence number not found")
}
sequenceNum := binary.BigEndian.Uint64(seqBytes)
// Get the PaymentCreationInfo.
creationInfo, err := fetchCreationInfo(bucket)
if err != nil {
return nil, err
}
var htlcs []HTLCAttempt
htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket)
if htlcsBucket != nil {
// Get the payment attempts. This can be empty.
htlcs, err = fetchHtlcAttempts(htlcsBucket)
if err != nil {
return nil, err
}
}
// Get failure reason if available.
var failureReason *FailureReason
b := bucket.Get(paymentFailInfoKey)
if b != nil {
reason := FailureReason(b[0])
failureReason = &reason
}
// Create a new payment.
payment := &MPPayment{
SequenceNum: sequenceNum,
Info: creationInfo,
HTLCs: htlcs,
FailureReason: failureReason,
}
// Set its state and status.
if err := payment.setState(); err != nil {
return nil, err
}
return payment, nil
}
// fetchHtlcAttempts retrieves all htlc attempts made for the payment found in
// the given bucket.
func fetchHtlcAttempts(bucket kvdb.RBucket) ([]HTLCAttempt, error) {
htlcsMap := make(map[uint64]*HTLCAttempt)
attemptInfoCount := 0
err := bucket.ForEach(func(k, v []byte) error {
aid := byteOrder.Uint64(k[len(k)-8:])
if _, ok := htlcsMap[aid]; !ok {
htlcsMap[aid] = &HTLCAttempt{}
}
var err error
switch {
case bytes.HasPrefix(k, htlcAttemptInfoKey):
attemptInfo, err := readHtlcAttemptInfo(v)
if err != nil {
return err
}
attemptInfo.AttemptID = aid
htlcsMap[aid].HTLCAttemptInfo = *attemptInfo
attemptInfoCount++
case bytes.HasPrefix(k, htlcSettleInfoKey):
htlcsMap[aid].Settle, err = readHtlcSettleInfo(v)
if err != nil {
return err
}
case bytes.HasPrefix(k, htlcFailInfoKey):
htlcsMap[aid].Failure, err = readHtlcFailInfo(v)
if err != nil {
return err
}
default:
return fmt.Errorf("unknown htlc attempt key")
}
return nil
})
if err != nil {
return nil, err
}
// Sanity check that all htlcs have an attempt info.
if attemptInfoCount != len(htlcsMap) {
return nil, errNoAttemptInfo
}
keys := make([]uint64, len(htlcsMap))
i := 0
for k := range htlcsMap {
keys[i] = k
i++
}
// Sort HTLC attempts by their attempt ID. This is needed because in the
// DB we store the attempts with keys prefixed by their status which
// changes order (groups them together by status).
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
htlcs := make([]HTLCAttempt, len(htlcsMap))
for i, key := range keys {
htlcs[i] = *htlcsMap[key]
}
return htlcs, nil
}
// readHtlcAttemptInfo reads the payment attempt info for this htlc.
func readHtlcAttemptInfo(b []byte) (*HTLCAttemptInfo, error) {
r := bytes.NewReader(b)
return deserializeHTLCAttemptInfo(r)
}
// readHtlcSettleInfo reads the settle info for the htlc. If the htlc isn't
// settled, nil is returned.
func readHtlcSettleInfo(b []byte) (*HTLCSettleInfo, error) {
r := bytes.NewReader(b)
return deserializeHTLCSettleInfo(r)
}
// readHtlcFailInfo reads the failure info for the htlc. If the htlc hasn't
// failed, nil is returned.
func readHtlcFailInfo(b []byte) (*HTLCFailInfo, error) {
r := bytes.NewReader(b)
return deserializeHTLCFailInfo(r)
}
// fetchFailedHtlcKeys retrieves the bucket keys of all failed HTLCs of a
// payment bucket.
func fetchFailedHtlcKeys(bucket kvdb.RBucket) ([][]byte, error) {
htlcsBucket := bucket.NestedReadBucket(paymentHtlcsBucket)
var htlcs []HTLCAttempt
var err error
if htlcsBucket != nil {
htlcs, err = fetchHtlcAttempts(htlcsBucket)
if err != nil {
return nil, err
}
}
// Now iterate though them and save the bucket keys for the failed
// HTLCs.
var htlcKeys [][]byte
for _, h := range htlcs {
if h.Failure == nil {
continue
}
htlcKeyBytes := make([]byte, 8)
binary.BigEndian.PutUint64(htlcKeyBytes, h.AttemptID)
htlcKeys = append(htlcKeys, htlcKeyBytes)
}
return htlcKeys, nil
}
// PaymentsQuery represents a query to the payments database starting or ending
// at a certain offset index. The number of retrieved records can be limited.
type PaymentsQuery struct {
// IndexOffset determines the starting point of the payments query and
// is always exclusive. In normal order, the query starts at the next
// higher (available) index compared to IndexOffset. In reversed order,
// the query ends at the next lower (available) index compared to the
// IndexOffset. In the case of a zero index_offset, the query will start
// with the oldest payment when paginating forwards, or will end with
// the most recent payment when paginating backwards.
IndexOffset uint64
// MaxPayments is the maximal number of payments returned in the
// payments query.
MaxPayments uint64
// Reversed gives a meaning to the IndexOffset. If reversed is set to
// true, the query will fetch payments with indices lower than the
// IndexOffset, otherwise, it will return payments with indices greater
// than the IndexOffset.
Reversed bool
// If IncludeIncomplete is true, then return payments that have not yet
// fully completed. This means that pending payments, as well as failed
// payments will show up if this field is set to true.
IncludeIncomplete bool
// CountTotal indicates that all payments currently present in the
// payment index (complete and incomplete) should be counted.
CountTotal bool
// CreationDateStart, expressed in Unix seconds, if set, filters out
// all payments with a creation date greater than or equal to it.
CreationDateStart int64
// CreationDateEnd, expressed in Unix seconds, if set, filters out all
// payments with a creation date less than or equal to it.
CreationDateEnd int64
}
// PaymentsResponse contains the result of a query to the payments database.
// It includes the set of payments that match the query and integers which
// represent the index of the first and last item returned in the series of
// payments. These integers allow callers to resume their query in the event
// that the query's response exceeds the max number of returnable events.
type PaymentsResponse struct {
// Payments is the set of payments returned from the database for the
// PaymentsQuery.
Payments []*MPPayment
// FirstIndexOffset is the index of the first element in the set of
// returned MPPayments. Callers can use this to resume their query
// in the event that the slice has too many events to fit into a single
// response. The offset can be used to continue reverse pagination.
FirstIndexOffset uint64
// LastIndexOffset is the index of the last element in the set of
// returned MPPayments. Callers can use this to resume their query
// in the event that the slice has too many events to fit into a single
// response. The offset can be used to continue forward pagination.
LastIndexOffset uint64
// TotalCount represents the total number of payments that are currently
// stored in the payment database. This will only be set if the
// CountTotal field in the query was set to true.
TotalCount uint64
}
// QueryPayments is a query to the payments database which is restricted
// to a subset of payments by the payments query, containing an offset
// index and a maximum number of returned payments.
func (d *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
var resp PaymentsResponse
if err := kvdb.View(d, func(tx kvdb.RTx) error {
// Get the root payments bucket.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}
// Get the index bucket which maps sequence number -> payment
// hash and duplicate bool. If we have a payments bucket, we
// should have an indexes bucket as well.
indexes := tx.ReadBucket(paymentsIndexBucket)
if indexes == nil {
return fmt.Errorf("index bucket does not exist")
}
// accumulatePayments gets payments with the sequence number
// and hash provided and adds them to our list of payments if
// they meet the criteria of our query. It returns the number
// of payments that were added.
accumulatePayments := func(sequenceKey, hash []byte) (bool,
error) {
r := bytes.NewReader(hash)
paymentHash, err := deserializePaymentIndex(r)
if err != nil {
return false, err
}
payment, err := fetchPaymentWithSequenceNumber(
tx, paymentHash, sequenceKey,
)
if err != nil {
return false, err
}
// To keep compatibility with the old API, we only
// return non-succeeded payments if requested.
if payment.Status != StatusSucceeded &&
!query.IncludeIncomplete {
return false, err
}
// Get the creation time in Unix seconds, this always
// rounds down the nanoseconds to full seconds.
createTime := payment.Info.CreationTime.Unix()
// Skip any payments that were created before the
// specified time.
if createTime < query.CreationDateStart {
return false, nil
}
// Skip any payments that were created after the
// specified time.
if query.CreationDateEnd != 0 &&
createTime > query.CreationDateEnd {
return false, nil
}
// At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range.
resp.Payments = append(resp.Payments, payment)
return true, nil
}
// Create a paginator which reads from our sequence index bucket
// with the parameters provided by the payments query.
paginator := newPaginator(
indexes.ReadCursor(), query.Reversed, query.IndexOffset,
query.MaxPayments,
)
// Run a paginated query, adding payments to our response.
if err := paginator.query(accumulatePayments); err != nil {
return err
}
// Counting the total number of payments is expensive, since we
// literally have to traverse the cursor linearly, which can
// take quite a while. So it's an optional query parameter.
if query.CountTotal {
var (
totalPayments uint64
err error
)
countFn := func(_, _ []byte) error {
totalPayments++
return nil
}
// In non-boltdb database backends, there's a faster
// ForAll query that allows for batch fetching items.
if fastBucket, ok := indexes.(kvdb.ExtendedRBucket); ok {
err = fastBucket.ForAll(countFn)
} else {
err = indexes.ForEach(countFn)
}
if err != nil {
return fmt.Errorf("error counting payments: %w",
err)
}
resp.TotalCount = totalPayments
}
return nil
}, func() {
resp = PaymentsResponse{}
}); err != nil {
return resp, err
}
// Need to swap the payments slice order if reversed order.
if query.Reversed {
for l, r := 0, len(resp.Payments)-1; l < r; l, r = l+1, r-1 {
resp.Payments[l], resp.Payments[r] =
resp.Payments[r], resp.Payments[l]
}
}
// Set the first and last index of the returned payments so that the
// caller can resume from this point later on.
if len(resp.Payments) > 0 {
resp.FirstIndexOffset = resp.Payments[0].SequenceNum
resp.LastIndexOffset =
resp.Payments[len(resp.Payments)-1].SequenceNum
}
return resp, nil
}
// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
// *and* sequence number provided from the database. This is required because
// we previously had more than one payment per hash, so we have multiple indexes
// pointing to a single payment; we want to retrieve the correct one.
func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash,
sequenceNumber []byte) (*MPPayment, error) {
// We can now lookup the payment keyed by its hash in
// the payments root bucket.
bucket, err := fetchPaymentBucket(tx, paymentHash)
if err != nil {
return nil, err
}
// A single payment hash can have multiple payments associated with it.
// We lookup our sequence number first, to determine whether this is
// the payment we are actually looking for.
seqBytes := bucket.Get(paymentSequenceKey)
if seqBytes == nil {
return nil, ErrNoSequenceNumber
}
// If this top level payment has the sequence number we are looking for,
// return it.
if bytes.Equal(seqBytes, sequenceNumber) {
return fetchPayment(bucket)
}
// If we were not looking for the top level payment, we are looking for
// one of our duplicate payments. We need to iterate through the seq
// numbers in this bucket to find the correct payments. If we do not
// find a duplicate payments bucket here, something is wrong.
dup := bucket.NestedReadBucket(duplicatePaymentsBucket)
if dup == nil {
return nil, ErrNoDuplicateBucket
}
var duplicatePayment *MPPayment
err = dup.ForEach(func(k, v []byte) error {
subBucket := dup.NestedReadBucket(k)
if subBucket == nil {
// We one bucket for each duplicate to be found.
return ErrNoDuplicateNestedBucket
}
seqBytes := subBucket.Get(duplicatePaymentSequenceKey)
if seqBytes == nil {
return err
}
// If this duplicate payment is not the sequence number we are
// looking for, we can continue.
if !bytes.Equal(seqBytes, sequenceNumber) {
return nil
}
duplicatePayment, err = fetchDuplicatePayment(subBucket)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// If none of the duplicate payments matched our sequence number, we
// failed to find the payment with this sequence number; something is
// wrong.
if duplicatePayment == nil {
return nil, ErrDuplicateNotFound
}
return duplicatePayment, nil
}
// DeletePayment deletes a payment from the DB given its payment hash. If
// failedHtlcsOnly is set, only failed HTLC attempts of the payment will be
// deleted.
func (d *DB) DeletePayment(paymentHash lntypes.Hash,
failedHtlcsOnly bool) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error {
payments := tx.ReadWriteBucket(paymentsRootBucket)
if payments == nil {
return nil
}
bucket := payments.NestedReadWriteBucket(paymentHash[:])
if bucket == nil {
return fmt.Errorf("non bucket element in payments " +
"bucket")
}
// If the status is InFlight, we cannot safely delete
// the payment information, so we return early.
paymentStatus, err := fetchPaymentStatus(bucket)
if err != nil {
return err
}
// If the payment has inflight HTLCs, we cannot safely delete
// the payment information, so we return an error.
if err := paymentStatus.removable(); err != nil {
return fmt.Errorf("payment '%v' has inflight HTLCs"+
"and therefore cannot be deleted: %w",
paymentHash.String(), err)
}
// Delete the failed HTLC attempts we found.
if failedHtlcsOnly {
toDelete, err := fetchFailedHtlcKeys(bucket)
if err != nil {
return err
}
htlcsBucket := bucket.NestedReadWriteBucket(
paymentHtlcsBucket,
)
for _, htlcID := range toDelete {
err = htlcsBucket.Delete(
htlcBucketKey(htlcAttemptInfoKey, htlcID),
)
if err != nil {
return err
}
err = htlcsBucket.Delete(
htlcBucketKey(htlcFailInfoKey, htlcID),
)
if err != nil {
return err
}
err = htlcsBucket.Delete(
htlcBucketKey(htlcSettleInfoKey, htlcID),
)
if err != nil {
return err
}
}
return nil
}
seqNrs, err := fetchSequenceNumbers(bucket)
if err != nil {
return err
}
if err := payments.DeleteNestedBucket(paymentHash[:]); err != nil {
return err
}
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
for _, k := range seqNrs {
if err := indexBucket.Delete(k); err != nil {
return err
}
}
return nil
}, func() {})
}
// DeletePayments deletes all completed and failed payments from the DB. If
// failedOnly is set, only failed payments will be considered for deletion. If
// failedHtlcsOnly is set, the payment itself won't be deleted, only failed HTLC
// attempts. The method returns the number of deleted payments, which is always
// 0 if failedHtlcsOnly is set.
func (d *DB) DeletePayments(failedOnly, failedHtlcsOnly bool) (int, error) {
var numPayments int
err := kvdb.Update(d, func(tx kvdb.RwTx) error {
payments := tx.ReadWriteBucket(paymentsRootBucket)
if payments == nil {
return nil
}
var (
// deleteBuckets is the set of payment buckets we need
// to delete.
deleteBuckets [][]byte
// deleteIndexes is the set of indexes pointing to these
// payments that need to be deleted.
deleteIndexes [][]byte
// deleteHtlcs maps a payment hash to the HTLC IDs we
// want to delete for that payment.
deleteHtlcs = make(map[lntypes.Hash][][]byte)
)
err := payments.ForEach(func(k, _ []byte) error {
bucket := payments.NestedReadBucket(k)
if bucket == nil {
// We only expect sub-buckets to be found in
// this top-level bucket.
return fmt.Errorf("non bucket element in " +
"payments bucket")
}
// If the status is InFlight, we cannot safely delete
// the payment information, so we return early.
paymentStatus, err := fetchPaymentStatus(bucket)
if err != nil {
return err
}
// If the payment has inflight HTLCs, we cannot safely
// delete the payment information, so we return an nil
// to skip it.
if err := paymentStatus.removable(); err != nil {
return nil
}
// If we requested to only delete failed payments, we
// can return if this one is not.
if failedOnly && paymentStatus != StatusFailed {
return nil
}
// If we are only deleting failed HTLCs, fetch them.
if failedHtlcsOnly {
toDelete, err := fetchFailedHtlcKeys(bucket)
if err != nil {
return err
}
hash, err := lntypes.MakeHash(k)
if err != nil {
return err
}
deleteHtlcs[hash] = toDelete
// We return, we are only deleting attempts.
return nil
}
// Add the bucket to the set of buckets we can delete.
deleteBuckets = append(deleteBuckets, k)
// Get all the sequence number associated with the
// payment, including duplicates.
seqNrs, err := fetchSequenceNumbers(bucket)
if err != nil {
return err
}
deleteIndexes = append(deleteIndexes, seqNrs...)
numPayments++
return nil
})
if err != nil {
return err
}
// Delete the failed HTLC attempts we found.
for hash, htlcIDs := range deleteHtlcs {
bucket := payments.NestedReadWriteBucket(hash[:])
htlcsBucket := bucket.NestedReadWriteBucket(
paymentHtlcsBucket,
)
for _, aid := range htlcIDs {
if err := htlcsBucket.Delete(
htlcBucketKey(htlcAttemptInfoKey, aid),
); err != nil {
return err
}
if err := htlcsBucket.Delete(
htlcBucketKey(htlcFailInfoKey, aid),
); err != nil {
return err
}
if err := htlcsBucket.Delete(
htlcBucketKey(htlcSettleInfoKey, aid),
); err != nil {
return err
}
}
}
for _, k := range deleteBuckets {
if err := payments.DeleteNestedBucket(k); err != nil {
return err
}
}
// Get our index bucket and delete all indexes pointing to the
// payments we are deleting.
indexBucket := tx.ReadWriteBucket(paymentsIndexBucket)
for _, k := range deleteIndexes {
if err := indexBucket.Delete(k); err != nil {
return err
}
}
return nil
}, func() {
numPayments = 0
})
if err != nil {
return 0, err
}
return numPayments, nil
}
// fetchSequenceNumbers fetches all the sequence numbers associated with a
// payment, including those belonging to any duplicate payments.
func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) {
seqNum := paymentBucket.Get(paymentSequenceKey)
if seqNum == nil {
return nil, errors.New("expected sequence number")
}
sequenceNumbers := [][]byte{seqNum}
// Get the duplicate payments bucket, if it has no duplicates, just
// return early with the payment sequence number.
duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket)
if duplicates == nil {
return sequenceNumbers, nil
}
// If we do have duplicated, they are keyed by sequence number, so we
// iterate through the duplicates bucket and add them to our set of
// sequence numbers.
if err := duplicates.ForEach(func(k, v []byte) error {
sequenceNumbers = append(sequenceNumbers, k)
return nil
}); err != nil {
return nil, err
}
return sequenceNumbers, nil
}
// nolint: dupl
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
var scratch [8]byte
if _, err := w.Write(c.PaymentIdentifier[:]); err != nil {
return err
}
byteOrder.PutUint64(scratch[:], uint64(c.Value))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := serializeTime(w, c.CreationTime); err != nil {
return err
}
byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest)))
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
if _, err := w.Write(c.PaymentRequest[:]); err != nil {
return err
}
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
err := c.FirstHopCustomRecords.SerializeTo(w)
if err != nil {
return err
}
return nil
}
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo,
error) {
var scratch [8]byte
c := &PaymentCreationInfo{}
if _, err := io.ReadFull(r, c.PaymentIdentifier[:]); err != nil {
return nil, err
}
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return nil, err
}
c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
creationTime, err := deserializeTime(r)
if err != nil {
return nil, err
}
c.CreationTime = creationTime
if _, err := io.ReadFull(r, scratch[:4]); err != nil {
return nil, err
}
reqLen := uint32(byteOrder.Uint32(scratch[:4]))
payReq := make([]byte, reqLen)
if reqLen > 0 {
if _, err := io.ReadFull(r, payReq); err != nil {
return nil, err
}
}
c.PaymentRequest = payReq
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
if err != nil {
return nil, err
}
return c, nil
}
func serializeHTLCAttemptInfo(w io.Writer, a *HTLCAttemptInfo) error {
if err := WriteElements(w, a.sessionKey); err != nil {
return err
}
if err := SerializeRoute(w, a.Route); err != nil {
return err
}
if err := serializeTime(w, a.AttemptTime); err != nil {
return err
}
// If the hash is nil we can just return.
if a.Hash == nil {
return nil
}
if _, err := w.Write(a.Hash[:]); err != nil {
return err
}
// Merge the fixed/known records together with the custom records to
// serialize them as a single blob. We can't do this in SerializeRoute
// because we're in the middle of the byte stream there. We can only do
// TLV serialization at the end of the stream, since EOF is allowed for
// a stream if no more data is expected.
producers := []tlv.RecordProducer{
&a.Route.FirstHopAmount,
}
tlvData, err := lnwire.MergeAndEncode(
producers, nil, a.Route.FirstHopWireCustomRecords,
)
if err != nil {
return err
}
if _, err := w.Write(tlvData); err != nil {
return err
}
return nil
}
func deserializeHTLCAttemptInfo(r io.Reader) (*HTLCAttemptInfo, error) {
a := &HTLCAttemptInfo{}
err := ReadElements(r, &a.sessionKey)
if err != nil {
return nil, err
}
a.Route, err = DeserializeRoute(r)
if err != nil {
return nil, err
}
a.AttemptTime, err = deserializeTime(r)
if err != nil {
return nil, err
}
hash := lntypes.Hash{}
_, err = io.ReadFull(r, hash[:])
switch {
// Older payment attempts wouldn't have the hash set, in which case we
// can just return.
case err == io.EOF, err == io.ErrUnexpectedEOF:
return a, nil
case err != nil:
return nil, err
default:
}
a.Hash = &hash
// Read any remaining data (if any) and parse it into the known records
// and custom records.
extraData, err := io.ReadAll(r)
if err != nil {
return nil, err
}
customRecords, _, _, err := lnwire.ParseAndExtractCustomRecords(
extraData, &a.Route.FirstHopAmount,
)
if err != nil {
return nil, err
}
a.Route.FirstHopWireCustomRecords = customRecords
return a, nil
}
func serializeHop(w io.Writer, h *route.Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:],
h.ChannelID,
h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
return err
}
// For legacy payloads, we don't need to write any TLV records, so
// we'll write a zero indicating the our serialized TLV map has no
// records.
if h.LegacyPayload {
return WriteElements(w, uint32(0))
}
// Gather all non-primitive TLV records so that they can be serialized
// as a single blob.
//
// TODO(conner): add migration to unify all fields in a single TLV
// blobs. The split approach will cause headaches down the road as more
// fields are added, which we can avoid by having a single TLV stream
// for all payload fields.
var records []tlv.Record
if h.MPP != nil {
records = append(records, h.MPP.Record())
}
// Add blinding point and encrypted data if present.
if h.EncryptedData != nil {
records = append(records, record.NewEncryptedDataRecord(
&h.EncryptedData,
))
}
if h.BlindingPoint != nil {
records = append(records, record.NewBlindingPointRecord(
&h.BlindingPoint,
))
}
if h.AMP != nil {
records = append(records, h.AMP.Record())
}
if h.Metadata != nil {
records = append(records, record.NewMetadataRecord(&h.Metadata))
}
if h.TotalAmtMsat != 0 {
totalMsatInt := uint64(h.TotalAmtMsat)
records = append(
records, record.NewTotalAmtMsatBlinded(&totalMsatInt),
)
}
// Final sanity check to absolutely rule out custom records that are not
// custom and write into the standard range.
if err := h.CustomRecords.Validate(); err != nil {
return err
}
// Convert custom records to tlv and add to the record list.
// MapToRecords sorts the list, so adding it here will keep the list
// canonical.
tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
// Otherwise, we'll transform our slice of records into a map of the
// raw bytes, then serialize them in-line with a length (number of
// elements) prefix.
mapRecords, err := tlv.RecordsToMap(records)
if err != nil {
return err
}
numRecords := uint32(len(mapRecords))
if err := WriteElements(w, numRecords); err != nil {
return err
}
for recordType, rawBytes := range mapRecords {
if err := WriteElements(w, recordType); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
return err
}
}
return nil
}
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
// to read/write a TLV stream larger than this.
const maxOnionPayloadSize = 1300
func deserializeHop(r io.Reader) (*route.Hop, error) {
h := &route.Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
// TODO(roasbeef): change field to allow LegacyPayload false to be the
// legacy default?
err := binary.Read(r, byteOrder, &h.LegacyPayload)
if err != nil {
return nil, err
}
var numElements uint32
if err := ReadElements(r, &numElements); err != nil {
return nil, err
}
// If there're no elements, then we can return early.
if numElements == 0 {
return h, nil
}
tlvMap := make(map[uint64][]byte)
for i := uint32(0); i < numElements; i++ {
var tlvType uint64
if err := ReadElements(r, &tlvType); err != nil {
return nil, err
}
rawRecordBytes, err := wire.ReadVarBytes(
r, 0, maxOnionPayloadSize, "tlv",
)
if err != nil {
return nil, err
}
tlvMap[tlvType] = rawRecordBytes
}
// If the MPP type is present, remove it from the generic TLV map and
// parse it back into a proper MPP struct.
//
// TODO(conner): add migration to unify all fields in a single TLV
// blobs. The split approach will cause headaches down the road as more
// fields are added, which we can avoid by having a single TLV stream
// for all payload fields.
mppType := uint64(record.MPPOnionType)
if mppBytes, ok := tlvMap[mppType]; ok {
delete(tlvMap, mppType)
var (
mpp = &record.MPP{}
mppRec = mpp.Record()
r = bytes.NewReader(mppBytes)
)
err := mppRec.Decode(r, uint64(len(mppBytes)))
if err != nil {
return nil, err
}
h.MPP = mpp
}
// If encrypted data or blinding key are present, remove them from
// the TLV map and parse into proper types.
encryptedDataType := uint64(record.EncryptedDataOnionType)
if data, ok := tlvMap[encryptedDataType]; ok {
delete(tlvMap, encryptedDataType)
h.EncryptedData = data
}
blindingType := uint64(record.BlindingPointOnionType)
if blindingPoint, ok := tlvMap[blindingType]; ok {
delete(tlvMap, blindingType)
h.BlindingPoint, err = btcec.ParsePubKey(blindingPoint)
if err != nil {
return nil, fmt.Errorf("invalid blinding point: %w",
err)
}
}
ampType := uint64(record.AMPOnionType)
if ampBytes, ok := tlvMap[ampType]; ok {
delete(tlvMap, ampType)
var (
amp = &record.AMP{}
ampRec = amp.Record()
r = bytes.NewReader(ampBytes)
)
err := ampRec.Decode(r, uint64(len(ampBytes)))
if err != nil {
return nil, err
}
h.AMP = amp
}
// If the metadata type is present, remove it from the tlv map and
// populate directly on the hop.
metadataType := uint64(record.MetadataOnionType)
if metadata, ok := tlvMap[metadataType]; ok {
delete(tlvMap, metadataType)
h.Metadata = metadata
}
totalAmtMsatType := uint64(record.TotalAmtMsatBlindedType)
if totalAmtMsat, ok := tlvMap[totalAmtMsatType]; ok {
delete(tlvMap, totalAmtMsatType)
var (
totalAmtMsatInt uint64
buf [8]byte
)
if err := tlv.DTUint64(
bytes.NewReader(totalAmtMsat),
&totalAmtMsatInt,
&buf,
uint64(len(totalAmtMsat)),
); err != nil {
return nil, err
}
h.TotalAmtMsat = lnwire.MilliSatoshi(totalAmtMsatInt)
}
h.CustomRecords = tlvMap
return h, nil
}
// SerializeRoute serializes a route.
func SerializeRoute(w io.Writer, r route.Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHop(w, h); err != nil {
return err
}
}
// Any new/extra TLV data is encoded in serializeHTLCAttemptInfo!
return nil
}
// DeserializeRoute deserializes a route.
func DeserializeRoute(r io.Reader) (route.Route, error) {
rt := route.Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*route.Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHop(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
// Any new/extra TLV data is decoded in deserializeHTLCAttemptInfo!
return rt, nil
}
package channeldb
import (
"bytes"
"errors"
"fmt"
"time"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/routing/route"
)
var (
// peersBucket is the name of a top level bucket in which we store
// information about our peers. Information for different peers is
// stored in buckets keyed by their public key.
//
//
// peers-bucket
// |
// |-- <peer-pubkey>
// | |--flap-count-key: <ts><flap count>
// |
// |-- <peer-pubkey>
// | |--flap-count-key: <ts><flap count>
peersBucket = []byte("peers-bucket")
// flapCountKey is a key used in the peer pubkey sub-bucket that stores
// the timestamp of a peer's last flap count and its all time flap
// count.
flapCountKey = []byte("flap-count")
)
var (
// ErrNoPeerBucket is returned when we try to read entries for a peer
// that is not tracked.
ErrNoPeerBucket = errors.New("peer bucket not found")
)
// FlapCount contains information about a peer's flap count.
type FlapCount struct {
// Count provides the total flap count for a peer.
Count uint32
// LastFlap is the timestamp of the last flap recorded for a peer.
LastFlap time.Time
}
// WriteFlapCounts writes the flap count for a set of peers to disk, creating a
// bucket for the peer's pubkey if necessary. Note that this function overwrites
// the current value.
func (d *DB) WriteFlapCounts(flapCounts map[route.Vertex]*FlapCount) error {
// Exit early if there are no updates.
if len(flapCounts) == 0 {
log.Debugf("No flap counts to write, skipped db update")
return nil
}
return kvdb.Update(d, func(tx kvdb.RwTx) error {
// Run through our set of flap counts and record them for
// each peer, creating a bucket for the peer pubkey if required.
for peer, flapCount := range flapCounts {
peers := tx.ReadWriteBucket(peersBucket)
peerBucket, err := peers.CreateBucketIfNotExists(
peer[:],
)
if err != nil {
return err
}
var b bytes.Buffer
err = serializeTime(&b, flapCount.LastFlap)
if err != nil {
return err
}
if err = WriteElement(&b, flapCount.Count); err != nil {
return err
}
err = peerBucket.Put(flapCountKey, b.Bytes())
if err != nil {
return err
}
}
return nil
}, func() {})
}
// ReadFlapCount attempts to read the flap count for a peer, failing if the
// peer is not found or we do not have flap count stored.
func (d *DB) ReadFlapCount(pubkey route.Vertex) (*FlapCount, error) {
var flapCount FlapCount
if err := kvdb.View(d, func(tx kvdb.RTx) error {
peers := tx.ReadBucket(peersBucket)
peerBucket := peers.NestedReadBucket(pubkey[:])
if peerBucket == nil {
return ErrNoPeerBucket
}
flapBytes := peerBucket.Get(flapCountKey)
if flapBytes == nil {
return fmt.Errorf("flap count not recorded for: %v",
pubkey)
}
var (
err error
r = bytes.NewReader(flapBytes)
)
flapCount.LastFlap, err = deserializeTime(r)
if err != nil {
return err
}
return ReadElements(r, &flapCount.Count)
}, func() {
flapCount = FlapCount{}
}); err != nil {
return nil, err
}
return &flapCount, nil
}
package channeldb
import (
"bytes"
"errors"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// closeSummaryBucket is a top level bucket which holds additional
// information about channel closes. It nests channels by chainhash
// and channel point.
// [closeSummaryBucket]
// [chainHashBucket]
// [channelBucket]
// [resolversBucket]
closeSummaryBucket = []byte("close-summaries")
// resolversBucket holds the outcome of a channel's resolvers. It is
// nested under a channel and chainhash bucket in the close summaries
// bucket.
resolversBucket = []byte("resolvers-bucket")
)
var (
// ErrNoChainHashBucket is returned when we have not created a bucket
// for the current chain hash.
ErrNoChainHashBucket = errors.New("no chain hash bucket")
// ErrNoChannelSummaries is returned when a channel is not found in the
// chain hash bucket.
ErrNoChannelSummaries = errors.New("channel bucket not found")
amountType tlv.Type = 1
resolverType tlv.Type = 2
outcomeType tlv.Type = 3
spendTxIDType tlv.Type = 4
)
// ResolverType indicates the type of resolver that was resolved on chain.
type ResolverType uint8
const (
// ResolverTypeAnchor represents a resolver for an anchor output.
ResolverTypeAnchor ResolverType = 0
// ResolverTypeIncomingHtlc represents resolution of an incoming htlc.
ResolverTypeIncomingHtlc ResolverType = 1
// ResolverTypeOutgoingHtlc represents resolution of an outgoing htlc.
ResolverTypeOutgoingHtlc ResolverType = 2
// ResolverTypeCommit represents resolution of our time locked commit
// when we force close.
ResolverTypeCommit ResolverType = 3
)
// ResolverOutcome indicates the outcome for the resolver that that the contract
// court reached. This state is not necessarily final, since htlcs on our own
// commitment are resolved across two resolvers.
type ResolverOutcome uint8
const (
// ResolverOutcomeClaimed indicates that funds were claimed on chain.
ResolverOutcomeClaimed ResolverOutcome = 0
// ResolverOutcomeUnclaimed indicates that we did not claim our funds on
// chain. This may be the case for anchors that we did not sweep, or
// outputs that were not economical to sweep.
ResolverOutcomeUnclaimed ResolverOutcome = 1
// ResolverOutcomeAbandoned indicates that we did not attempt to claim
// an output on chain. This is the case for htlcs that we could not
// decode to claim, or invoice which we fail when an attempt is made
// to settle them on chain.
ResolverOutcomeAbandoned ResolverOutcome = 2
// ResolverOutcomeTimeout indicates that a contract was timed out on
// chain.
ResolverOutcomeTimeout ResolverOutcome = 3
// ResolverOutcomeFirstStage indicates that a htlc had to be claimed
// over two stages, with this outcome representing the confirmation
// of our success/timeout tx.
ResolverOutcomeFirstStage ResolverOutcome = 4
)
// ResolverReport provides an account of the outcome of a resolver. This differs
// from a ContractReport because it does not necessarily fully resolve the
// contract; each step of two stage htlc resolution is included.
type ResolverReport struct {
// OutPoint is the on chain outpoint that was spent as a result of this
// resolution. When an output is directly resolved (eg, commitment
// sweeps and single stage htlcs on the remote party's output) this
// is an output on the commitment tx that was broadcast. When we resolve
// across two stages (eg, htlcs on our own force close commit), the
// first stage outpoint is the output on our commitment and the second
// stage output is the spend from our htlc success/timeout tx.
OutPoint wire.OutPoint
// Amount is the value of the output referenced above.
Amount btcutil.Amount
// ResolverType indicates the type of resolution that occurred.
ResolverType
// ResolverOutcome indicates the outcome of the resolver.
ResolverOutcome
// SpendTxID is the transaction ID of the spending transaction that
// claimed the outpoint. This may be a sweep transaction, or a first
// stage success/timeout transaction.
SpendTxID *chainhash.Hash
}
// PutResolverReport creates and commits a transaction that is used to write a
// resolver report to disk.
func (d *DB) PutResolverReport(tx kvdb.RwTx, chainHash chainhash.Hash,
channelOutpoint *wire.OutPoint, report *ResolverReport) error {
putReportFunc := func(tx kvdb.RwTx) error {
return putReport(tx, chainHash, channelOutpoint, report)
}
// If the transaction is nil, we'll create a new one.
if tx == nil {
return kvdb.Update(d, putReportFunc, func() {})
}
// Otherwise, we can write the report to disk using the existing
// transaction.
return putReportFunc(tx)
}
// putReport puts a report in the bucket provided, with its outpoint as its key.
func putReport(tx kvdb.RwTx, chainHash chainhash.Hash,
channelOutpoint *wire.OutPoint, report *ResolverReport) error {
channelBucket, err := fetchReportWriteBucket(
tx, chainHash, channelOutpoint,
)
if err != nil {
return err
}
// If the resolvers bucket does not exist yet, create it.
resolvers, err := channelBucket.CreateBucketIfNotExists(
resolversBucket,
)
if err != nil {
return err
}
var valueBuf bytes.Buffer
if err := serializeReport(&valueBuf, report); err != nil {
return err
}
// Finally write our outpoint to be used as the key for this record.
var keyBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&keyBuf, &report.OutPoint); err != nil {
return err
}
return resolvers.Put(keyBuf.Bytes(), valueBuf.Bytes())
}
// serializeReport serialized a report using a TLV stream to allow for optional
// fields.
func serializeReport(w io.Writer, report *ResolverReport) error {
amt := uint64(report.Amount)
resolver := uint8(report.ResolverType)
outcome := uint8(report.ResolverOutcome)
// Create a set of TLV records for the values we know to be present.
records := []tlv.Record{
tlv.MakePrimitiveRecord(amountType, &amt),
tlv.MakePrimitiveRecord(resolverType, &resolver),
tlv.MakePrimitiveRecord(outcomeType, &outcome),
}
// If our spend txid is non-nil, we add a tlv entry for it.
if report.SpendTxID != nil {
var spendBuf bytes.Buffer
err := WriteElement(&spendBuf, *report.SpendTxID)
if err != nil {
return err
}
spendBytes := spendBuf.Bytes()
records = append(records, tlv.MakePrimitiveRecord(
spendTxIDType, &spendBytes,
))
}
// Create our stream and encode it.
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// FetchChannelReports fetches the set of reports for a channel.
func (d DB) FetchChannelReports(chainHash chainhash.Hash,
outPoint *wire.OutPoint) ([]*ResolverReport, error) {
var reports []*ResolverReport
if err := kvdb.View(d.Backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchReportReadBucket(
tx, chainHash, outPoint,
)
if err != nil {
return err
}
// If there are no resolvers for this channel, we simply
// return nil, because nothing has been persisted yet.
resolvers := chanBucket.NestedReadBucket(resolversBucket)
if resolvers == nil {
return nil
}
// Run through each resolution and add it to our set of
// resolutions.
return resolvers.ForEach(func(k, v []byte) error {
// Deserialize the contents of our field.
r := bytes.NewReader(v)
report, err := deserializeReport(r)
if err != nil {
return err
}
// Once we have read our values out, set the outpoint
// on the report using the key.
r = bytes.NewReader(k)
if err := ReadElement(r, &report.OutPoint); err != nil {
return err
}
reports = append(reports, report)
return nil
})
}, func() {
reports = nil
}); err != nil {
return nil, err
}
return reports, nil
}
// deserializeReport gets a resolver report from a tlv stream. The outpoint on
// the resolver will not be set because we key reports by their outpoint, and
// this function reads only the values saved in the stream.
func deserializeReport(r io.Reader) (*ResolverReport, error) {
var (
resolver, outcome uint8
amt uint64
spentTx []byte
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(amountType, &amt),
tlv.MakePrimitiveRecord(resolverType, &resolver),
tlv.MakePrimitiveRecord(outcomeType, &outcome),
tlv.MakePrimitiveRecord(spendTxIDType, &spentTx),
)
if err != nil {
return nil, err
}
if err := tlvStream.Decode(r); err != nil {
return nil, err
}
report := &ResolverReport{
Amount: btcutil.Amount(amt),
ResolverOutcome: ResolverOutcome(outcome),
ResolverType: ResolverType(resolver),
}
// If our spend tx is set, we set it on our report.
if len(spentTx) != 0 {
spendTx, err := chainhash.NewHash(spentTx)
if err != nil {
return nil, err
}
report.SpendTxID = spendTx
}
return report, nil
}
// fetchReportWriteBucket returns a write channel bucket within the reports
// top level bucket. If the channel's bucket does not yet exist, it will be
// created.
func fetchReportWriteBucket(tx kvdb.RwTx, chainHash chainhash.Hash,
outPoint *wire.OutPoint) (kvdb.RwBucket, error) {
// Get the channel close summary bucket.
closedBucket := tx.ReadWriteBucket(closeSummaryBucket)
// Create the chain hash bucket if it does not exist.
chainHashBkt, err := closedBucket.CreateBucketIfNotExists(chainHash[:])
if err != nil {
return nil, err
}
var chanPointBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
return chainHashBkt.CreateBucketIfNotExists(chanPointBuf.Bytes())
}
// fetchReportReadBucket returns a read channel bucket within the reports
// top level bucket. If any bucket along the way does not exist, it will error.
func fetchReportReadBucket(tx kvdb.RTx, chainHash chainhash.Hash,
outPoint *wire.OutPoint) (kvdb.RBucket, error) {
// First fetch the top level channel close summary bucket.
closeBucket := tx.ReadBucket(closeSummaryBucket)
// Next we get the chain hash bucket for our current chain.
chainHashBucket := closeBucket.NestedReadBucket(chainHash[:])
if chainHashBucket == nil {
return nil, ErrNoChainHashBucket
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for the channel itself.
var chanPointBuf bytes.Buffer
if err := graphdb.WriteOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainHashBucket.NestedReadBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrNoChannelSummaries
}
return chanBucket, nil
}
package channeldb
import (
"bytes"
"errors"
"io"
"math"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// OutputIndexEmpty is used when the output index doesn't exist.
OutputIndexEmpty = math.MaxUint16
)
type (
// BigSizeAmount is a type alias for a TLV record of a btcutil.Amount.
BigSizeAmount = tlv.BigSizeT[btcutil.Amount]
// BigSizeMilliSatoshi is a type alias for a TLV record of a
// lnwire.MilliSatoshi.
BigSizeMilliSatoshi = tlv.BigSizeT[lnwire.MilliSatoshi]
)
var (
// revocationLogBucketDeprecated is dedicated for storing the necessary
// delta state between channel updates required to re-construct a past
// state in order to punish a counterparty attempting a non-cooperative
// channel closure. This key should be accessed from within the
// sub-bucket of a target channel, identified by its channel point.
//
// Deprecated: This bucket is kept for read-only in case the user
// choose not to migrate the old data.
revocationLogBucketDeprecated = []byte("revocation-log-key")
// revocationLogBucket is a sub-bucket under openChannelBucket. This
// sub-bucket is dedicated for storing the minimal info required to
// re-construct a past state in order to punish a counterparty
// attempting a non-cooperative channel closure.
revocationLogBucket = []byte("revocation-log")
// ErrLogEntryNotFound is returned when we cannot find a log entry at
// the height requested in the revocation log.
ErrLogEntryNotFound = errors.New("log entry not found")
// ErrOutputIndexTooBig is returned when the output index is greater
// than uint16.
ErrOutputIndexTooBig = errors.New("output index is over uint16")
)
// SparsePayHash is a type alias for a 32 byte array, which when serialized is
// able to save some space by not including an empty payment hash on disk.
type SparsePayHash [32]byte
// NewSparsePayHash creates a new SparsePayHash from a 32 byte array.
func NewSparsePayHash(rHash [32]byte) SparsePayHash {
return SparsePayHash(rHash)
}
// Record returns a tlv record for the SparsePayHash.
func (s *SparsePayHash) Record() tlv.Record {
// We use a zero for the type here, as this'll be used along with the
// RecordT type.
return tlv.MakeDynamicRecord(
0, s, s.hashLen,
sparseHashEncoder, sparseHashDecoder,
)
}
// hashLen is used by MakeDynamicRecord to return the size of the RHash.
//
// NOTE: for zero hash, we return a length 0.
func (s *SparsePayHash) hashLen() uint64 {
if bytes.Equal(s[:], lntypes.ZeroHash[:]) {
return 0
}
return 32
}
// sparseHashEncoder is the customized encoder which skips encoding the empty
// hash.
func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the value is an empty hash, we will skip encoding it.
if bytes.Equal(v[:], lntypes.ZeroHash[:]) {
return nil
}
vArray := (*[32]byte)(v)
return tlv.EBytes32(w, vArray, buf)
}
// sparseHashDecoder is the customized decoder which skips decoding the empty
// hash.
func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
v, ok := val.(*SparsePayHash)
if !ok {
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
}
// If the length is zero, we will skip encoding the empty hash.
if l == 0 {
return nil
}
vArray := (*[32]byte)(v)
return tlv.DBytes32(r, vArray, buf, 32)
}
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
// historical HTLCs, which is useful for constructing RevocationLog when a
// breach is detected.
// The actual size of each HTLCEntry varies based on its RHash and Amt(sat),
// summarized as follows,
//
// | RHash empty | Amt<=252 | Amt<=65,535 | Amt<=4,294,967,295 | otherwise |
// |:-----------:|:--------:|:-----------:|:------------------:|:---------:|
// | true | 19 | 21 | 23 | 26 |
// | false | 51 | 53 | 55 | 58 |
//
// So the size varies from 19 bytes to 58 bytes, where most likely to be 23 or
// 55 bytes.
//
// NOTE: all the fields saved to disk use the primitive go types so they can be
// made into tlv records without further conversion.
type HTLCEntry struct {
// RHash is the payment hash of the HTLC.
RHash tlv.RecordT[tlv.TlvType0, SparsePayHash]
// RefundTimeout is the absolute timeout on the HTLC that the sender
// must wait before reclaiming the funds in limbo.
RefundTimeout tlv.RecordT[tlv.TlvType1, uint32]
// OutputIndex is the output index for this particular HTLC output
// within the commitment transaction.
//
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
// gives us a max number of HTLCs of 65K.
OutputIndex tlv.RecordT[tlv.TlvType2, uint16]
// Incoming denotes whether we're the receiver or the sender of this
// HTLC.
Incoming tlv.RecordT[tlv.TlvType3, bool]
// Amt is the amount of satoshis this HTLC escrows.
Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]]
// CustomBlob is an optional blob that can be used to store information
// specific to revocation handling for a custom channel type.
CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob]
// HtlcIndex is the index of the HTLC in the channel.
HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, uint16]
}
// toTlvStream converts an HTLCEntry record into a tlv representation.
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
records := []tlv.Record{
h.RHash.Record(),
h.RefundTimeout.Record(),
h.OutputIndex.Record(),
h.Incoming.Record(),
h.Amt.Record(),
}
h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) {
records = append(records, r.Record())
})
h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, uint16]) {
records = append(records, r.Record())
})
tlv.SortRecords(records)
return tlv.NewStream(records...)
}
// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC.
func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) {
h := &HTLCEntry{
RHash: tlv.NewRecordT[tlv.TlvType0](
NewSparsePayHash(htlc.RHash),
),
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1](
htlc.RefundTimeout,
),
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint16(htlc.OutputIndex),
),
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming),
Amt: tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(htlc.Amt.ToSatoshis()),
),
HtlcIndex: tlv.SomeRecordT(tlv.NewPrimitiveRecord[tlv.TlvType6](
uint16(htlc.HtlcIndex),
)),
}
if len(htlc.CustomRecords) != 0 {
blob, err := htlc.CustomRecords.Serialize()
if err != nil {
return nil, err
}
h.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
}
return h, nil
}
// RevocationLog stores the info needed to construct a breach retribution. Its
// fields can be viewed as a subset of a ChannelCommitment's. In the database,
// all historical versions of the RevocationLog are saved using the
// CommitHeight as the key.
type RevocationLog struct {
// OurOutputIndex specifies our output index in this commitment. In a
// remote commitment transaction, this is the to remote output index.
OurOutputIndex tlv.RecordT[tlv.TlvType0, uint16]
// TheirOutputIndex specifies their output index in this commitment. In
// a remote commitment transaction, this is the to local output index.
TheirOutputIndex tlv.RecordT[tlv.TlvType1, uint16]
// CommitTxHash is the hash of the latest version of the commitment
// state, broadcast able by us.
CommitTxHash tlv.RecordT[tlv.TlvType2, [32]byte]
// HTLCEntries is the set of HTLCEntry's that are pending at this
// particular commitment height.
HTLCEntries []*HTLCEntry
// OurBalance is the current available balance within the channel
// directly spendable by us. In other words, it is the value of the
// to_remote output on the remote parties' commitment transaction.
//
// NOTE: this is an option so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for
// all revocation logs.
OurBalance tlv.OptionalRecordT[tlv.TlvType3, BigSizeMilliSatoshi]
// TheirBalance is the current available balance within the channel
// directly spendable by the remote node. In other words, it is the
// value of the to_local output on the remote parties' commitment.
//
// NOTE: this is an option so that it is clear if the value is zero or
// nil. Since migration 30 of the channeldb initially did not include
// this field, it could be the case that the field is not present for
// all revocation logs.
TheirBalance tlv.OptionalRecordT[tlv.TlvType4, BigSizeMilliSatoshi]
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This information is only created
// at channel funding time, and after wards is to be considered
// immutable.
CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob]
}
// NewRevocationLog creates a new RevocationLog from the given parameters.
func NewRevocationLog(ourOutputIndex uint16, theirOutputIndex uint16,
commitHash [32]byte, ourBalance,
theirBalance fn.Option[lnwire.MilliSatoshi], htlcs []*HTLCEntry,
customBlob fn.Option[tlv.Blob]) RevocationLog {
rl := RevocationLog{
OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0](
ourOutputIndex,
),
TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1](
theirOutputIndex,
),
CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2](commitHash),
HTLCEntries: htlcs,
}
ourBalance.WhenSome(func(balance lnwire.MilliSatoshi) {
rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3](
tlv.NewBigSizeT(balance),
))
})
theirBalance.WhenSome(func(balance lnwire.MilliSatoshi) {
rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(balance),
))
})
customBlob.WhenSome(func(blob tlv.Blob) {
rl.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
})
return rl
}
// putRevocationLog uses the fields `CommitTx` and `Htlcs` from a
// ChannelCommitment to construct a revocation log entry and saves them to
// disk. It also saves our output index and their output index, which are
// useful when creating breach retribution.
func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
ourOutputIndex, theirOutputIndex uint32, noAmtData bool) error {
// Sanity check that the output indexes can be safely converted.
if ourOutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
if theirOutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
rl := &RevocationLog{
OurOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint16(ourOutputIndex),
),
TheirOutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint16(theirOutputIndex),
),
CommitTxHash: tlv.NewPrimitiveRecord[tlv.TlvType2, [32]byte](
commit.CommitTx.TxHash(),
),
HTLCEntries: make([]*HTLCEntry, 0, len(commit.Htlcs)),
}
commit.CustomBlob.WhenSome(func(blob tlv.Blob) {
rl.CustomBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5, tlv.Blob](blob),
)
})
if !noAmtData {
rl.OurBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType3](
tlv.NewBigSizeT(commit.LocalBalance),
))
rl.TheirBalance = tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType4](
tlv.NewBigSizeT(commit.RemoteBalance),
))
}
for _, htlc := range commit.Htlcs {
// Skip dust HTLCs.
if htlc.OutputIndex < 0 {
continue
}
// Sanity check that the output indexes can be safely
// converted.
if htlc.OutputIndex > math.MaxUint16 {
return ErrOutputIndexTooBig
}
entry, err := NewHTLCEntryFromHTLC(htlc)
if err != nil {
return err
}
rl.HTLCEntries = append(rl.HTLCEntries, entry)
}
var b bytes.Buffer
err := serializeRevocationLog(&b, rl)
if err != nil {
return err
}
logEntrykey := makeLogKey(commit.CommitHeight)
return bucket.Put(logEntrykey[:], b.Bytes())
}
// fetchRevocationLog queries the revocation log bucket to find an log entry.
// Return an error if not found.
func fetchRevocationLog(log kvdb.RBucket,
updateNum uint64) (RevocationLog, error) {
logEntrykey := makeLogKey(updateNum)
commitBytes := log.Get(logEntrykey[:])
if commitBytes == nil {
return RevocationLog{}, ErrLogEntryNotFound
}
commitReader := bytes.NewReader(commitBytes)
return deserializeRevocationLog(commitReader)
}
// serializeRevocationLog serializes a RevocationLog record based on tlv
// format.
func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
// Add the tlv records for all non-optional fields.
records := []tlv.Record{
rl.OurOutputIndex.Record(),
rl.TheirOutputIndex.Record(),
rl.CommitTxHash.Record(),
}
// Now we add any optional fields that are non-nil.
rl.OurBalance.WhenSome(
func(r tlv.RecordT[tlv.TlvType3, BigSizeMilliSatoshi]) {
records = append(records, r.Record())
},
)
rl.TheirBalance.WhenSome(
func(r tlv.RecordT[tlv.TlvType4, BigSizeMilliSatoshi]) {
records = append(records, r.Record())
},
)
rl.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) {
records = append(records, r.Record())
})
// Create the tlv stream.
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
// Write the tlv stream.
if err := writeTlvStream(w, tlvStream); err != nil {
return err
}
// Write the HTLCs.
return serializeHTLCEntries(w, rl.HTLCEntries)
}
// serializeHTLCEntries serializes a list of HTLCEntry records based on tlv
// format.
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
for _, htlc := range htlcs {
// Create the tlv stream.
tlvStream, err := htlc.toTlvStream()
if err != nil {
return err
}
// Write the tlv stream.
if err := writeTlvStream(w, tlvStream); err != nil {
return err
}
}
return nil
}
// deserializeRevocationLog deserializes a RevocationLog based on tlv format.
func deserializeRevocationLog(r io.Reader) (RevocationLog, error) {
var rl RevocationLog
ourBalance := rl.OurBalance.Zero()
theirBalance := rl.TheirBalance.Zero()
customBlob := rl.CustomBlob.Zero()
// Create the tlv stream.
tlvStream, err := tlv.NewStream(
rl.OurOutputIndex.Record(),
rl.TheirOutputIndex.Record(),
rl.CommitTxHash.Record(),
ourBalance.Record(),
theirBalance.Record(),
customBlob.Record(),
)
if err != nil {
return rl, err
}
// Read the tlv stream.
parsedTypes, err := readTlvStream(r, tlvStream)
if err != nil {
return rl, err
}
if t, ok := parsedTypes[ourBalance.TlvType()]; ok && t == nil {
rl.OurBalance = tlv.SomeRecordT(ourBalance)
}
if t, ok := parsedTypes[theirBalance.TlvType()]; ok && t == nil {
rl.TheirBalance = tlv.SomeRecordT(theirBalance)
}
if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil {
rl.CustomBlob = tlv.SomeRecordT(customBlob)
}
// Read the HTLC entries.
rl.HTLCEntries, err = deserializeHTLCEntries(r)
return rl, err
}
// deserializeHTLCEntries deserializes a list of HTLC entries based on tlv
// format.
func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
var htlcs []*HTLCEntry
for {
var htlc HTLCEntry
customBlob := htlc.CustomBlob.Zero()
htlcIndex := htlc.HtlcIndex.Zero()
// Create the tlv stream.
records := []tlv.Record{
htlc.RHash.Record(),
htlc.RefundTimeout.Record(),
htlc.OutputIndex.Record(),
htlc.Incoming.Record(),
htlc.Amt.Record(),
customBlob.Record(),
htlcIndex.Record(),
}
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
// Read the HTLC entry.
parsedTypes, err := readTlvStream(r, tlvStream)
if err != nil {
// We've reached the end when hitting an EOF.
if err == io.ErrUnexpectedEOF {
break
}
return nil, err
}
if t, ok := parsedTypes[customBlob.TlvType()]; ok && t == nil {
htlc.CustomBlob = tlv.SomeRecordT(customBlob)
}
if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil {
htlc.HtlcIndex = tlv.SomeRecordT(htlcIndex)
}
// Append the entry.
htlcs = append(htlcs, &htlc)
}
return htlcs, nil
}
// writeTlvStream is a helper function that encodes the tlv stream into the
// writer.
func writeTlvStream(w io.Writer, s *tlv.Stream) error {
var b bytes.Buffer
if err := s.Encode(&b); err != nil {
return err
}
// Write the stream's length as a varint.
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
if err != nil {
return err
}
if _, err = w.Write(b.Bytes()); err != nil {
return err
}
return nil
}
// readTlvStream is a helper function that decodes the tlv stream from the
// reader.
func readTlvStream(r io.Reader, s *tlv.Stream) (tlv.TypeMap, error) {
var bodyLen uint64
// Read the stream's length.
bodyLen, err := tlv.ReadVarInt(r, &[8]byte{})
switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this results in an
// invalid record.
case err == io.EOF:
return nil, io.ErrUnexpectedEOF
// Other unexpected errors.
case err != nil:
return nil, err
}
// TODO(yy): add overflow check.
lr := io.LimitReader(r, int64(bodyLen))
return s.DecodeWithParsedTypes(lr)
}
// fetchOldRevocationLog finds the revocation log from the deprecated
// sub-bucket.
func fetchOldRevocationLog(log kvdb.RBucket,
updateNum uint64) (ChannelCommitment, error) {
logEntrykey := makeLogKey(updateNum)
commitBytes := log.Get(logEntrykey[:])
if commitBytes == nil {
return ChannelCommitment{}, ErrLogEntryNotFound
}
commitReader := bytes.NewReader(commitBytes)
return deserializeChanCommit(commitReader)
}
// fetchRevocationLogCompatible finds the revocation log from both the
// revocationLogBucket and revocationLogBucketDeprecated for compatibility
// concern. It returns three values,
// - RevocationLog, if this is non-nil, it means we've found the log in the
// new bucket.
// - ChannelCommitment, if this is non-nil, it means we've found the log in the
// old bucket.
// - error, this can happen if the log cannot be found in neither buckets.
func fetchRevocationLogCompatible(chanBucket kvdb.RBucket,
updateNum uint64) (*RevocationLog, *ChannelCommitment, error) {
// Look into the new bucket first.
logBucket := chanBucket.NestedReadBucket(revocationLogBucket)
if logBucket != nil {
rl, err := fetchRevocationLog(logBucket, updateNum)
// We've found the record, no need to visit the old bucket.
if err == nil {
return &rl, nil, nil
}
// Return the error if it doesn't say the log cannot be found.
if err != ErrLogEntryNotFound {
return nil, nil, err
}
}
// Otherwise, look into the old bucket and try to find the log there.
oldBucket := chanBucket.NestedReadBucket(revocationLogBucketDeprecated)
if oldBucket != nil {
c, err := fetchOldRevocationLog(oldBucket, updateNum)
if err != nil {
return nil, nil, err
}
// Found an old record and return it.
return nil, &c, nil
}
// If both the buckets are nil, then the sub-buckets haven't been
// created yet.
if logBucket == nil && oldBucket == nil {
return nil, nil, ErrNoPastDeltas
}
// Otherwise, we've tried to query the new bucket but the log cannot be
// found.
return nil, nil, ErrLogEntryNotFound
}
// fetchLogBucket returns a read bucket by visiting both the new and the old
// bucket.
func fetchLogBucket(chanBucket kvdb.RBucket) (kvdb.RBucket, error) {
logBucket := chanBucket.NestedReadBucket(revocationLogBucket)
if logBucket == nil {
logBucket = chanBucket.NestedReadBucket(
revocationLogBucketDeprecated,
)
if logBucket == nil {
return nil, ErrNoPastDeltas
}
}
return logBucket, nil
}
// deleteLogBucket deletes the both the new and old revocation log buckets.
func deleteLogBucket(chanBucket kvdb.RwBucket) error {
// Check if the bucket exists and delete it.
logBucket := chanBucket.NestedReadWriteBucket(
revocationLogBucket,
)
if logBucket != nil {
err := chanBucket.DeleteNestedBucket(revocationLogBucket)
if err != nil {
return err
}
}
// We also check whether the old revocation log bucket exists
// and delete it if so.
oldLogBucket := chanBucket.NestedReadWriteBucket(
revocationLogBucketDeprecated,
)
if oldLogBucket != nil {
err := chanBucket.DeleteNestedBucket(
revocationLogBucketDeprecated,
)
if err != nil {
return err
}
}
return nil
}
package channeldb
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// waitingProofsBucketKey byte string name of the waiting proofs store.
waitingProofsBucketKey = []byte("waitingproofs")
// ErrWaitingProofNotFound is returned if waiting proofs haven't been
// found by db.
ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " +
"found")
// ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been
// found by db.
ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " +
"key already exist")
)
// WaitingProofStore is the bold db map-like storage for half announcement
// signatures. The one responsibility of this storage is to be able to
// retrieve waiting proofs after client restart.
type WaitingProofStore struct {
// cache is used in order to reduce the number of redundant get
// calls, when object isn't stored in it.
cache map[WaitingProofKey]struct{}
db kvdb.Backend
mu sync.RWMutex
}
// NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) {
s := &WaitingProofStore{
db: db,
}
if err := s.ForAll(func(proof *WaitingProof) error {
s.cache[proof.Key()] = struct{}{}
return nil
}, func() {
s.cache = make(map[WaitingProofKey]struct{})
}); err != nil && err != ErrWaitingProofNotFound {
return nil, err
}
return s, nil
}
// Add adds new waiting proof in the storage.
func (s *WaitingProofStore) Add(proof *WaitingProof) error {
s.mu.Lock()
defer s.mu.Unlock()
err := kvdb.Update(s.db, func(tx kvdb.RwTx) error {
var err error
var b bytes.Buffer
// Get or create the bucket.
bucket, err := tx.CreateTopLevelBucket(waitingProofsBucketKey)
if err != nil {
return err
}
// Encode the objects and place it in the bucket.
if err := proof.Encode(&b); err != nil {
return err
}
key := proof.Key()
return bucket.Put(key[:], b.Bytes())
}, func() {})
if err != nil {
return err
}
// Knowing that the write succeeded, we can now update the in-memory
// cache with the proof's key.
s.cache[proof.Key()] = struct{}{}
return nil
}
// Remove removes the proof from storage by its key.
func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.cache[key]; !ok {
return ErrWaitingProofNotFound
}
err := kvdb.Update(s.db, func(tx kvdb.RwTx) error {
// Get or create the top bucket.
bucket := tx.ReadWriteBucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
return bucket.Delete(key[:])
}, func() {})
if err != nil {
return err
}
// Since the proof was successfully deleted from the store, we can now
// remove it from the in-memory cache.
delete(s.cache, key)
return nil
}
// ForAll iterates thought all waiting proofs and passing the waiting proof
// in the given callback.
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error,
reset func()) error {
return kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
// Iterate over objects buckets.
return bucket.ForEach(func(k, v []byte) error {
// Skip buckets fields.
if v == nil {
return nil
}
r := bytes.NewReader(v)
proof := &WaitingProof{}
if err := proof.Decode(r); err != nil {
return err
}
return cb(proof)
})
}, reset)
}
// Get returns the object which corresponds to the given index.
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
var proof *WaitingProof
s.mu.RLock()
defer s.mu.RUnlock()
if _, ok := s.cache[key]; !ok {
return nil, ErrWaitingProofNotFound
}
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(waitingProofsBucketKey)
if bucket == nil {
return ErrWaitingProofNotFound
}
// Iterate over objects buckets.
v := bucket.Get(key[:])
if v == nil {
return ErrWaitingProofNotFound
}
r := bytes.NewReader(v)
return proof.Decode(r)
}, func() {
proof = &WaitingProof{}
})
return proof, err
}
// WaitingProofKey is the proof key which uniquely identifies the waiting
// proof object. The goal of this key is distinguish the local and remote
// proof for the same channel id.
type WaitingProofKey [9]byte
// WaitingProof is the storable object, which encapsulate the half proof and
// the information about from which side this proof came. This structure is
// needed to make channel proof exchange persistent, so that after client
// restart we may receive remote/local half proof and process it.
type WaitingProof struct {
*lnwire.AnnounceSignatures1
isRemote bool
}
// NewWaitingProof constructs a new waiting prof instance.
func NewWaitingProof(isRemote bool,
proof *lnwire.AnnounceSignatures1) *WaitingProof {
return &WaitingProof{
AnnounceSignatures1: proof,
isRemote: isRemote,
}
}
// OppositeKey returns the key which uniquely identifies opposite waiting proof.
func (p *WaitingProof) OppositeKey() WaitingProofKey {
var key [9]byte
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
if !p.isRemote {
key[8] = 1
}
return key
}
// Key returns the key which uniquely identifies waiting proof.
func (p *WaitingProof) Key() WaitingProofKey {
var key [9]byte
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
if p.isRemote {
key[8] = 1
}
return key
}
// Encode writes the internal representation of waiting proof in byte stream.
func (p *WaitingProof) Encode(w io.Writer) error {
if err := binary.Write(w, byteOrder, p.isRemote); err != nil {
return err
}
// TODO(yy): remove the type assertion when we finished refactoring db
// into using write buffer.
buf, ok := w.(*bytes.Buffer)
if !ok {
return fmt.Errorf("expect io.Writer to be *bytes.Buffer")
}
if err := p.AnnounceSignatures1.Encode(buf, 0); err != nil {
return err
}
return nil
}
// Decode reads the data from the byte stream and initializes the
// waiting proof object with it.
func (p *WaitingProof) Decode(r io.Reader) error {
if err := binary.Read(r, byteOrder, &p.isRemote); err != nil {
return err
}
msg := &lnwire.AnnounceSignatures1{}
if err := msg.Decode(r, 0); err != nil {
return err
}
p.AnnounceSignatures1 = msg
return nil
}
package channeldb
import (
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
)
var (
// ErrNoWitnesses is an error that's returned when no new witnesses have
// been added to the WitnessCache.
ErrNoWitnesses = fmt.Errorf("no witnesses")
// ErrUnknownWitnessType is returned if a caller attempts to
ErrUnknownWitnessType = fmt.Errorf("unknown witness type")
)
// WitnessType is enum that denotes what "type" of witness is being
// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce
// any structure on added witnesses, we use this type to partition the
// witnesses on disk, and also to know how to map a witness to its look up key.
type WitnessType uint8
var (
// Sha256HashWitness is a witness that is simply the pre image to a
// hash image. In order to map to its key, we'll use sha256.
Sha256HashWitness WitnessType = 1
)
// toDBKey is a helper method that maps a witness type to the key that we'll
// use to store it within the database.
func (w WitnessType) toDBKey() ([]byte, error) {
switch w {
case Sha256HashWitness:
return []byte{byte(w)}, nil
default:
return nil, ErrUnknownWitnessType
}
}
var (
// witnessBucketKey is the name of the bucket that we use to store all
// witnesses encountered. Within this bucket, we'll create a sub-bucket for
// each witness type.
witnessBucketKey = []byte("byte")
)
// WitnessCache is a persistent cache of all witnesses we've encountered on the
// network. In the case of multi-hop, multi-step contracts, a cache of all
// witnesses can be useful in the case of partial contract resolution. If
// negotiations break down, we may be forced to locate the witness for a
// portion of the contract on-chain. In this case, we'll then add that witness
// to the cache so the incoming contract can fully resolve witness.
// Additionally, as one MUST always use a unique witness on the network, we may
// use this cache to detect duplicate witnesses.
//
// TODO(roasbeef): need expiry policy?
// - encrypt?
type WitnessCache struct {
db *DB
}
// NewWitnessCache returns a new instance of the witness cache.
func (d *DB) NewWitnessCache() *WitnessCache {
return &WitnessCache{
db: d,
}
}
// witnessEntry is a key-value struct that holds each key -> witness pair, used
// when inserting records into the cache.
type witnessEntry struct {
key []byte
witness []byte
}
// AddSha256Witnesses adds a batch of new sha256 preimages into the witness
// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the
// preimages' witness type.
func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error {
// Optimistically compute the preimages' hashes before attempting to
// start the db transaction.
entries := make([]witnessEntry, 0, len(preimages))
for i := range preimages {
hash := preimages[i].Hash()
entries = append(entries, witnessEntry{
key: hash[:],
witness: preimages[i][:],
})
}
return w.addWitnessEntries(Sha256HashWitness, entries)
}
// addWitnessEntries inserts the witnessEntry key-value pairs into the cache,
// using the appropriate witness type to segment the namespace of possible
// witness types.
func (w *WitnessCache) addWitnessEntries(wType WitnessType,
entries []witnessEntry) error {
// Exit early if there are no witnesses to add.
if len(entries) == 0 {
return nil
}
return kvdb.Batch(w.db.Backend, func(tx kvdb.RwTx) error {
witnessBucket, err := tx.CreateTopLevelBucket(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
witnessTypeBucketKey,
)
if err != nil {
return err
}
for _, entry := range entries {
err = witnessTypeBucket.Put(entry.key, entry.witness)
if err != nil {
return err
}
}
return nil
})
}
// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If
// the witness isn't found, ErrNoWitnesses will be returned.
func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) {
witness, err := w.lookupWitness(Sha256HashWitness, hash[:])
if err != nil {
return lntypes.Preimage{}, err
}
return lntypes.MakePreimage(witness)
}
// lookupWitness attempts to lookup a witness according to its type and also
// its witness key. In the case that the witness isn't found, ErrNoWitnesses
// will be returned.
func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) {
var witness []byte
err := kvdb.View(w.db, func(tx kvdb.RTx) error {
witnessBucket := tx.ReadBucket(witnessBucketKey)
if witnessBucket == nil {
return ErrNoWitnesses
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket := witnessBucket.NestedReadBucket(witnessTypeBucketKey)
if witnessTypeBucket == nil {
return ErrNoWitnesses
}
dbWitness := witnessTypeBucket.Get(witnessKey)
if dbWitness == nil {
return ErrNoWitnesses
}
witness = make([]byte, len(dbWitness))
copy(witness[:], dbWitness)
return nil
}, func() {
witness = nil
})
if err != nil {
return nil, err
}
return witness, nil
}
// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash.
func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error {
return w.deleteWitness(Sha256HashWitness, hash[:])
}
// deleteWitness attempts to delete a particular witness from the database.
func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error {
return kvdb.Batch(w.db.Backend, func(tx kvdb.RwTx) error {
witnessBucket, err := tx.CreateTopLevelBucket(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
witnessTypeBucketKey,
)
if err != nil {
return err
}
return witnessTypeBucket.Delete(witnessKey)
})
}
// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After
// this function return with a non-nil error,
func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error {
return kvdb.Batch(w.db.Backend, func(tx kvdb.RwTx) error {
witnessBucket, err := tx.CreateTopLevelBucket(witnessBucketKey)
if err != nil {
return err
}
witnessTypeBucketKey, err := wType.toDBKey()
if err != nil {
return err
}
return witnessBucket.DeleteNestedBucket(witnessTypeBucketKey)
})
}
package contractcourt
import (
"errors"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/sweep"
)
// anchorResolver is a resolver that will attempt to sweep our anchor output.
type anchorResolver struct {
// anchorSignDescriptor contains the information that is required to
// sweep the anchor.
anchorSignDescriptor input.SignDescriptor
// anchor is the outpoint on the commitment transaction.
anchor wire.OutPoint
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
broadcastHeight uint32
// chanPoint is the channel point of the original contract.
chanPoint wire.OutPoint
// chanType denotes the type of channel the contract belongs to.
chanType channeldb.ChannelType
// currentReport stores the current state of the resolver for reporting
// over the rpc interface.
currentReport ContractReport
// reportLock prevents concurrent access to the resolver report.
reportLock sync.Mutex
contractResolverKit
}
// newAnchorResolver instantiates a new anchor resolver.
func newAnchorResolver(anchorSignDescriptor input.SignDescriptor,
anchor wire.OutPoint, broadcastHeight uint32,
chanPoint wire.OutPoint, resCfg ResolverConfig) *anchorResolver {
amt := btcutil.Amount(anchorSignDescriptor.Output.Value)
report := ContractReport{
Outpoint: anchor,
Type: ReportOutputAnchor,
Amount: amt,
LimboBalance: amt,
RecoveredBalance: 0,
}
r := &anchorResolver{
contractResolverKit: *newContractResolverKit(resCfg),
anchorSignDescriptor: anchorSignDescriptor,
anchor: anchor,
broadcastHeight: broadcastHeight,
chanPoint: chanPoint,
currentReport: report,
}
r.initLogger(fmt.Sprintf("%T(%v)", r, r.anchor))
return r
}
// ResolverKey returns an identifier which should be globally unique for this
// particular resolver within the chain the original contract resides within.
func (c *anchorResolver) ResolverKey() []byte {
// The anchor resolver is stateless and doesn't need a database key.
return nil
}
// Resolve waits for the output to be swept.
//
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil, nil
}
var (
outcome channeldb.ResolverOutcome
spendTx *chainhash.Hash
)
select {
case sweepRes := <-c.sweepResultChan:
err := sweepRes.Err
switch {
// Anchor was swept successfully.
case err == nil:
sweepTxID := sweepRes.Tx.TxHash()
spendTx = &sweepTxID
outcome = channeldb.ResolverOutcomeClaimed
// Anchor was swept by someone else. This is possible after the
// 16 block csv lock.
case errors.Is(err, sweep.ErrRemoteSpend),
errors.Is(err, sweep.ErrInputMissing):
c.log.Warnf("our anchor spent by someone else")
outcome = channeldb.ResolverOutcomeUnclaimed
// An unexpected error occurred.
default:
c.log.Errorf("unable to sweep anchor: %v", sweepRes.Err)
return nil, sweepRes.Err
}
case <-c.quit:
return nil, errResolverShuttingDown
}
c.log.Infof("resolved in tx %v", spendTx)
// Update report to reflect that funds are no longer in limbo.
c.reportLock.Lock()
if outcome == channeldb.ResolverOutcomeClaimed {
c.currentReport.RecoveredBalance = c.currentReport.LimboBalance
}
c.currentReport.LimboBalance = 0
report := c.currentReport.resolverReport(
spendTx, channeldb.ResolverTypeAnchor, outcome,
)
c.reportLock.Unlock()
c.markResolved()
return nil, c.PutResolverReport(nil, report)
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) Stop() {
c.log.Debugf("stopping...")
defer c.log.Debugf("stopped")
close(c.quit)
}
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
// NOTE: Part of the ContractResolver interface.
func (c *anchorResolver) SupplementState(state *channeldb.OpenChannel) {
c.chanType = state.ChanType
}
// report returns a report on the resolution state of the contract.
func (c *anchorResolver) report() *ContractReport {
c.reportLock.Lock()
defer c.reportLock.Unlock()
reportCopy := c.currentReport
return &reportCopy
}
func (c *anchorResolver) Encode(w io.Writer) error {
return errors.New("serialization not supported")
}
// A compile time assertion to ensure anchorResolver meets the
// ContractResolver interface.
var _ ContractResolver = (*anchorResolver)(nil)
// Launch offers the anchor output to the sweeper.
func (c *anchorResolver) Launch() error {
if c.isLaunched() {
c.log.Tracef("already launched")
return nil
}
c.log.Debugf("launching resolver...")
c.markLaunched()
// If we're already resolved, then we can exit early.
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil
}
// Attempt to update the sweep parameters to the post-confirmation
// situation. We don't want to force sweep anymore, because the anchor
// lost its special purpose to get the commitment confirmed. It is just
// an output that we want to sweep only if it is economical to do so.
//
// An exclusive group is not necessary anymore, because we know that
// this is the only anchor that can be swept.
//
// We also clear the parent tx information for cpfp, because the
// commitment tx is confirmed.
//
// After a restart or when the remote force closes, the sweeper is not
// yet aware of the anchor. In that case, it will be added as new input
// to the sweeper.
witnessType := input.CommitmentAnchor
// For taproot channels, we need to use the proper witness type.
if c.chanType.IsTaproot() {
witnessType = input.TaprootAnchorSweepSpend
}
anchorInput := input.MakeBaseInput(
&c.anchor, witnessType, &c.anchorSignDescriptor,
c.broadcastHeight, nil,
)
resultChan, err := c.Sweeper.SweepInput(
&anchorInput,
sweep.Params{
// For normal anchor sweeping, the budget is 330 sats.
Budget: btcutil.Amount(
anchorInput.SignDesc().Output.Value,
),
// There's no rush to sweep the anchor, so we use a nil
// deadline here.
DeadlineHeight: fn.None[int32](),
},
)
if err != nil {
return err
}
c.sweepResultChan = resultChan
return nil
}
package contractcourt
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// justiceTxConfTarget is the number of blocks we'll use as a
// confirmation target when creating the justice transaction. We'll
// choose an aggressive target, since we want to be sure it confirms
// quickly.
justiceTxConfTarget = 2
// blocksPassedSplitPublish is the number of blocks without
// confirmation of the justice tx we'll wait before starting to publish
// smaller variants of the justice tx. We do this to mitigate an attack
// the channel peer can do by pinning the HTLC outputs of the
// commitment with low-fee HTLC transactions.
blocksPassedSplitPublish = 4
)
var (
// retributionBucket stores retribution state on disk between detecting
// a contract breach, broadcasting a justice transaction that sweeps the
// channel, and finally witnessing the justice transaction confirm on
// the blockchain. It is critical that such state is persisted on disk,
// so that if our node restarts at any point during the retribution
// procedure, we can recover and continue from the persisted state.
retributionBucket = []byte("retribution")
// taprootRetributionBucket stores the tarpoot specific retribution
// information. This includes things like the control blocks for both
// commitment outputs, and the taptweak needed to sweep each HTLC (one
// for the first and one for the second level).
taprootRetributionBucket = []byte("tap-retribution")
// errBrarShuttingDown is an error returned if the BreachArbitrator has
// been signalled to exit.
errBrarShuttingDown = errors.New("BreachArbitrator shutting down")
)
// ContractBreachEvent is an event the BreachArbitrator will receive in case a
// contract breach is observed on-chain. It contains the necessary information
// to handle the breach, and a ProcessACK closure we will use to ACK the event
// when we have safely stored all the necessary information.
type ContractBreachEvent struct {
// ChanPoint is the channel point of the breached channel.
ChanPoint wire.OutPoint
// ProcessACK is an closure that should be called with a nil error iff
// the breach retribution info is safely stored in the retribution
// store. In case storing the information to the store fails, a non-nil
// error should be used. When this closure returns, it means that the
// contract court has marked the channel pending close in the DB, and
// it is safe for the BreachArbitrator to carry on its duty.
ProcessACK func(error)
// BreachRetribution is the information needed to act on this contract
// breach.
BreachRetribution *lnwallet.BreachRetribution
}
// ChannelCloseType is an enum which signals the type of channel closure the
// peer should execute.
type ChannelCloseType uint8
const (
// CloseRegular indicates a regular cooperative channel closure
// should be attempted.
CloseRegular ChannelCloseType = iota
// CloseBreach indicates that a channel breach has been detected, and
// the link should immediately be marked as unavailable.
CloseBreach
)
// RetributionStorer provides an interface for managing a persistent map from
// wire.OutPoint -> retributionInfo. Upon learning of a breach, a
// BreachArbitrator should record the retributionInfo for the breached channel,
// which serves a checkpoint in the event that retribution needs to be resumed
// after failure. A RetributionStore provides an interface for managing the
// persisted set, as well as mapping user defined functions over the entire
// on-disk contents.
//
// Calls to RetributionStore may occur concurrently. A concrete instance of
// RetributionStore should use appropriate synchronization primitives, or
// be otherwise safe for concurrent access.
type RetributionStorer interface {
// Add persists the retributionInfo to disk, using the information's
// chanPoint as the key. This method should overwrite any existing
// entries found under the same key, and an error should be raised if
// the addition fails.
Add(retInfo *retributionInfo) error
// IsBreached queries the retribution store to see if the breach arbiter
// is aware of any breaches for the provided channel point.
IsBreached(chanPoint *wire.OutPoint) (bool, error)
// Remove deletes the retributionInfo from disk, if any exists, under
// the given key. An error should be re raised if the removal fails.
Remove(key *wire.OutPoint) error
// ForAll iterates over the existing on-disk contents and applies a
// chosen, read-only callback to each. This method should ensure that it
// immediately propagate any errors generated by the callback.
ForAll(cb func(*retributionInfo) error, reset func()) error
}
// BreachConfig bundles the required subsystems used by the breach arbiter. An
// instance of BreachConfig is passed to NewBreachArbitrator during
// instantiation.
type BreachConfig struct {
// CloseLink allows the breach arbiter to shutdown any channel links for
// which it detects a breach, ensuring now further activity will
// continue across the link. The method accepts link's channel point and
// a close type to be included in the channel close summary.
CloseLink func(*wire.OutPoint, ChannelCloseType)
// DB provides access to the user's channels, allowing the breach
// arbiter to determine the current state of a user's channels, and how
// it should respond to channel closure.
DB *channeldb.ChannelStateDB
// Estimator is used by the breach arbiter to determine an appropriate
// fee level when generating, signing, and broadcasting sweep
// transactions.
Estimator chainfee.Estimator
// GenSweepScript generates the receiving scripts for swept outputs.
GenSweepScript func() fn.Result[lnwallet.AddrWithKey]
// Notifier provides a publish/subscribe interface for event driven
// notifications regarding the confirmation of txids.
Notifier chainntnfs.ChainNotifier
// PublishTransaction facilitates the process of broadcasting a
// transaction to the network.
PublishTransaction func(*wire.MsgTx, string) error
// ContractBreaches is a channel where the BreachArbitrator will receive
// notifications in the event of a contract breach being observed. A
// ContractBreachEvent must be ACKed by the BreachArbitrator, such that
// the sending subsystem knows that the event is properly handed off.
ContractBreaches <-chan *ContractBreachEvent
// Signer is used by the breach arbiter to generate sweep transactions,
// which move coins from previously open channels back to the user's
// wallet.
Signer input.Signer
// Store is a persistent resource that maintains information regarding
// breached channels. This is used in conjunction with DB to recover
// from crashes, restarts, or other failures.
Store RetributionStorer
// AuxSweeper is an optional interface that can be used to modify the
// way sweep transaction are generated.
AuxSweeper fn.Option[sweep.AuxSweeper]
}
// BreachArbitrator is a special subsystem which is responsible for watching and
// acting on the detection of any attempted uncooperative channel breaches by
// channel counterparties. This file essentially acts as deterrence code for
// those attempting to launch attacks against the daemon. In practice it's
// expected that the logic in this file never gets executed, but it is
// important to have it in place just in case we encounter cheating channel
// counterparties.
// TODO(roasbeef): closures in config for subsystem pointers to decouple?
type BreachArbitrator struct {
started sync.Once
stopped sync.Once
cfg *BreachConfig
subscriptions map[wire.OutPoint]chan struct{}
quit chan struct{}
wg sync.WaitGroup
sync.Mutex
}
// NewBreachArbitrator creates a new instance of a BreachArbitrator initialized
// with its dependent objects.
func NewBreachArbitrator(cfg *BreachConfig) *BreachArbitrator {
return &BreachArbitrator{
cfg: cfg,
subscriptions: make(map[wire.OutPoint]chan struct{}),
quit: make(chan struct{}),
}
}
// Start is an idempotent method that officially starts the BreachArbitrator
// along with all other goroutines it needs to perform its functions.
func (b *BreachArbitrator) Start() error {
var err error
b.started.Do(func() {
brarLog.Info("Breach arbiter starting")
err = b.start()
})
return err
}
func (b *BreachArbitrator) start() error {
// Load all retributions currently persisted in the retribution store.
var breachRetInfos map[wire.OutPoint]retributionInfo
if err := b.cfg.Store.ForAll(func(ret *retributionInfo) error {
breachRetInfos[ret.chanPoint] = *ret
return nil
}, func() {
breachRetInfos = make(map[wire.OutPoint]retributionInfo)
}); err != nil {
brarLog.Errorf("Unable to create retribution info: %v", err)
return err
}
// Load all currently closed channels from disk, we will use the
// channels that have been marked fully closed to filter the retribution
// information loaded from disk. This is necessary in the event that the
// channel was marked fully closed, but was not removed from the
// retribution store.
closedChans, err := b.cfg.DB.FetchClosedChannels(false)
if err != nil {
brarLog.Errorf("Unable to fetch closing channels: %v", err)
return err
}
brarLog.Debugf("Found %v closing channels, %v retribution records",
len(closedChans), len(breachRetInfos))
// Using the set of non-pending, closed channels, reconcile any
// discrepancies between the channeldb and the retribution store by
// removing any retribution information for which we have already
// finished our responsibilities. If the removal is successful, we also
// remove the entry from our in-memory map, to avoid any further action
// for this channel.
// TODO(halseth): no need continue on IsPending once closed channels
// actually means close transaction is confirmed.
for _, chanSummary := range closedChans {
brarLog.Debugf("Working on close channel: %v, is_pending: %v",
chanSummary.ChanPoint, chanSummary.IsPending)
if chanSummary.IsPending {
continue
}
chanPoint := &chanSummary.ChanPoint
if _, ok := breachRetInfos[*chanPoint]; ok {
if err := b.cfg.Store.Remove(chanPoint); err != nil {
brarLog.Errorf("Unable to remove closed "+
"chanid=%v from breach arbiter: %v",
chanPoint, err)
return err
}
delete(breachRetInfos, *chanPoint)
brarLog.Debugf("Skipped closed channel: %v",
chanSummary.ChanPoint)
}
}
// Spawn the exactRetribution tasks to monitor and resolve any breaches
// that were loaded from the retribution store.
for chanPoint := range breachRetInfos {
retInfo := breachRetInfos[chanPoint]
brarLog.Debugf("Handling breach handoff on startup "+
"for ChannelPoint(%v)", chanPoint)
// Register for a notification when the breach transaction is
// confirmed on chain.
breachTXID := retInfo.commitHash
breachScript := retInfo.breachedOutputs[0].signDesc.Output.PkScript
confChan, err := b.cfg.Notifier.RegisterConfirmationsNtfn(
&breachTXID, breachScript, 1, retInfo.breachHeight,
)
if err != nil {
brarLog.Errorf("Unable to register for conf updates "+
"for txid: %v, err: %v", breachTXID, err)
return err
}
// Launch a new goroutine which to finalize the channel
// retribution after the breach transaction confirms.
b.wg.Add(1)
go b.exactRetribution(confChan, &retInfo)
}
// Start watching the remaining active channels!
b.wg.Add(1)
go b.contractObserver()
return nil
}
// Stop is an idempotent method that signals the BreachArbitrator to execute a
// graceful shutdown. This function will block until all goroutines spawned by
// the BreachArbitrator have gracefully exited.
func (b *BreachArbitrator) Stop() error {
b.stopped.Do(func() {
brarLog.Infof("Breach arbiter shutting down...")
defer brarLog.Debug("Breach arbiter shutdown complete")
close(b.quit)
b.wg.Wait()
})
return nil
}
// IsBreached queries the breach arbiter's retribution store to see if it is
// aware of any channel breaches for a particular channel point.
func (b *BreachArbitrator) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
return b.cfg.Store.IsBreached(chanPoint)
}
// SubscribeBreachComplete is used by outside subsystems to be notified of a
// successful breach resolution.
func (b *BreachArbitrator) SubscribeBreachComplete(chanPoint *wire.OutPoint,
c chan struct{}) (bool, error) {
breached, err := b.cfg.Store.IsBreached(chanPoint)
if err != nil {
// If an error occurs, no subscription will be registered.
return false, err
}
if !breached {
// If chanPoint no longer exists in the Store, then the breach
// was cleaned up successfully. Any subscription that occurs
// happens after the breach information was persisted to the
// underlying store.
return true, nil
}
// Otherwise since the channel point is not resolved, add a
// subscription. There can only be one subscription per channel point.
b.Lock()
defer b.Unlock()
b.subscriptions[*chanPoint] = c
return false, nil
}
// notifyBreachComplete is used by the BreachArbitrator to notify outside
// subsystems that the breach resolution process is complete.
func (b *BreachArbitrator) notifyBreachComplete(chanPoint *wire.OutPoint) {
b.Lock()
defer b.Unlock()
if c, ok := b.subscriptions[*chanPoint]; ok {
close(c)
}
// Remove the subscription.
delete(b.subscriptions, *chanPoint)
}
// contractObserver is the primary goroutine for the BreachArbitrator. This
// goroutine is responsible for handling breach events coming from the
// contractcourt on the ContractBreaches channel. If a channel breach is
// detected, then the contractObserver will execute the retribution logic
// required to sweep ALL outputs from a contested channel into the daemon's
// wallet.
//
// NOTE: This MUST be run as a goroutine.
func (b *BreachArbitrator) contractObserver() {
defer b.wg.Done()
brarLog.Infof("Starting contract observer, watching for breaches.")
for {
select {
case breachEvent := <-b.cfg.ContractBreaches:
// We have been notified about a contract breach!
// Handle the handoff, making sure we ACK the event
// after we have safely added it to the retribution
// store.
b.wg.Add(1)
go b.handleBreachHandoff(breachEvent)
case <-b.quit:
return
}
}
}
// spend is used to wrap the index of the retributionInfo output that gets
// spent together with the spend details.
type spend struct {
index int
detail *chainntnfs.SpendDetail
}
// waitForSpendEvent waits for any of the breached outputs to get spent, and
// returns the spend details for those outputs. The spendNtfns map is a cache
// used to store registered spend subscriptions, in case we must call this
// method multiple times.
func (b *BreachArbitrator) waitForSpendEvent(breachInfo *retributionInfo,
spendNtfns map[wire.OutPoint]*chainntnfs.SpendEvent) ([]spend, error) {
inputs := breachInfo.breachedOutputs
// We create a channel the first goroutine that gets a spend event can
// signal. We make it buffered in case multiple spend events come in at
// the same time.
anySpend := make(chan struct{}, len(inputs))
// The allSpends channel will be used to pass spend events from all the
// goroutines that detects a spend before they are signalled to exit.
allSpends := make(chan spend, len(inputs))
// exit will be used to signal the goroutines that they can exit.
exit := make(chan struct{})
var wg sync.WaitGroup
// We'll now launch a goroutine for each of the HTLC outputs, that will
// signal the moment they detect a spend event.
for i := range inputs {
breachedOutput := &inputs[i]
brarLog.Infof("Checking spend from %v(%v) for ChannelPoint(%v)",
breachedOutput.witnessType, breachedOutput.outpoint,
breachInfo.chanPoint)
// If we have already registered for a notification for this
// output, we'll reuse it.
spendNtfn, ok := spendNtfns[breachedOutput.outpoint]
if !ok {
var err error
spendNtfn, err = b.cfg.Notifier.RegisterSpendNtfn(
&breachedOutput.outpoint,
breachedOutput.signDesc.Output.PkScript,
breachInfo.breachHeight,
)
if err != nil {
brarLog.Errorf("Unable to check for spentness "+
"of outpoint=%v: %v",
breachedOutput.outpoint, err)
// Registration may have failed if we've been
// instructed to shutdown. If so, return here
// to avoid entering an infinite loop.
select {
case <-b.quit:
return nil, errBrarShuttingDown
default:
continue
}
}
spendNtfns[breachedOutput.outpoint] = spendNtfn
}
// Launch a goroutine waiting for a spend event.
b.wg.Add(1)
wg.Add(1)
go func(index int, spendEv *chainntnfs.SpendEvent) {
defer b.wg.Done()
defer wg.Done()
select {
// The output has been taken to the second level!
case sp, ok := <-spendEv.Spend:
if !ok {
return
}
brarLog.Infof("Detected spend on %s(%v) by "+
"txid(%v) for ChannelPoint(%v)",
inputs[index].witnessType,
inputs[index].outpoint,
sp.SpenderTxHash,
breachInfo.chanPoint)
// First we send the spend event on the
// allSpends channel, such that it can be
// handled after all go routines have exited.
allSpends <- spend{index, sp}
// Finally we'll signal the anySpend channel
// that a spend was detected, such that the
// other goroutines can be shut down.
anySpend <- struct{}{}
case <-exit:
return
case <-b.quit:
return
}
}(i, spendNtfn)
}
// We'll wait for any of the outputs to be spent, or that we are
// signalled to exit.
select {
// A goroutine have signalled that a spend occurred.
case <-anySpend:
// Signal for the remaining goroutines to exit.
close(exit)
wg.Wait()
// At this point all goroutines that can send on the allSpends
// channel have exited. We can therefore safely close the
// channel before ranging over its content.
close(allSpends)
// Gather all detected spends and return them.
var spends []spend
for s := range allSpends {
breachedOutput := &inputs[s.index]
delete(spendNtfns, breachedOutput.outpoint)
spends = append(spends, s)
}
return spends, nil
case <-b.quit:
return nil, errBrarShuttingDown
}
}
// convertToSecondLevelRevoke takes a breached output, and a transaction that
// spends it to the second level, and mutates the breach output into one that
// is able to properly sweep that second level output. We'll use this function
// when we go to sweep a breached commitment transaction, but the cheating
// party has already attempted to take it to the second level.
func convertToSecondLevelRevoke(bo *breachedOutput, breachInfo *retributionInfo,
spendDetails *chainntnfs.SpendDetail) {
// In this case, we'll modify the witness type of this output to
// actually prepare for a second level revoke.
isTaproot := txscript.IsPayToTaproot(bo.signDesc.Output.PkScript)
if isTaproot {
bo.witnessType = input.TaprootHtlcSecondLevelRevoke
} else {
bo.witnessType = input.HtlcSecondLevelRevoke
}
// We'll also redirect the outpoint to this second level output, so the
// spending transaction updates it inputs accordingly.
spendingTx := spendDetails.SpendingTx
spendInputIndex := spendDetails.SpenderInputIndex
oldOp := bo.outpoint
bo.outpoint = wire.OutPoint{
Hash: spendingTx.TxHash(),
Index: spendInputIndex,
}
// Next, we need to update the amount so we can do fee estimation
// properly, and also so we can generate a valid signature as we need
// to know the new input value (the second level transactions shaves
// off some funds to fees).
newAmt := spendingTx.TxOut[spendInputIndex].Value
bo.amt = btcutil.Amount(newAmt)
bo.signDesc.Output.Value = newAmt
bo.signDesc.Output.PkScript = spendingTx.TxOut[spendInputIndex].PkScript
// For taproot outputs, the taptweak also needs to be swapped out. We
// do this unconditionally as this field isn't used at all for segwit
// v0 outputs.
bo.signDesc.TapTweak = bo.secondLevelTapTweak[:]
// Finally, we'll need to adjust the witness program in the
// SignDescriptor.
bo.signDesc.WitnessScript = bo.secondLevelWitnessScript
brarLog.Warnf("HTLC(%v) for ChannelPoint(%v) has been spent to the "+
"second-level, adjusting -> %v", oldOp, breachInfo.chanPoint,
bo.outpoint)
}
// updateBreachInfo mutates the passed breachInfo by removing or converting any
// outputs among the spends. It also counts the total and revoked funds swept
// by our justice spends.
func updateBreachInfo(breachInfo *retributionInfo, spends []spend) (
btcutil.Amount, btcutil.Amount) {
inputs := breachInfo.breachedOutputs
doneOutputs := make(map[int]struct{})
var totalFunds, revokedFunds btcutil.Amount
for _, s := range spends {
breachedOutput := &inputs[s.index]
txIn := s.detail.SpendingTx.TxIn[s.detail.SpenderInputIndex]
switch breachedOutput.witnessType {
case input.TaprootHtlcAcceptedRevoke:
fallthrough
case input.TaprootHtlcOfferedRevoke:
fallthrough
case input.HtlcAcceptedRevoke:
fallthrough
case input.HtlcOfferedRevoke:
// If the HTLC output was spent using the revocation
// key, it is our own spend, and we can forget the
// output. Otherwise it has been taken to the second
// level.
signDesc := &breachedOutput.signDesc
ok, err := input.IsHtlcSpendRevoke(txIn, signDesc)
if err != nil {
brarLog.Errorf("Unable to determine if "+
"revoke spend: %v", err)
break
}
if ok {
brarLog.Debugf("HTLC spend was our own " +
"revocation spend")
break
}
brarLog.Infof("Spend on second-level "+
"%s(%v) for ChannelPoint(%v) "+
"transitions to second-level output",
breachedOutput.witnessType,
breachedOutput.outpoint, breachInfo.chanPoint)
// In this case we'll morph our initial revoke
// spend to instead point to the second level
// output, and update the sign descriptor in the
// process.
convertToSecondLevelRevoke(
breachedOutput, breachInfo, s.detail,
)
continue
}
// Now that we have determined the spend is done by us, we
// count the total and revoked funds swept depending on the
// input type.
switch breachedOutput.witnessType {
// If the output being revoked is the remote commitment output
// or an offered HTLC output, its amount contributes to the
// value of funds being revoked from the counter party.
case input.CommitmentRevoke, input.TaprootCommitmentRevoke,
input.HtlcSecondLevelRevoke,
input.TaprootHtlcSecondLevelRevoke,
input.TaprootHtlcOfferedRevoke, input.HtlcOfferedRevoke:
revokedFunds += breachedOutput.Amount()
}
totalFunds += breachedOutput.Amount()
brarLog.Infof("Spend on %s(%v) for ChannelPoint(%v) "+
"transitions output to terminal state, "+
"removing input from justice transaction",
breachedOutput.witnessType,
breachedOutput.outpoint, breachInfo.chanPoint)
doneOutputs[s.index] = struct{}{}
}
// Filter the inputs for which we can no longer proceed.
var nextIndex int
for i := range inputs {
if _, ok := doneOutputs[i]; ok {
continue
}
inputs[nextIndex] = inputs[i]
nextIndex++
}
// Update our remaining set of outputs before continuing with
// another attempt at publication.
breachInfo.breachedOutputs = inputs[:nextIndex]
return totalFunds, revokedFunds
}
// exactRetribution is a goroutine which is executed once a contract breach has
// been detected by a breachObserver. This function is responsible for
// punishing a counterparty for violating the channel contract by sweeping ALL
// the lingering funds within the channel into the daemon's wallet.
//
// NOTE: This MUST be run as a goroutine.
//
//nolint:funlen
func (b *BreachArbitrator) exactRetribution(
confChan *chainntnfs.ConfirmationEvent, breachInfo *retributionInfo) {
defer b.wg.Done()
// TODO(roasbeef): state needs to be checkpointed here
select {
case _, ok := <-confChan.Confirmed:
// If the second value is !ok, then the channel has been closed
// signifying a daemon shutdown, so we exit.
if !ok {
return
}
// Otherwise, if this is a real confirmation notification, then
// we fall through to complete our duty.
case <-b.quit:
return
}
brarLog.Debugf("Breach transaction %v has been confirmed, sweeping "+
"revoked funds", breachInfo.commitHash)
// We may have to wait for some of the HTLC outputs to be spent to the
// second level before broadcasting the justice tx. We'll store the
// SpendEvents between each attempt to not re-register unnecessarily.
spendNtfns := make(map[wire.OutPoint]*chainntnfs.SpendEvent)
// Compute both the total value of funds being swept and the
// amount of funds that were revoked from the counter party.
var totalFunds, revokedFunds btcutil.Amount
justiceTxBroadcast:
// With the breach transaction confirmed, we now create the
// justice tx which will claim ALL the funds within the
// channel.
justiceTxs, err := b.createJusticeTx(breachInfo.breachedOutputs)
if err != nil {
brarLog.Errorf("Unable to create justice tx: %v", err)
return
}
finalTx := justiceTxs.spendAll
brarLog.Debugf("Broadcasting justice tx: %v", lnutils.SpewLogClosure(
finalTx))
// As we're about to broadcast our breach transaction, we'll notify the
// aux sweeper of our broadcast attempt first.
err = fn.MapOptionZ(b.cfg.AuxSweeper, func(aux sweep.AuxSweeper) error {
bumpReq := sweep.BumpRequest{
Inputs: finalTx.inputs,
DeliveryAddress: finalTx.sweepAddr,
ExtraTxOut: finalTx.extraTxOut,
}
return aux.NotifyBroadcast(
&bumpReq, finalTx.justiceTx, finalTx.fee, nil,
)
})
if err != nil {
brarLog.Errorf("unable to notify broadcast: %w", err)
return
}
// We'll now attempt to broadcast the transaction which finalized the
// channel's retribution against the cheating counter party.
label := labels.MakeLabel(labels.LabelTypeJusticeTransaction, nil)
err = b.cfg.PublishTransaction(finalTx.justiceTx, label)
if err != nil {
brarLog.Errorf("Unable to broadcast justice tx: %v", err)
}
// Regardless of publication succeeded or not, we now wait for any of
// the inputs to be spent. If any input got spent by the remote, we
// must recreate our justice transaction.
var (
spendChan = make(chan []spend, 1)
errChan = make(chan error, 1)
wg sync.WaitGroup
)
wg.Add(1)
go func() {
defer wg.Done()
spends, err := b.waitForSpendEvent(breachInfo, spendNtfns)
if err != nil {
errChan <- err
return
}
spendChan <- spends
}()
// We'll also register for block notifications, such that in case our
// justice tx doesn't confirm within a reasonable timeframe, we can
// start to more aggressively sweep the time sensitive outputs.
newBlockChan, err := b.cfg.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
brarLog.Errorf("Unable to register for block notifications: %v",
err)
return
}
defer newBlockChan.Cancel()
Loop:
for {
select {
case spends := <-spendChan:
// Update the breach info with the new spends.
t, r := updateBreachInfo(breachInfo, spends)
totalFunds += t
revokedFunds += r
brarLog.Infof("%v spends from breach tx for "+
"ChannelPoint(%v) has been detected, %v "+
"revoked funds (%v total) have been claimed",
len(spends), breachInfo.chanPoint,
revokedFunds, totalFunds)
if len(breachInfo.breachedOutputs) == 0 {
brarLog.Infof("Justice for ChannelPoint(%v) "+
"has been served, %v revoked funds "+
"(%v total) have been claimed. No "+
"more outputs to sweep, marking fully "+
"resolved", breachInfo.chanPoint,
revokedFunds, totalFunds)
err = b.cleanupBreach(&breachInfo.chanPoint)
if err != nil {
brarLog.Errorf("Failed to cleanup "+
"breached ChannelPoint(%v): %v",
breachInfo.chanPoint, err)
}
// TODO(roasbeef): add peer to blacklist?
// TODO(roasbeef): close other active channels
// with offending peer
break Loop
}
brarLog.Infof("Attempting another justice tx "+
"with %d inputs",
len(breachInfo.breachedOutputs))
wg.Wait()
goto justiceTxBroadcast
// On every new block, we check whether we should republish the
// transactions.
case epoch, ok := <-newBlockChan.Epochs:
if !ok {
return
}
// If less than four blocks have passed since the
// breach confirmed, we'll continue waiting. It was
// published with a 2-block fee estimate, so it's not
// unexpected that four blocks without confirmation can
// pass.
splitHeight := breachInfo.breachHeight +
blocksPassedSplitPublish
if uint32(epoch.Height) < splitHeight {
continue Loop
}
brarLog.Warnf("Block height %v arrived without "+
"justice tx confirming (breached at "+
"height %v), splitting justice tx.",
epoch.Height, breachInfo.breachHeight)
// Otherwise we'll attempt to publish the two separate
// justice transactions that sweeps the commitment
// outputs and the HTLC outputs separately. This is to
// mitigate the case where our "spend all" justice TX
// doesn't propagate because the HTLC outputs have been
// pinned by low fee HTLC txs.
label := labels.MakeLabel(
labels.LabelTypeJusticeTransaction, nil,
)
if justiceTxs.spendCommitOuts != nil {
tx := justiceTxs.spendCommitOuts
brarLog.Debugf("Broadcasting justice tx "+
"spending commitment outs: %v",
lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil {
brarLog.Warnf("Unable to broadcast "+
"commit out spending justice "+
"tx: %v", err)
}
}
if justiceTxs.spendHTLCs != nil {
tx := justiceTxs.spendHTLCs
brarLog.Debugf("Broadcasting justice tx "+
"spending HTLC outs: %v",
lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil {
brarLog.Warnf("Unable to broadcast "+
"HTLC out spending justice "+
"tx: %v", err)
}
}
for _, tx := range justiceTxs.spendSecondLevelHTLCs {
tx := tx
brarLog.Debugf("Broadcasting justice tx "+
"spending second-level HTLC output: %v",
lnutils.SpewLogClosure(tx))
err = b.cfg.PublishTransaction(
tx.justiceTx, label,
)
if err != nil {
brarLog.Warnf("Unable to broadcast "+
"second-level HTLC out "+
"spending justice tx: %v", err)
}
}
case err := <-errChan:
if err != errBrarShuttingDown {
brarLog.Errorf("error waiting for "+
"spend event: %v", err)
}
break Loop
case <-b.quit:
break Loop
}
}
// Wait for our go routine to exit.
wg.Wait()
}
// cleanupBreach marks the given channel point as fully resolved and removes the
// retribution for that the channel from the retribution store.
func (b *BreachArbitrator) cleanupBreach(chanPoint *wire.OutPoint) error {
// With the channel closed, mark it in the database as such.
err := b.cfg.DB.MarkChanFullyClosed(chanPoint)
if err != nil {
return fmt.Errorf("unable to mark chan as closed: %w", err)
}
// Justice has been carried out; we can safely delete the retribution
// info from the database.
err = b.cfg.Store.Remove(chanPoint)
if err != nil {
return fmt.Errorf("unable to remove retribution from db: %w",
err)
}
// This is after the Remove call so that the chan passed in via
// SubscribeBreachComplete is always notified, no matter when it is
// called. Otherwise, if notifyBreachComplete was before Remove, a
// very rare edge case could occur in which SubscribeBreachComplete
// is called after notifyBreachComplete and before Remove, meaning the
// caller would never be notified.
b.notifyBreachComplete(chanPoint)
return nil
}
// handleBreachHandoff handles a new breach event, by writing it to disk, then
// notifies the BreachArbitrator contract observer goroutine that a channel's
// contract has been breached by the prior counterparty. Once notified the
// BreachArbitrator will attempt to sweep ALL funds within the channel using the
// information provided within the BreachRetribution generated due to the
// breach of channel contract. The funds will be swept only after the breaching
// transaction receives a necessary number of confirmations.
//
// NOTE: This MUST be run as a goroutine.
func (b *BreachArbitrator) handleBreachHandoff(
breachEvent *ContractBreachEvent) {
defer b.wg.Done()
chanPoint := breachEvent.ChanPoint
brarLog.Debugf("Handling breach handoff for ChannelPoint(%v)",
chanPoint)
// A read from this channel indicates that a channel breach has been
// detected! So we notify the main coordination goroutine with the
// information needed to bring the counterparty to justice.
breachInfo := breachEvent.BreachRetribution
brarLog.Warnf("REVOKED STATE #%v FOR ChannelPoint(%v) "+
"broadcast, REMOTE PEER IS DOING SOMETHING "+
"SKETCHY!!!", breachInfo.RevokedStateNum,
chanPoint)
// Immediately notify the HTLC switch that this link has been
// breached in order to ensure any incoming or outgoing
// multi-hop HTLCs aren't sent over this link, nor any other
// links associated with this peer.
b.cfg.CloseLink(&chanPoint, CloseBreach)
// TODO(roasbeef): need to handle case of remote broadcast
// mid-local initiated state-transition, possible
// false-positive?
// Acquire the mutex to ensure consistency between the call to
// IsBreached and Add below.
b.Lock()
// We first check if this breach info is already added to the
// retribution store.
breached, err := b.cfg.Store.IsBreached(&chanPoint)
if err != nil {
b.Unlock()
brarLog.Errorf("Unable to check breach info in DB: %v", err)
// Notify about the failed lookup and return.
breachEvent.ProcessACK(err)
return
}
// If this channel is already marked as breached in the retribution
// store, we already have handled the handoff for this breach. In this
// case we can safely ACK the handoff, and return.
if breached {
b.Unlock()
breachEvent.ProcessACK(nil)
return
}
// Using the breach information provided by the wallet and the
// channel snapshot, construct the retribution information that
// will be persisted to disk.
retInfo := newRetributionInfo(&chanPoint, breachInfo)
// Persist the pending retribution state to disk.
err = b.cfg.Store.Add(retInfo)
b.Unlock()
if err != nil {
brarLog.Errorf("Unable to persist retribution "+
"info to db: %v", err)
}
// Now that the breach has been persisted, try to send an
// acknowledgment back to the close observer with the error. If
// the ack is successful, the close observer will mark the
// channel as pending-closed in the channeldb.
breachEvent.ProcessACK(err)
// Bail if we failed to persist retribution info.
if err != nil {
return
}
// Now that a new channel contract has been added to the retribution
// store, we first register for a notification to be dispatched once
// the breach transaction (the revoked commitment transaction) has been
// confirmed in the chain to ensure we're not dealing with a moving
// target.
breachTXID := &retInfo.commitHash
breachScript := retInfo.breachedOutputs[0].signDesc.Output.PkScript
cfChan, err := b.cfg.Notifier.RegisterConfirmationsNtfn(
breachTXID, breachScript, 1, retInfo.breachHeight,
)
if err != nil {
brarLog.Errorf("Unable to register for conf updates for "+
"txid: %v, err: %v", breachTXID, err)
return
}
brarLog.Warnf("A channel has been breached with txid: %v. Waiting "+
"for confirmation, then justice will be served!", breachTXID)
// With the retribution state persisted, channel close persisted, and
// notification registered, we launch a new goroutine which will
// finalize the channel retribution after the breach transaction has
// been confirmed.
b.wg.Add(1)
go b.exactRetribution(cfChan, retInfo)
}
// breachedOutput contains all the information needed to sweep a breached
// output. A breached output is an output that we are now entitled to due to a
// revoked commitment transaction being broadcast.
type breachedOutput struct {
amt btcutil.Amount
outpoint wire.OutPoint
witnessType input.StandardWitnessType
signDesc input.SignDescriptor
confHeight uint32
secondLevelWitnessScript []byte
secondLevelTapTweak [32]byte
witnessFunc input.WitnessGenerator
resolutionBlob fn.Option[tlv.Blob]
// TODO(roasbeef): function opt and hook into brar
}
// makeBreachedOutput assembles a new breachedOutput that can be used by the
// breach arbiter to construct a justice or sweep transaction.
func makeBreachedOutput(outpoint *wire.OutPoint,
witnessType input.StandardWitnessType, secondLevelScript []byte,
signDescriptor *input.SignDescriptor, confHeight uint32,
resolutionBlob fn.Option[tlv.Blob]) breachedOutput {
amount := signDescriptor.Output.Value
return breachedOutput{
amt: btcutil.Amount(amount),
outpoint: *outpoint,
secondLevelWitnessScript: secondLevelScript,
witnessType: witnessType,
signDesc: *signDescriptor,
confHeight: confHeight,
resolutionBlob: resolutionBlob,
}
}
// Amount returns the number of satoshis contained in the breached output.
func (bo *breachedOutput) Amount() btcutil.Amount {
return bo.amt
}
// OutPoint returns the breached output's identifier that is to be included as a
// transaction input.
func (bo *breachedOutput) OutPoint() wire.OutPoint {
return bo.outpoint
}
// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to make
// sure any presigned input is still valid by including the output.
func (bo *breachedOutput) RequiredTxOut() *wire.TxOut {
return nil
}
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it.
func (bo *breachedOutput) RequiredLockTime() (uint32, bool) {
return 0, false
}
// WitnessType returns the type of witness that must be generated to spend the
// breached output.
func (bo *breachedOutput) WitnessType() input.WitnessType {
return bo.witnessType
}
// SignDesc returns the breached output's SignDescriptor, which is used during
// signing to compute the witness.
func (bo *breachedOutput) SignDesc() *input.SignDescriptor {
return &bo.signDesc
}
// Preimage returns the preimage that was used to create the breached output.
func (bo *breachedOutput) Preimage() fn.Option[lntypes.Preimage] {
return fn.None[lntypes.Preimage]()
}
// CraftInputScript computes a valid witness that allows us to spend from the
// breached output. It does so by first generating and memoizing the witness
// generation function, which parameterized primarily by the witness type and
// sign descriptor. The method then returns the witness computed by invoking
// this function on the first and subsequent calls.
func (bo *breachedOutput) CraftInputScript(signer input.Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (*input.Script, error) {
// First, we ensure that the witness generation function has been
// initialized for this breached output.
signDesc := bo.SignDesc()
signDesc.PrevOutputFetcher = prevOutputFetcher
bo.witnessFunc = bo.witnessType.WitnessGenerator(signer, signDesc)
// Now that we have ensured that the witness generation function has
// been initialized, we can proceed to execute it and generate the
// witness for this particular breached output.
return bo.witnessFunc(txn, hashCache, txinIdx)
}
// BlocksToMaturity returns the relative timelock, as a number of blocks, that
// must be built on top of the confirmation height before the output can be
// spent.
func (bo *breachedOutput) BlocksToMaturity() uint32 {
// If the output is a to_remote output we can claim, and it's of the
// confirmed type (or is a taproot channel that always has the CSV 1),
// we must wait one block before claiming it.
switch bo.witnessType {
case input.CommitmentToRemoteConfirmed, input.TaprootRemoteCommitSpend:
return 1
}
// All other breached outputs have no CSV delay.
return 0
}
// HeightHint returns the minimum height at which a confirmed spending tx can
// occur.
func (bo *breachedOutput) HeightHint() uint32 {
return bo.confHeight
}
// UnconfParent returns information about a possibly unconfirmed parent tx.
func (bo *breachedOutput) UnconfParent() *input.TxInfo {
return nil
}
// ResolutionBlob returns a special opaque blob to be used to sweep/resolve this
// input.
func (bo *breachedOutput) ResolutionBlob() fn.Option[tlv.Blob] {
return bo.resolutionBlob
}
// Add compile-time constraint ensuring breachedOutput implements the Input
// interface.
var _ input.Input = (*breachedOutput)(nil)
// retributionInfo encapsulates all the data needed to sweep all the contested
// funds within a channel whose contract has been breached by the prior
// counterparty. This struct is used to create the justice transaction which
// spends all outputs of the commitment transaction into an output controlled
// by the wallet.
type retributionInfo struct {
commitHash chainhash.Hash
chanPoint wire.OutPoint
chainHash chainhash.Hash
breachHeight uint32
breachedOutputs []breachedOutput
}
// newRetributionInfo constructs a retributionInfo containing all the
// information required by the breach arbiter to recover funds from breached
// channels. The information is primarily populated using the BreachRetribution
// delivered by the wallet when it detects a channel breach.
func newRetributionInfo(chanPoint *wire.OutPoint,
breachInfo *lnwallet.BreachRetribution) *retributionInfo {
// Determine the number of second layer HTLCs we will attempt to sweep.
nHtlcs := len(breachInfo.HtlcRetributions)
// Initialize a slice to hold the outputs we will attempt to sweep. The
// maximum capacity of the slice is set to 2+nHtlcs to handle the case
// where the local, remote, and all HTLCs are not dust outputs. All
// HTLC outputs provided by the wallet are guaranteed to be non-dust,
// though the commitment outputs are conditionally added depending on
// the nil-ness of their sign descriptors.
breachedOutputs := make([]breachedOutput, 0, nHtlcs+2)
isTaproot := func() bool {
if breachInfo.LocalOutputSignDesc != nil {
return txscript.IsPayToTaproot(
breachInfo.LocalOutputSignDesc.Output.PkScript,
)
}
return txscript.IsPayToTaproot(
breachInfo.RemoteOutputSignDesc.Output.PkScript,
)
}()
// First, record the breach information for the local channel point if
// it is not considered dust, which is signaled by a non-nil sign
// descriptor. Here we use CommitmentNoDelay (or
// CommitmentNoDelayTweakless for newer commitments) since this output
// belongs to us and has no time-based constraints on spending. For
// taproot channels, this is a normal spend from our output on the
// commitment of the remote party.
if breachInfo.LocalOutputSignDesc != nil {
var witnessType input.StandardWitnessType
switch {
case isTaproot:
witnessType = input.TaprootRemoteCommitSpend
case !isTaproot &&
breachInfo.LocalOutputSignDesc.SingleTweak == nil:
witnessType = input.CommitSpendNoDelayTweakless
case !isTaproot:
witnessType = input.CommitmentNoDelay
}
// If the local delay is non-zero, it means this output is of
// the confirmed to_remote type.
if !isTaproot && breachInfo.LocalDelay != 0 {
witnessType = input.CommitmentToRemoteConfirmed
}
localOutput := makeBreachedOutput(
&breachInfo.LocalOutpoint,
witnessType,
// No second level script as this is a commitment
// output.
nil,
breachInfo.LocalOutputSignDesc,
breachInfo.BreachHeight,
breachInfo.LocalResolutionBlob,
)
breachedOutputs = append(breachedOutputs, localOutput)
}
// Second, record the same information regarding the remote outpoint,
// again if it is not dust, which belongs to the party who tried to
// steal our money! Here we set witnessType of the breachedOutput to
// CommitmentRevoke, since we will be using a revoke key, withdrawing
// the funds from the commitment transaction immediately.
if breachInfo.RemoteOutputSignDesc != nil {
var witType input.StandardWitnessType
if isTaproot {
witType = input.TaprootCommitmentRevoke
} else {
witType = input.CommitmentRevoke
}
remoteOutput := makeBreachedOutput(
&breachInfo.RemoteOutpoint,
witType,
// No second level script as this is a commitment
// output.
nil,
breachInfo.RemoteOutputSignDesc,
breachInfo.BreachHeight,
breachInfo.RemoteResolutionBlob,
)
breachedOutputs = append(breachedOutputs, remoteOutput)
}
// Lastly, for each of the breached HTLC outputs, record each as a
// breached output with the appropriate witness type based on its
// directionality. All HTLC outputs provided by the wallet are assumed
// to be non-dust.
for i, breachedHtlc := range breachInfo.HtlcRetributions {
// Using the breachedHtlc's incoming flag, determine the
// appropriate witness type that needs to be generated in order
// to sweep the HTLC output.
var htlcWitnessType input.StandardWitnessType
switch {
case isTaproot && breachedHtlc.IsIncoming:
htlcWitnessType = input.TaprootHtlcAcceptedRevoke
case isTaproot && !breachedHtlc.IsIncoming:
htlcWitnessType = input.TaprootHtlcOfferedRevoke
case !isTaproot && breachedHtlc.IsIncoming:
htlcWitnessType = input.HtlcAcceptedRevoke
case !isTaproot && !breachedHtlc.IsIncoming:
htlcWitnessType = input.HtlcOfferedRevoke
}
htlcOutput := makeBreachedOutput(
&breachInfo.HtlcRetributions[i].OutPoint,
htlcWitnessType,
breachInfo.HtlcRetributions[i].SecondLevelWitnessScript,
&breachInfo.HtlcRetributions[i].SignDesc,
breachInfo.BreachHeight,
breachInfo.HtlcRetributions[i].ResolutionBlob,
)
// For taproot outputs, we also need to hold onto the second
// level tap tweak as well.
//nolint:ll
htlcOutput.secondLevelTapTweak = breachedHtlc.SecondLevelTapTweak
breachedOutputs = append(breachedOutputs, htlcOutput)
}
return &retributionInfo{
commitHash: breachInfo.BreachTxHash,
chainHash: breachInfo.ChainHash,
chanPoint: *chanPoint,
breachedOutputs: breachedOutputs,
breachHeight: breachInfo.BreachHeight,
}
}
// justiceTxVariants is a struct that holds transactions which exacts "justice"
// by sweeping ALL the funds within the channel which we are now entitled to
// due to a breach of the channel's contract by the counterparty. There are
// four variants of justice transactions:
//
// 1. The "normal" justice tx that spends all breached outputs.
// 2. A tx that spends only the breached to_local output and to_remote output
// (can be nil if none of these exist).
// 3. A tx that spends all the breached commitment level HTLC outputs (can be
// nil if none of these exist or if all have been taken to the second level).
// 4. A set of txs that spend all the second-level HTLC outputs (can be empty if
// no HTLC second-level txs have been confirmed).
//
// The reason we create these three variants, is that in certain cases (like
// with the anchor output HTLC malleability), the channel counter party can pin
// the HTLC outputs with low fee children, hindering our normal justice tx that
// attempts to spend these outputs from propagating. In this case we want to
// spend the to_local output and commitment level HTLC outputs separately,
// before the CSV locks expire.
type justiceTxVariants struct {
spendAll *justiceTxCtx
spendCommitOuts *justiceTxCtx
spendHTLCs *justiceTxCtx
spendSecondLevelHTLCs []*justiceTxCtx
}
// createJusticeTx creates transactions which exacts "justice" by sweeping ALL
// the funds within the channel which we are now entitled to due to a breach of
// the channel's contract by the counterparty. This function returns a *fully*
// signed transaction with the witness for each input fully in place.
func (b *BreachArbitrator) createJusticeTx(
breachedOutputs []breachedOutput) (*justiceTxVariants, error) {
var (
allInputs []input.Input
commitInputs []input.Input
htlcInputs []input.Input
secondLevelInputs []input.Input
)
for i := range breachedOutputs {
// Grab locally scoped reference to breached output.
inp := &breachedOutputs[i]
allInputs = append(allInputs, inp)
// Check if the input is from a commitment output, a commitment
// level HTLC output or a second level HTLC output.
switch inp.WitnessType() {
case input.HtlcAcceptedRevoke, input.HtlcOfferedRevoke,
input.TaprootHtlcAcceptedRevoke,
input.TaprootHtlcOfferedRevoke:
htlcInputs = append(htlcInputs, inp)
case input.HtlcSecondLevelRevoke,
input.TaprootHtlcSecondLevelRevoke:
secondLevelInputs = append(secondLevelInputs, inp)
default:
commitInputs = append(commitInputs, inp)
}
}
var (
txs = &justiceTxVariants{}
err error
)
// For each group of inputs, create a tx that spends them.
txs.spendAll, err = b.createSweepTx(allInputs...)
if err != nil {
return nil, err
}
txs.spendCommitOuts, err = b.createSweepTx(commitInputs...)
if err != nil {
brarLog.Errorf("could not create sweep tx for commitment "+
"outputs: %v", err)
}
txs.spendHTLCs, err = b.createSweepTx(htlcInputs...)
if err != nil {
brarLog.Errorf("could not create sweep tx for HTLC outputs: %v",
err)
}
// TODO(roasbeef): only register one of them?
secondLevelSweeps := make([]*justiceTxCtx, 0, len(secondLevelInputs))
for _, input := range secondLevelInputs {
sweepTx, err := b.createSweepTx(input)
if err != nil {
brarLog.Errorf("could not create sweep tx for "+
"second-level HTLC output: %v", err)
continue
}
secondLevelSweeps = append(secondLevelSweeps, sweepTx)
}
txs.spendSecondLevelHTLCs = secondLevelSweeps
return txs, nil
}
// justiceTxCtx contains the justice transaction along with other related meta
// data.
type justiceTxCtx struct {
justiceTx *wire.MsgTx
sweepAddr lnwallet.AddrWithKey
extraTxOut fn.Option[sweep.SweepOutput]
fee btcutil.Amount
inputs []input.Input
}
// createSweepTx creates a tx that sweeps the passed inputs back to our wallet.
func (b *BreachArbitrator) createSweepTx(
inputs ...input.Input) (*justiceTxCtx, error) {
if len(inputs) == 0 {
return nil, nil
}
// We will assemble the breached outputs into a slice of spendable
// outputs, while simultaneously computing the estimated weight of the
// transaction.
var (
spendableOutputs []input.Input
weightEstimate input.TxWeightEstimator
)
// Allocate enough space to potentially hold each of the breached
// outputs in the retribution info.
spendableOutputs = make([]input.Input, 0, len(inputs))
// The justice transaction we construct will be a segwit transaction
// that pays to a p2tr output. Components such as the version,
// nLockTime, and output are already included in the TxWeightEstimator.
weightEstimate.AddP2TROutput()
// If any of our inputs has a resolution blob, then we'll add another
// P2TR _output_, since we'll want to separate the custom channel
// outputs from the regular, BTC only outputs. So we only need one such
// output, which'll carry the custom channel "valuables" from both the
// breached commitment and HTLC outputs.
hasBlobs := fn.Any(inputs, func(i input.Input) bool {
return i.ResolutionBlob().IsSome()
})
if hasBlobs {
weightEstimate.AddP2TROutput()
}
// Next, we iterate over the breached outputs contained in the
// retribution info. For each, we switch over the witness type such
// that we contribute the appropriate weight for each input and
// witness, finally adding to our list of spendable outputs.
for i := range inputs {
// Grab locally scoped reference to breached output.
inp := inputs[i]
// First, determine the appropriate estimated witness weight
// for the give witness type of this breached output. If the
// witness weight cannot be estimated, we will omit it from the
// transaction.
witnessWeight, _, err := inp.WitnessType().SizeUpperBound()
if err != nil {
brarLog.Warnf("could not determine witness weight "+
"for breached output in retribution info: %v",
err)
continue
}
weightEstimate.AddWitnessInput(witnessWeight)
// Finally, append this input to our list of spendable outputs.
spendableOutputs = append(spendableOutputs, inp)
}
txWeight := weightEstimate.Weight()
return b.sweepSpendableOutputsTxn(txWeight, spendableOutputs...)
}
// sweepSpendableOutputsTxn creates a signed transaction from a sequence of
// spendable outputs by sweeping the funds into a single p2wkh output.
func (b *BreachArbitrator) sweepSpendableOutputsTxn(txWeight lntypes.WeightUnit,
inputs ...input.Input) (*justiceTxCtx, error) {
// First, we obtain a new public key script from the wallet which we'll
// sweep the funds to.
// TODO(roasbeef): possibly create many outputs to minimize change in
// the future?
pkScript, err := b.cfg.GenSweepScript().Unpack()
if err != nil {
return nil, err
}
// Compute the total amount contained in the inputs.
var totalAmt btcutil.Amount
for _, inp := range inputs {
totalAmt += btcutil.Amount(inp.SignDesc().Output.Value)
}
// We'll actually attempt to target inclusion within the next two
// blocks as we'd like to sweep these funds back into our wallet ASAP.
feePerKw, err := b.cfg.Estimator.EstimateFeePerKW(justiceTxConfTarget)
if err != nil {
return nil, err
}
txFee := feePerKw.FeeForWeight(txWeight)
// At this point, we'll check to see if we have any extra outputs to
// add from the aux sweeper.
extraChangeOut := fn.MapOptionZ(
b.cfg.AuxSweeper,
func(aux sweep.AuxSweeper) fn.Result[sweep.SweepOutput] {
return aux.DeriveSweepAddr(inputs, pkScript)
},
)
if err := extraChangeOut.Err(); err != nil {
return nil, err
}
// TODO(roasbeef): already start to siphon their funds into fees
sweepAmt := int64(totalAmt - txFee)
// With the fee calculated, we can now create the transaction using the
// information gathered above and the provided retribution information.
txn := wire.NewMsgTx(2)
// First, we'll add the extra sweep output if it exists, subtracting the
// amount from the sweep amt.
if b.cfg.AuxSweeper.IsSome() {
extraChangeOut.WhenOk(func(o sweep.SweepOutput) {
sweepAmt -= o.Value
txn.AddTxOut(&o.TxOut)
})
}
// Next, we'll add the output to which our funds will be deposited.
txn.AddTxOut(&wire.TxOut{
PkScript: pkScript.DeliveryAddress,
Value: sweepAmt,
})
// TODO(roasbeef): add other output change modify sweep amt
// Next, we add all of the spendable outputs as inputs to the
// transaction.
for _, inp := range inputs {
txn.AddTxIn(&wire.TxIn{
PreviousOutPoint: inp.OutPoint(),
Sequence: inp.BlocksToMaturity(),
})
}
// Before signing the transaction, check to ensure that it meets some
// basic validity requirements.
btx := btcutil.NewTx(txn)
if err := blockchain.CheckTransactionSanity(btx); err != nil {
return nil, err
}
// Create a sighash cache to improve the performance of hashing and
// signing SigHashAll inputs.
prevOutputFetcher, err := input.MultiPrevOutFetcher(inputs)
if err != nil {
return nil, err
}
hashCache := txscript.NewTxSigHashes(txn, prevOutputFetcher)
// Create a closure that encapsulates the process of initializing a
// particular output's witness generation function, computing the
// witness, and attaching it to the transaction. This function accepts
// an integer index representing the intended txin index, and the
// breached output from which it will spend.
addWitness := func(idx int, so input.Input) error {
// First, we construct a valid witness for this outpoint and
// transaction using the SpendableOutput's witness generation
// function.
inputScript, err := so.CraftInputScript(
b.cfg.Signer, txn, hashCache, prevOutputFetcher, idx,
)
if err != nil {
return err
}
// Then, we add the witness to the transaction at the
// appropriate txin index.
txn.TxIn[idx].Witness = inputScript.Witness
return nil
}
// Finally, generate a witness for each output and attach it to the
// transaction.
for i, inp := range inputs {
if err := addWitness(i, inp); err != nil {
return nil, err
}
}
return &justiceTxCtx{
justiceTx: txn,
sweepAddr: pkScript,
extraTxOut: extraChangeOut.OkToSome(),
fee: txFee,
inputs: inputs,
}, nil
}
// RetributionStore handles persistence of retribution states to disk and is
// backed by a boltdb bucket. The primary responsibility of the retribution
// store is to ensure that we can recover from a restart in the middle of a
// breached contract retribution.
type RetributionStore struct {
db kvdb.Backend
}
// NewRetributionStore creates a new instance of a RetributionStore.
func NewRetributionStore(db kvdb.Backend) *RetributionStore {
return &RetributionStore{
db: db,
}
}
// taprootBriefcaseFromRetInfo creates a taprootBriefcase from a retribution
// info struct. This stores all the tap tweak information we need to inrder to
// be able to hadnel breaches after a restart.
func taprootBriefcaseFromRetInfo(retInfo *retributionInfo) *taprootBriefcase {
tapCase := newTaprootBriefcase()
for _, bo := range retInfo.breachedOutputs {
switch bo.WitnessType() {
// For spending from our commitment output on the remote
// commitment, we'll need to stash the control block.
case input.TaprootRemoteCommitSpend:
//nolint:ll
tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = bo.signDesc.ControlBlock
bo.resolutionBlob.WhenSome(func(blob tlv.Blob) {
tapCase.SettledCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType2](
blob,
),
)
})
// To spend the revoked output again, we'll store the same
// control block value as above, but in a different place.
case input.TaprootCommitmentRevoke:
//nolint:ll
tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock = bo.signDesc.ControlBlock
bo.resolutionBlob.WhenSome(func(blob tlv.Blob) {
tapCase.BreachedCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType3](
blob,
),
)
})
// For spending the HTLC outputs, we'll store the first and
// second level tweak values.
case input.TaprootHtlcAcceptedRevoke:
fallthrough
case input.TaprootHtlcOfferedRevoke:
resID := newResolverID(bo.OutPoint())
var firstLevelTweak [32]byte
copy(firstLevelTweak[:], bo.signDesc.TapTweak)
secondLevelTweak := bo.secondLevelTapTweak
//nolint:ll
tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID] = firstLevelTweak
//nolint:ll
tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID] = secondLevelTweak
}
}
return tapCase
}
// applyTaprootRetInfo attaches the taproot specific information in the tapCase
// to the passed retInfo struct.
func applyTaprootRetInfo(tapCase *taprootBriefcase,
retInfo *retributionInfo) error {
for i := range retInfo.breachedOutputs {
bo := retInfo.breachedOutputs[i]
switch bo.WitnessType() {
// For spending from our commitment output on the remote
// commitment, we'll apply the control block.
case input.TaprootRemoteCommitSpend:
//nolint:ll
bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock
tapCase.SettledCommitBlob.WhenSomeV(
func(blob tlv.Blob) {
bo.resolutionBlob = fn.Some(blob)
},
)
// To spend the revoked output again, we'll apply the same
// control block value as above, but to a different place.
case input.TaprootCommitmentRevoke:
//nolint:ll
bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock
tapCase.BreachedCommitBlob.WhenSomeV(
func(blob tlv.Blob) {
bo.resolutionBlob = fn.Some(blob)
},
)
// For spending the HTLC outputs, we'll apply the first and
// second level tweak values.
case input.TaprootHtlcAcceptedRevoke:
fallthrough
case input.TaprootHtlcOfferedRevoke:
resID := newResolverID(bo.OutPoint())
//nolint:ll
tap1, ok := tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID]
if !ok {
return fmt.Errorf("unable to find taproot "+
"tweak for: %v", bo.OutPoint())
}
bo.signDesc.TapTweak = tap1[:]
//nolint:ll
tap2, ok := tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID]
if !ok {
return fmt.Errorf("unable to find taproot "+
"tweak for: %v", bo.OutPoint())
}
bo.secondLevelTapTweak = tap2
}
retInfo.breachedOutputs[i] = bo
}
return nil
}
// Add adds a retribution state to the RetributionStore, which is then persisted
// to disk.
func (rs *RetributionStore) Add(ret *retributionInfo) error {
return kvdb.Update(rs.db, func(tx kvdb.RwTx) error {
// If this is our first contract breach, the retributionBucket
// won't exist, in which case, we just create a new bucket.
retBucket, err := tx.CreateTopLevelBucket(retributionBucket)
if err != nil {
return err
}
tapRetBucket, err := tx.CreateTopLevelBucket(
taprootRetributionBucket,
)
if err != nil {
return err
}
var outBuf bytes.Buffer
err = graphdb.WriteOutpoint(&outBuf, &ret.chanPoint)
if err != nil {
return err
}
var retBuf bytes.Buffer
if err := ret.Encode(&retBuf); err != nil {
return err
}
err = retBucket.Put(outBuf.Bytes(), retBuf.Bytes())
if err != nil {
return err
}
// If this isn't a taproot channel, then we can exit early here
// as there's no extra data to write.
switch {
case len(ret.breachedOutputs) == 0:
return nil
case !txscript.IsPayToTaproot(
ret.breachedOutputs[0].signDesc.Output.PkScript,
):
return nil
}
// We'll also map the ret info into the taproot storage
// structure we need for taproot channels.
var b bytes.Buffer
tapRetcase := taprootBriefcaseFromRetInfo(ret)
if err := tapRetcase.Encode(&b); err != nil {
return err
}
return tapRetBucket.Put(outBuf.Bytes(), b.Bytes())
}, func() {})
}
// IsBreached queries the retribution store to discern if this channel was
// previously breached. This is used when connecting to a peer to determine if
// it is safe to add a link to the htlcswitch, as we should never add a channel
// that has already been breached.
func (rs *RetributionStore) IsBreached(chanPoint *wire.OutPoint) (bool, error) {
var found bool
err := kvdb.View(rs.db, func(tx kvdb.RTx) error {
retBucket := tx.ReadBucket(retributionBucket)
if retBucket == nil {
return nil
}
var chanBuf bytes.Buffer
err := graphdb.WriteOutpoint(&chanBuf, chanPoint)
if err != nil {
return err
}
retInfo := retBucket.Get(chanBuf.Bytes())
if retInfo != nil {
found = true
}
return nil
}, func() {
found = false
})
return found, err
}
// Remove removes a retribution state and finalized justice transaction by
// channel point from the retribution store.
func (rs *RetributionStore) Remove(chanPoint *wire.OutPoint) error {
return kvdb.Update(rs.db, func(tx kvdb.RwTx) error {
retBucket := tx.ReadWriteBucket(retributionBucket)
tapRetBucket, err := tx.CreateTopLevelBucket(
taprootRetributionBucket,
)
if err != nil {
return err
}
// We return an error if the bucket is not already created,
// since normal operation of the breach arbiter should never
// try to remove a finalized retribution state that is not
// already stored in the db.
if retBucket == nil {
return errors.New("unable to remove retribution " +
"because the retribution bucket doesn't exist")
}
// Serialize the channel point we are intending to remove.
var chanBuf bytes.Buffer
err = graphdb.WriteOutpoint(&chanBuf, chanPoint)
if err != nil {
return err
}
chanBytes := chanBuf.Bytes()
// Remove the persisted retribution info and finalized justice
// transaction.
if err := retBucket.Delete(chanBytes); err != nil {
return err
}
return tapRetBucket.Delete(chanBytes)
}, func() {})
}
// ForAll iterates through all stored retributions and executes the passed
// callback function on each retribution.
func (rs *RetributionStore) ForAll(cb func(*retributionInfo) error,
reset func()) error {
return kvdb.View(rs.db, func(tx kvdb.RTx) error {
// If the bucket does not exist, then there are no pending
// retributions.
retBucket := tx.ReadBucket(retributionBucket)
if retBucket == nil {
return nil
}
tapRetBucket := tx.ReadBucket(
taprootRetributionBucket,
)
// Otherwise, we fetch each serialized retribution info,
// deserialize it, and execute the passed in callback function
// on it.
return retBucket.ForEach(func(k, retBytes []byte) error {
ret := &retributionInfo{}
err := ret.Decode(bytes.NewBuffer(retBytes))
if err != nil {
return err
}
tapInfoBytes := tapRetBucket.Get(k)
if tapInfoBytes != nil {
var tapCase taprootBriefcase
err := tapCase.Decode(
bytes.NewReader(tapInfoBytes),
)
if err != nil {
return err
}
err = applyTaprootRetInfo(&tapCase, ret)
if err != nil {
return err
}
}
return cb(ret)
})
}, reset)
}
// Encode serializes the retribution into the passed byte stream.
func (ret *retributionInfo) Encode(w io.Writer) error {
var scratch [4]byte
if _, err := w.Write(ret.commitHash[:]); err != nil {
return err
}
if err := graphdb.WriteOutpoint(w, &ret.chanPoint); err != nil {
return err
}
if _, err := w.Write(ret.chainHash[:]); err != nil {
return err
}
binary.BigEndian.PutUint32(scratch[:], ret.breachHeight)
if _, err := w.Write(scratch[:]); err != nil {
return err
}
nOutputs := len(ret.breachedOutputs)
if err := wire.WriteVarInt(w, 0, uint64(nOutputs)); err != nil {
return err
}
for _, output := range ret.breachedOutputs {
if err := output.Encode(w); err != nil {
return err
}
}
return nil
}
// Decode deserializes a retribution from the passed byte stream.
func (ret *retributionInfo) Decode(r io.Reader) error {
var scratch [32]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
hash, err := chainhash.NewHash(scratch[:])
if err != nil {
return err
}
ret.commitHash = *hash
if err := graphdb.ReadOutpoint(r, &ret.chanPoint); err != nil {
return err
}
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
chainHash, err := chainhash.NewHash(scratch[:])
if err != nil {
return err
}
ret.chainHash = *chainHash
if _, err := io.ReadFull(r, scratch[:4]); err != nil {
return err
}
ret.breachHeight = binary.BigEndian.Uint32(scratch[:4])
nOutputsU64, err := wire.ReadVarInt(r, 0)
if err != nil {
return err
}
nOutputs := int(nOutputsU64)
ret.breachedOutputs = make([]breachedOutput, nOutputs)
for i := range ret.breachedOutputs {
if err := ret.breachedOutputs[i].Decode(r); err != nil {
return err
}
}
return nil
}
// Encode serializes a breachedOutput into the passed byte stream.
func (bo *breachedOutput) Encode(w io.Writer) error {
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:8], uint64(bo.amt))
if _, err := w.Write(scratch[:8]); err != nil {
return err
}
if err := graphdb.WriteOutpoint(w, &bo.outpoint); err != nil {
return err
}
err := input.WriteSignDescriptor(w, &bo.signDesc)
if err != nil {
return err
}
err = wire.WriteVarBytes(w, 0, bo.secondLevelWitnessScript)
if err != nil {
return err
}
binary.BigEndian.PutUint16(scratch[:2], uint16(bo.witnessType))
if _, err := w.Write(scratch[:2]); err != nil {
return err
}
return nil
}
// Decode deserializes a breachedOutput from the passed byte stream.
func (bo *breachedOutput) Decode(r io.Reader) error {
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:8]); err != nil {
return err
}
bo.amt = btcutil.Amount(binary.BigEndian.Uint64(scratch[:8]))
if err := graphdb.ReadOutpoint(r, &bo.outpoint); err != nil {
return err
}
if err := input.ReadSignDescriptor(r, &bo.signDesc); err != nil {
return err
}
wScript, err := wire.ReadVarBytes(r, 0, 1000, "witness script")
if err != nil {
return err
}
bo.secondLevelWitnessScript = wScript
if _, err := io.ReadFull(r, scratch[:2]); err != nil {
return err
}
bo.witnessType = input.StandardWitnessType(
binary.BigEndian.Uint16(scratch[:2]),
)
return nil
}
package contractcourt
import (
"encoding/binary"
"fmt"
"io"
"github.com/lightningnetwork/lnd/channeldb"
)
// breachResolver is a resolver that will handle breached closes. In the
// future, this will likely take over the duties the current BreachArbitrator
// has.
type breachResolver struct {
// subscribed denotes whether or not the breach resolver has subscribed
// to the BreachArbitrator for breach resolution.
subscribed bool
// replyChan is closed when the breach arbiter has completed serving
// justice.
replyChan chan struct{}
contractResolverKit
}
// newBreachResolver instantiates a new breach resolver.
func newBreachResolver(resCfg ResolverConfig) *breachResolver {
r := &breachResolver{
contractResolverKit: *newContractResolverKit(resCfg),
replyChan: make(chan struct{}),
}
r.initLogger(fmt.Sprintf("%T(%v)", r, r.ChanPoint))
return r
}
// ResolverKey returns the unique identifier for this resolver.
func (b *breachResolver) ResolverKey() []byte {
key := newResolverID(b.ChanPoint)
return key[:]
}
// Resolve queries the BreachArbitrator to see if the justice transaction has
// been broadcast.
//
// NOTE: Part of the ContractResolver interface.
//
// TODO(yy): let sweeper handle the breach inputs.
func (b *breachResolver) Resolve() (ContractResolver, error) {
if !b.subscribed {
complete, err := b.SubscribeBreachComplete(
&b.ChanPoint, b.replyChan,
)
if err != nil {
return nil, err
}
// If the breach resolution process is already complete, then
// we can cleanup and checkpoint the resolved state.
if complete {
b.markResolved()
return nil, b.Checkpoint(b)
}
// Prevent duplicate subscriptions.
b.subscribed = true
}
select {
case <-b.replyChan:
// The replyChan has been closed, signalling that the breach
// has been fully resolved. Checkpoint the resolved state and
// exit.
b.markResolved()
return nil, b.Checkpoint(b)
case <-b.quit:
}
return nil, errResolverShuttingDown
}
// Stop signals the breachResolver to stop.
func (b *breachResolver) Stop() {
b.log.Debugf("stopping...")
close(b.quit)
}
// SupplementState adds additional state to the breachResolver.
func (b *breachResolver) SupplementState(_ *channeldb.OpenChannel) {
}
// Encode encodes the breachResolver to the passed writer.
func (b *breachResolver) Encode(w io.Writer) error {
return binary.Write(w, endian, b.IsResolved())
}
// newBreachResolverFromReader attempts to decode an encoded breachResolver
// from the passed Reader instance, returning an active breachResolver.
func newBreachResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*breachResolver, error) {
b := &breachResolver{
contractResolverKit: *newContractResolverKit(resCfg),
replyChan: make(chan struct{}),
}
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
b.markResolved()
}
b.initLogger(fmt.Sprintf("%T(%v)", b, b.ChanPoint))
return b, nil
}
// A compile time assertion to ensure breachResolver meets the ContractResolver
// interface.
var _ ContractResolver = (*breachResolver)(nil)
// Launch offers the breach outputs to the sweeper - currently it's a NOOP as
// the outputs here are not offered to the sweeper.
//
// NOTE: Part of the ContractResolver interface.
//
// TODO(yy): implement it once the outputs are offered to the sweeper.
func (b *breachResolver) Launch() error {
if b.isLaunched() {
b.log.Tracef("already launched")
return nil
}
b.log.Debugf("launching resolver...")
b.markLaunched()
return nil
}
package contractcourt
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/tlv"
)
// ContractResolutions is a wrapper struct around the two forms of resolutions
// we may need to carry out once a contract is closing: resolving the
// commitment output, and resolving any incoming+outgoing HTLC's still present
// in the commitment.
type ContractResolutions struct {
// CommitHash is the txid of the commitment transaction.
CommitHash chainhash.Hash
// CommitResolution contains all data required to fully resolve a
// commitment output.
CommitResolution *lnwallet.CommitOutputResolution
// HtlcResolutions contains all data required to fully resolve any
// incoming+outgoing HTLC's present within the commitment transaction.
HtlcResolutions lnwallet.HtlcResolutions
// AnchorResolution contains the data required to sweep the anchor
// output. If the channel type doesn't include anchors, the value of
// this field will be nil.
AnchorResolution *lnwallet.AnchorResolution
// BreachResolution contains the data required to manage the lifecycle
// of a breach in the ChannelArbitrator.
BreachResolution *BreachResolution
}
// IsEmpty returns true if the set of resolutions is "empty". A resolution is
// empty if: our commitment output has been trimmed, we don't have any
// incoming or outgoing HTLC's active, there is no anchor output to sweep, or
// there are no breached outputs to resolve.
func (c *ContractResolutions) IsEmpty() bool {
return c.CommitResolution == nil &&
len(c.HtlcResolutions.IncomingHTLCs) == 0 &&
len(c.HtlcResolutions.OutgoingHTLCs) == 0 &&
c.AnchorResolution == nil && c.BreachResolution == nil
}
// ArbitratorLog is the primary source of persistent storage for the
// ChannelArbitrator. The log stores the current state of the
// ChannelArbitrator's internal state machine, any items that are required to
// properly make a state transition, and any unresolved contracts.
type ArbitratorLog interface {
// TODO(roasbeef): document on interface the errors expected to be
// returned
// CurrentState returns the current state of the ChannelArbitrator. It
// takes an optional database transaction, which will be used if it is
// non-nil, otherwise the lookup will be done in its own transaction.
CurrentState(tx kvdb.RTx) (ArbitratorState, error)
// CommitState persists, the current state of the chain attendant.
CommitState(ArbitratorState) error
// InsertUnresolvedContracts inserts a set of unresolved contracts into
// the log. The log will then persistently store each contract until
// they've been swapped out, or resolved. It takes a set of report which
// should be written to disk if as well if it is non-nil.
InsertUnresolvedContracts(reports []*channeldb.ResolverReport,
resolvers ...ContractResolver) error
// FetchUnresolvedContracts returns all unresolved contracts that have
// been previously written to the log.
FetchUnresolvedContracts() ([]ContractResolver, error)
// SwapContract performs an atomic swap of the old contract for the new
// contract. This method is used when after a contract has been fully
// resolved, it produces another contract that needs to be resolved.
SwapContract(old ContractResolver, new ContractResolver) error
// ResolveContract marks a contract as fully resolved. Once a contract
// has been fully resolved, it is deleted from persistent storage.
ResolveContract(ContractResolver) error
// LogContractResolutions stores a complete contract resolution for the
// contract under watch. This method will be called once the
// ChannelArbitrator either force closes a channel, or detects that the
// remote party has broadcast their commitment on chain.
LogContractResolutions(*ContractResolutions) error
// FetchContractResolutions fetches the set of previously stored
// contract resolutions from persistent storage.
FetchContractResolutions() (*ContractResolutions, error)
// InsertConfirmedCommitSet stores the known set of active HTLCs at the
// time channel closure. We'll use this to reconstruct our set of chain
// actions anew based on the confirmed and pending commitment state.
InsertConfirmedCommitSet(c *CommitSet) error
// FetchConfirmedCommitSet fetches the known confirmed active HTLC set
// from the database. It takes an optional database transaction, which
// will be used if it is non-nil, otherwise the lookup will be done in
// its own transaction.
FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error)
// FetchChainActions attempts to fetch the set of previously stored
// chain actions. We'll use this upon restart to properly advance our
// state machine forward.
//
// NOTE: This method only exists in order to be able to serve nodes had
// channels in the process of closing before the CommitSet struct was
// introduced.
FetchChainActions() (ChainActionMap, error)
// WipeHistory is to be called ONLY once *all* contracts have been
// fully resolved, and the channel closure if finalized. This method
// will delete all on-disk state within the persistent log.
WipeHistory() error
}
// ArbitratorState is an enum that details the current state of the
// ChannelArbitrator's state machine.
type ArbitratorState uint8
const (
// While some state transition is allowed, certain transitions are not
// possible. Listed below is the full state transition map which
// contains all possible paths. We start at StateDefault and end at
// StateFullyResolved, or StateError(not listed as its a possible state
// in every path). The format is,
// -> state: conditions we switch to this state.
//
// StateDefault
// |
// |-> StateDefault: no actions and chain trigger
// |
// |-> StateBroadcastCommit: chain/user trigger
// | |
// | |-> StateCommitmentBroadcasted: chain/user trigger
// | | |
// | | |-> StateCommitmentBroadcasted: chain/user trigger
// | | |
// | | |-> StateContractClosed: local/remote/breach close trigger
// | | | |
// | | | |-> StateWaitingFullResolution: contract resolutions not empty
// | | | | |
// | | | | |-> StateWaitingFullResolution: contract resolutions not empty
// | | | | |
// | | | | |-> StateFullyResolved: contract resolutions empty
// | | | |
// | | | |-> StateFullyResolved: contract resolutions empty
// | | |
// | | |-> StateFullyResolved: coop/breach(legacy) close trigger
// | |
// | |-> StateContractClosed: local/remote/breach close trigger
// | | |
// | | |-> StateWaitingFullResolution: contract resolutions not empty
// | | | |
// | | | |-> StateWaitingFullResolution: contract resolutions not empty
// | | | |
// | | | |-> StateFullyResolved: contract resolutions empty
// | | |
// | | |-> StateFullyResolved: contract resolutions empty
// | |
// | |-> StateFullyResolved: coop/breach(legacy) close trigger
// |
// |-> StateContractClosed: local/remote/breach close trigger
// | |
// | |-> StateWaitingFullResolution: contract resolutions not empty
// | | |
// | | |-> StateWaitingFullResolution: contract resolutions not empty
// | | |
// | | |-> StateFullyResolved: contract resolutions empty
// | |
// | |-> StateFullyResolved: contract resolutions empty
// |
// |-> StateFullyResolved: coop/breach(legacy) close trigger.
// StateDefault is the default state. In this state, no major actions
// need to be executed.
StateDefault ArbitratorState = 0
// StateBroadcastCommit is a state that indicates that the attendant
// has decided to broadcast the commitment transaction, but hasn't done
// so yet.
StateBroadcastCommit ArbitratorState = 1
// StateCommitmentBroadcasted is a state that indicates that the
// attendant has broadcasted the commitment transaction, and is now
// waiting for it to confirm.
StateCommitmentBroadcasted ArbitratorState = 6
// StateContractClosed is a state that indicates the contract has
// already been "closed", meaning the commitment is confirmed on chain.
// At this point, we can now examine our active contracts, in order to
// create the proper resolver for each one.
StateContractClosed ArbitratorState = 2
// StateWaitingFullResolution is a state that indicates that the
// commitment transaction has been confirmed, and the attendant is now
// waiting for all unresolved contracts to be fully resolved.
StateWaitingFullResolution ArbitratorState = 3
// StateFullyResolved is the final state of the attendant. In this
// state, all related contracts have been resolved, and the attendant
// can now be garbage collected.
StateFullyResolved ArbitratorState = 4
// StateError is the only error state of the resolver. If we enter this
// state, then we cannot proceed with manual intervention as a state
// transition failed.
StateError ArbitratorState = 5
)
// String returns a human readable string describing the ArbitratorState.
func (a ArbitratorState) String() string {
switch a {
case StateDefault:
return "StateDefault"
case StateBroadcastCommit:
return "StateBroadcastCommit"
case StateCommitmentBroadcasted:
return "StateCommitmentBroadcasted"
case StateContractClosed:
return "StateContractClosed"
case StateWaitingFullResolution:
return "StateWaitingFullResolution"
case StateFullyResolved:
return "StateFullyResolved"
case StateError:
return "StateError"
default:
return "unknown state"
}
}
// IsContractClosed returns a bool to indicate whether the closing/breaching tx
// has been confirmed onchain. If the state is StateContractClosed,
// StateWaitingFullResolution, or StateFullyResolved, it means the contract has
// been closed and all related contracts have been launched.
func (a ArbitratorState) IsContractClosed() bool {
return a == StateContractClosed || a == StateWaitingFullResolution ||
a == StateFullyResolved
}
// resolverType is an enum that enumerates the various types of resolvers. When
// writing resolvers to disk, we prepend this to the raw bytes stored. This
// allows us to properly decode the resolver into the proper type.
type resolverType uint8
const (
// resolverTimeout is the type of a resolver that's tasked with
// resolving an outgoing HTLC that is very close to timing out.
resolverTimeout resolverType = 0
// resolverSuccess is the type of a resolver that's tasked with
// resolving an incoming HTLC that we already know the preimage of.
resolverSuccess resolverType = 1
// resolverOutgoingContest is the type of a resolver that's tasked with
// resolving an outgoing HTLC that hasn't yet timed out.
resolverOutgoingContest resolverType = 2
// resolverIncomingContest is the type of a resolver that's tasked with
// resolving an incoming HTLC that we don't yet know the preimage to.
resolverIncomingContest resolverType = 3
// resolverUnilateralSweep is the type of resolver that's tasked with
// sweeping out direct commitment output form the remote party's
// commitment transaction.
resolverUnilateralSweep resolverType = 4
// resolverBreach is the type of resolver that manages a contract
// breach on-chain.
resolverBreach resolverType = 5
)
// resolverIDLen is the size of the resolver ID key. This is 36 bytes as we get
// 32 bytes from the hash of the prev tx, and 4 bytes for the output index.
const resolverIDLen = 36
// resolverID is a key that uniquely identifies a resolver within a particular
// chain. For this value we use the full outpoint of the resolver.
type resolverID [resolverIDLen]byte
// newResolverID returns a resolverID given the outpoint of a contract.
func newResolverID(op wire.OutPoint) resolverID {
var r resolverID
copy(r[:], op.Hash[:])
endian.PutUint32(r[32:], op.Index)
return r
}
// logScope is a key that we use to scope the storage of a ChannelArbitrator
// within the global log. We use this key to create a unique bucket within the
// database and ensure that we don't have any key collisions. The log's scope
// is define as: chainHash || chanPoint, where chanPoint is the chan point of
// the original channel.
type logScope [32 + 36]byte
// newLogScope creates a new logScope key from the passed chainhash and
// chanPoint.
func newLogScope(chain chainhash.Hash, op wire.OutPoint) (*logScope, error) {
var l logScope
b := bytes.NewBuffer(l[0:0])
if _, err := b.Write(chain[:]); err != nil {
return nil, err
}
if _, err := b.Write(op.Hash[:]); err != nil {
return nil, err
}
if err := binary.Write(b, endian, op.Index); err != nil {
return nil, err
}
return &l, nil
}
var (
// stateKey is the key that we use to store the current state of the
// arbitrator.
stateKey = []byte("state")
// contractsBucketKey is the bucket within the logScope that will store
// all the active unresolved contracts.
contractsBucketKey = []byte("contractkey")
// resolutionsKey is the key under the logScope that we'll use to store
// the full set of resolutions for a channel.
resolutionsKey = []byte("resolutions")
// resolutionsSignDetailsKey is the key under the logScope where we
// will store input.SignDetails for each HTLC resolution. If this is
// not found under the logScope, it means it was written before
// SignDetails was introduced, and should be set nil for each HTLC
// resolution.
resolutionsSignDetailsKey = []byte("resolutions-sign-details")
// anchorResolutionKey is the key under the logScope that we'll use to
// store the anchor resolution, if any.
anchorResolutionKey = []byte("anchor-resolution")
// breachResolutionKey is the key under the logScope that we'll use to
// store the breach resolution, if any. This is used rather than the
// resolutionsKey.
breachResolutionKey = []byte("breach-resolution")
// actionsBucketKey is the key under the logScope that we'll use to
// store all chain actions once they're determined.
actionsBucketKey = []byte("chain-actions")
// commitSetKey is the primary key under the logScope that we'll use to
// store the confirmed active HTLC sets once we learn that a channel
// has closed out on chain.
commitSetKey = []byte("commit-set")
// taprootDataKey is the key we'll use to store taproot specific data
// for the set of channels we'll need to sweep/claim.
taprootDataKey = []byte("taproot-data")
)
var (
// errScopeBucketNoExist is returned when we can't find the proper
// bucket for an arbitrator's scope.
errScopeBucketNoExist = fmt.Errorf("scope bucket not found")
// errNoContracts is returned when no contracts are found within the
// log.
errNoContracts = fmt.Errorf("no stored contracts")
// errNoResolutions is returned when the log doesn't contain any active
// chain resolutions.
errNoResolutions = fmt.Errorf("no contract resolutions exist")
// errNoActions is returned when the log doesn't contain any stored
// chain actions.
errNoActions = fmt.Errorf("no chain actions exist")
// errNoCommitSet is return when the log doesn't contained a CommitSet.
// This can happen if the channel hasn't closed yet, or a client is
// running an older version that didn't yet write this state.
errNoCommitSet = fmt.Errorf("no commit set exists")
)
// boltArbitratorLog is an implementation of the ArbitratorLog interface backed
// by a bolt DB instance.
type boltArbitratorLog struct {
db kvdb.Backend
cfg ChannelArbitratorConfig
scopeKey logScope
}
// newBoltArbitratorLog returns a new instance of the boltArbitratorLog given
// an arbitrator config, and the items needed to create its log scope.
func newBoltArbitratorLog(db kvdb.Backend, cfg ChannelArbitratorConfig,
chainHash chainhash.Hash, chanPoint wire.OutPoint) (*boltArbitratorLog, error) {
scope, err := newLogScope(chainHash, chanPoint)
if err != nil {
return nil, err
}
return &boltArbitratorLog{
db: db,
cfg: cfg,
scopeKey: *scope,
}, nil
}
// A compile time check to ensure boltArbitratorLog meets the ArbitratorLog
// interface.
var _ ArbitratorLog = (*boltArbitratorLog)(nil)
func fetchContractReadBucket(tx kvdb.RTx, scopeKey []byte) (kvdb.RBucket, error) {
scopeBucket := tx.ReadBucket(scopeKey)
if scopeBucket == nil {
return nil, errScopeBucketNoExist
}
contractBucket := scopeBucket.NestedReadBucket(contractsBucketKey)
if contractBucket == nil {
return nil, errNoContracts
}
return contractBucket, nil
}
func fetchContractWriteBucket(tx kvdb.RwTx, scopeKey []byte) (kvdb.RwBucket, error) {
scopeBucket, err := tx.CreateTopLevelBucket(scopeKey)
if err != nil {
return nil, err
}
contractBucket, err := scopeBucket.CreateBucketIfNotExists(
contractsBucketKey,
)
if err != nil {
return nil, err
}
return contractBucket, nil
}
// writeResolver is a helper method that writes a contract resolver and stores
// it it within the passed contractBucket using its unique resolutionsKey key.
func (b *boltArbitratorLog) writeResolver(contractBucket kvdb.RwBucket,
res ContractResolver) error {
// Only persist resolvers that are stateful. Stateless resolvers don't
// expose a resolver key.
resKey := res.ResolverKey()
if resKey == nil {
return nil
}
// First, we'll write to the buffer the type of this resolver. Using
// this byte, we can later properly deserialize the resolver properly.
var (
buf bytes.Buffer
rType resolverType
)
switch res.(type) {
case *htlcTimeoutResolver:
rType = resolverTimeout
case *htlcSuccessResolver:
rType = resolverSuccess
case *htlcOutgoingContestResolver:
rType = resolverOutgoingContest
case *htlcIncomingContestResolver:
rType = resolverIncomingContest
case *commitSweepResolver:
rType = resolverUnilateralSweep
case *breachResolver:
rType = resolverBreach
}
if _, err := buf.Write([]byte{byte(rType)}); err != nil {
return err
}
// With the type of the resolver written, we can then write out the raw
// bytes of the resolver itself.
if err := res.Encode(&buf); err != nil {
return err
}
return contractBucket.Put(resKey, buf.Bytes())
}
// CurrentState returns the current state of the ChannelArbitrator. It takes an
// optional database transaction, which will be used if it is non-nil, otherwise
// the lookup will be done in its own transaction.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) CurrentState(tx kvdb.RTx) (ArbitratorState, error) {
var (
s ArbitratorState
err error
)
if tx != nil {
s, err = b.currentState(tx)
} else {
err = kvdb.View(b.db, func(tx kvdb.RTx) error {
s, err = b.currentState(tx)
return err
}, func() {
s = 0
})
}
if err != nil && err != errScopeBucketNoExist {
return s, err
}
return s, nil
}
func (b *boltArbitratorLog) currentState(tx kvdb.RTx) (ArbitratorState, error) {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
return 0, errScopeBucketNoExist
}
stateBytes := scopeBucket.Get(stateKey)
if stateBytes == nil {
return 0, nil
}
return ArbitratorState(stateBytes[0]), nil
}
// CommitState persists, the current state of the chain attendant.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) CommitState(s ArbitratorState) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
scopeBucket, err := tx.CreateTopLevelBucket(b.scopeKey[:])
if err != nil {
return err
}
return scopeBucket.Put(stateKey[:], []byte{uint8(s)})
})
}
// FetchUnresolvedContracts returns all unresolved contracts that have been
// previously written to the log.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, error) {
resolverCfg := ResolverConfig{
ChannelArbitratorConfig: b.cfg,
Checkpoint: b.checkpointContract,
}
var contracts []ContractResolver
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
contractBucket, err := fetchContractReadBucket(tx, b.scopeKey[:])
if err != nil {
return err
}
return contractBucket.ForEach(func(resKey, resBytes []byte) error {
if len(resKey) != resolverIDLen {
return nil
}
var res ContractResolver
// We'll snip off the first byte of the raw resolver
// bytes in order to extract what type of resolver
// we're about to encode.
resType := resolverType(resBytes[0])
// Then we'll create a reader using the remaining
// bytes.
resReader := bytes.NewReader(resBytes[1:])
switch resType {
case resolverTimeout:
res, err = newTimeoutResolverFromReader(
resReader, resolverCfg,
)
case resolverSuccess:
res, err = newSuccessResolverFromReader(
resReader, resolverCfg,
)
case resolverOutgoingContest:
res, err = newOutgoingContestResolverFromReader(
resReader, resolverCfg,
)
case resolverIncomingContest:
res, err = newIncomingContestResolverFromReader(
resReader, resolverCfg,
)
case resolverUnilateralSweep:
res, err = newCommitSweepResolverFromReader(
resReader, resolverCfg,
)
case resolverBreach:
res, err = newBreachResolverFromReader(
resReader, resolverCfg,
)
default:
return fmt.Errorf("unknown resolver type: %v", resType)
}
if err != nil {
return err
}
contracts = append(contracts, res)
return nil
})
}, func() {
contracts = nil
})
if err != nil && err != errScopeBucketNoExist && err != errNoContracts {
return nil, err
}
return contracts, nil
}
// InsertUnresolvedContracts inserts a set of unresolved contracts into the
// log. The log will then persistently store each contract until they've been
// swapped out, or resolved.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) InsertUnresolvedContracts(reports []*channeldb.ResolverReport,
resolvers ...ContractResolver) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:])
if err != nil {
return err
}
for _, resolver := range resolvers {
err = b.writeResolver(contractBucket, resolver)
if err != nil {
return err
}
}
// Persist any reports that are present.
for _, report := range reports {
err := b.cfg.PutResolverReport(tx, report)
if err != nil {
return err
}
}
return nil
})
}
// SwapContract performs an atomic swap of the old contract for the new
// contract. This method is used when after a contract has been fully resolved,
// it produces another contract that needs to be resolved.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) SwapContract(oldContract, newContract ContractResolver) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:])
if err != nil {
return err
}
oldContractkey := oldContract.ResolverKey()
if err := contractBucket.Delete(oldContractkey); err != nil {
return err
}
return b.writeResolver(contractBucket, newContract)
})
}
// ResolveContract marks a contract as fully resolved. Once a contract has been
// fully resolved, it is deleted from persistent storage.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) ResolveContract(res ContractResolver) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:])
if err != nil {
return err
}
resKey := res.ResolverKey()
return contractBucket.Delete(resKey)
})
}
// LogContractResolutions stores a set of chain actions which are derived from
// our set of active contracts, and the on-chain state. We'll write this et of
// cations when: we decide to go on-chain to resolve a contract, or we detect
// that the remote party has gone on-chain.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
scopeBucket, err := tx.CreateTopLevelBucket(b.scopeKey[:])
if err != nil {
return err
}
var b bytes.Buffer
if _, err := b.Write(c.CommitHash[:]); err != nil {
return err
}
// First, we'll write out the commit output's resolution.
if c.CommitResolution == nil {
if err := binary.Write(&b, endian, false); err != nil {
return err
}
} else {
if err := binary.Write(&b, endian, true); err != nil {
return err
}
err = encodeCommitResolution(&b, c.CommitResolution)
if err != nil {
return err
}
}
// As we write the HTLC resolutions, we'll serialize the sign
// details for each, to store under a new key.
var signDetailsBuf bytes.Buffer
// With the output for the commitment transaction written, we
// can now write out the resolutions for the incoming and
// outgoing HTLC's.
numIncoming := uint32(len(c.HtlcResolutions.IncomingHTLCs))
if err := binary.Write(&b, endian, numIncoming); err != nil {
return err
}
for _, htlc := range c.HtlcResolutions.IncomingHTLCs {
err := encodeIncomingResolution(&b, &htlc)
if err != nil {
return err
}
err = encodeSignDetails(&signDetailsBuf, htlc.SignDetails)
if err != nil {
return err
}
}
numOutgoing := uint32(len(c.HtlcResolutions.OutgoingHTLCs))
if err := binary.Write(&b, endian, numOutgoing); err != nil {
return err
}
for _, htlc := range c.HtlcResolutions.OutgoingHTLCs {
err := encodeOutgoingResolution(&b, &htlc)
if err != nil {
return err
}
err = encodeSignDetails(&signDetailsBuf, htlc.SignDetails)
if err != nil {
return err
}
}
// Put the resolutions under the resolutionsKey.
err = scopeBucket.Put(resolutionsKey, b.Bytes())
if err != nil {
return err
}
// We'll put the serialized sign details under its own key to
// stay backwards compatible.
err = scopeBucket.Put(
resolutionsSignDetailsKey, signDetailsBuf.Bytes(),
)
if err != nil {
return err
}
// Write out the anchor resolution if present.
if c.AnchorResolution != nil {
var b bytes.Buffer
err := encodeAnchorResolution(&b, c.AnchorResolution)
if err != nil {
return err
}
err = scopeBucket.Put(anchorResolutionKey, b.Bytes())
if err != nil {
return err
}
}
// Write out the breach resolution if present.
if c.BreachResolution != nil {
var b bytes.Buffer
err := encodeBreachResolution(&b, c.BreachResolution)
if err != nil {
return err
}
err = scopeBucket.Put(breachResolutionKey, b.Bytes())
if err != nil {
return err
}
}
// If this isn't a taproot channel, then we can exit early here
// as there's no extra data to write.
switch {
case c.AnchorResolution == nil:
return nil
case !txscript.IsPayToTaproot(
c.AnchorResolution.AnchorSignDescriptor.Output.PkScript,
):
return nil
}
// With everything else encoded, we'll now populate the taproot
// specific items we need to store for the musig2 channels.
var tb bytes.Buffer
err = encodeTaprootAuxData(&tb, c)
if err != nil {
return err
}
return scopeBucket.Put(taprootDataKey, tb.Bytes())
})
}
// FetchContractResolutions fetches the set of previously stored contract
// resolutions from persistent storage.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) {
var c *ContractResolutions
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
return errScopeBucketNoExist
}
resolutionBytes := scopeBucket.Get(resolutionsKey)
if resolutionBytes == nil {
return errNoResolutions
}
resReader := bytes.NewReader(resolutionBytes)
_, err := io.ReadFull(resReader, c.CommitHash[:])
if err != nil {
return err
}
// First, we'll attempt to read out the commit resolution (if
// it exists).
var haveCommitRes bool
err = binary.Read(resReader, endian, &haveCommitRes)
if err != nil {
return err
}
if haveCommitRes {
c.CommitResolution = &lnwallet.CommitOutputResolution{}
err = decodeCommitResolution(
resReader, c.CommitResolution,
)
if err != nil {
return fmt.Errorf("unable to decode "+
"commit res: %w", err)
}
}
var (
numIncoming uint32
numOutgoing uint32
)
// Next, we'll read out the incoming and outgoing HTLC
// resolutions.
err = binary.Read(resReader, endian, &numIncoming)
if err != nil {
return err
}
c.HtlcResolutions.IncomingHTLCs = make([]lnwallet.IncomingHtlcResolution, numIncoming)
for i := uint32(0); i < numIncoming; i++ {
err := decodeIncomingResolution(
resReader, &c.HtlcResolutions.IncomingHTLCs[i],
)
if err != nil {
return fmt.Errorf("unable to decode "+
"incoming res: %w", err)
}
}
err = binary.Read(resReader, endian, &numOutgoing)
if err != nil {
return err
}
c.HtlcResolutions.OutgoingHTLCs = make([]lnwallet.OutgoingHtlcResolution, numOutgoing)
for i := uint32(0); i < numOutgoing; i++ {
err := decodeOutgoingResolution(
resReader, &c.HtlcResolutions.OutgoingHTLCs[i],
)
if err != nil {
return fmt.Errorf("unable to decode "+
"outgoing res: %w", err)
}
}
// Now we attempt to get the sign details for our HTLC
// resolutions. If not present the channel is of a type that
// doesn't need them. If present there will be SignDetails
// encoded for each HTLC resolution.
signDetailsBytes := scopeBucket.Get(resolutionsSignDetailsKey)
if signDetailsBytes != nil {
r := bytes.NewReader(signDetailsBytes)
// They will be encoded in the same order as the
// resolutions: firs incoming HTLCs, then outgoing.
for i := uint32(0); i < numIncoming; i++ {
htlc := &c.HtlcResolutions.IncomingHTLCs[i]
htlc.SignDetails, err = decodeSignDetails(r)
if err != nil {
return fmt.Errorf("unable to decode "+
"incoming sign desc: %w", err)
}
}
for i := uint32(0); i < numOutgoing; i++ {
htlc := &c.HtlcResolutions.OutgoingHTLCs[i]
htlc.SignDetails, err = decodeSignDetails(r)
if err != nil {
return fmt.Errorf("unable to decode "+
"outgoing sign desc: %w", err)
}
}
}
anchorResBytes := scopeBucket.Get(anchorResolutionKey)
if anchorResBytes != nil {
c.AnchorResolution = &lnwallet.AnchorResolution{}
resReader := bytes.NewReader(anchorResBytes)
err := decodeAnchorResolution(
resReader, c.AnchorResolution,
)
if err != nil {
return fmt.Errorf("unable to read anchor "+
"data: %w", err)
}
}
breachResBytes := scopeBucket.Get(breachResolutionKey)
if breachResBytes != nil {
c.BreachResolution = &BreachResolution{}
resReader := bytes.NewReader(breachResBytes)
err := decodeBreachResolution(
resReader, c.BreachResolution,
)
if err != nil {
return fmt.Errorf("unable to read breach "+
"data: %w", err)
}
}
tapCaseBytes := scopeBucket.Get(taprootDataKey)
if tapCaseBytes != nil {
err = decodeTapRootAuxData(
bytes.NewReader(tapCaseBytes), c,
)
if err != nil {
return fmt.Errorf("unable to read taproot "+
"data: %w", err)
}
}
return nil
}, func() {
c = &ContractResolutions{}
})
if err != nil {
return nil, err
}
return c, err
}
// FetchChainActions attempts to fetch the set of previously stored chain
// actions. We'll use this upon restart to properly advance our state machine
// forward.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) {
var actionsMap ChainActionMap
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
return errScopeBucketNoExist
}
actionsBucket := scopeBucket.NestedReadBucket(actionsBucketKey)
if actionsBucket == nil {
return errNoActions
}
return actionsBucket.ForEach(func(action, htlcBytes []byte) error {
if htlcBytes == nil {
return nil
}
chainAction := ChainAction(action[0])
htlcReader := bytes.NewReader(htlcBytes)
htlcs, err := channeldb.DeserializeHtlcs(htlcReader)
if err != nil {
return err
}
actionsMap[chainAction] = htlcs
return nil
})
}, func() {
actionsMap = make(ChainActionMap)
})
if err != nil {
return nil, err
}
return actionsMap, nil
}
// InsertConfirmedCommitSet stores the known set of active HTLCs at the time
// channel closure. We'll use this to reconstruct our set of chain actions anew
// based on the confirmed and pending commitment state.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) InsertConfirmedCommitSet(c *CommitSet) error {
return kvdb.Batch(b.db, func(tx kvdb.RwTx) error {
scopeBucket, err := tx.CreateTopLevelBucket(b.scopeKey[:])
if err != nil {
return err
}
var b bytes.Buffer
if err := encodeCommitSet(&b, c); err != nil {
return err
}
return scopeBucket.Put(commitSetKey, b.Bytes())
})
}
// FetchConfirmedCommitSet fetches the known confirmed active HTLC set from the
// database. It takes an optional database transaction, which will be used if it
// is non-nil, otherwise the lookup will be done in its own transaction.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error) {
if tx != nil {
return b.fetchConfirmedCommitSet(tx)
}
var c *CommitSet
err := kvdb.View(b.db, func(tx kvdb.RTx) error {
var err error
c, err = b.fetchConfirmedCommitSet(tx)
return err
}, func() {
c = nil
})
if err != nil {
return nil, err
}
return c, nil
}
func (b *boltArbitratorLog) fetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet,
error) {
scopeBucket := tx.ReadBucket(b.scopeKey[:])
if scopeBucket == nil {
return nil, errScopeBucketNoExist
}
commitSetBytes := scopeBucket.Get(commitSetKey)
if commitSetBytes == nil {
return nil, errNoCommitSet
}
return decodeCommitSet(bytes.NewReader(commitSetBytes))
}
// WipeHistory is to be called ONLY once *all* contracts have been fully
// resolved, and the channel closure if finalized. This method will delete all
// on-disk state within the persistent log.
//
// NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) WipeHistory() error {
return kvdb.Update(b.db, func(tx kvdb.RwTx) error {
scopeBucket, err := tx.CreateTopLevelBucket(b.scopeKey[:])
if err != nil {
return err
}
// Once we have the main top-level bucket, we'll delete the key
// that stores the state of the arbitrator.
if err := scopeBucket.Delete(stateKey[:]); err != nil {
return err
}
// Next, we'll delete any lingering contract state within the
// contracts bucket by removing the bucket itself.
err = scopeBucket.DeleteNestedBucket(contractsBucketKey)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
// Next, we'll delete storage of any lingering contract
// resolutions.
if err := scopeBucket.Delete(resolutionsKey); err != nil {
return err
}
err = scopeBucket.Delete(resolutionsSignDetailsKey)
if err != nil {
return err
}
// We'll delete any chain actions that are still stored by
// removing the enclosing bucket.
err = scopeBucket.DeleteNestedBucket(actionsBucketKey)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
// Finally, we'll delete the enclosing bucket itself.
return tx.DeleteTopLevelBucket(b.scopeKey[:])
}, func() {})
}
// checkpointContract is a private method that will be fed into
// ContractResolver instances to checkpoint their state once they reach
// milestones during contract resolution. If the report provided is non-nil,
// it should also be recorded.
func (b *boltArbitratorLog) checkpointContract(c ContractResolver,
reports ...*channeldb.ResolverReport) error {
return kvdb.Update(b.db, func(tx kvdb.RwTx) error {
contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:])
if err != nil {
return err
}
if err := b.writeResolver(contractBucket, c); err != nil {
return err
}
for _, report := range reports {
if err := b.cfg.PutResolverReport(tx, report); err != nil {
return err
}
}
return nil
}, func() {})
}
// encodeSignDetails encodes the given SignDetails struct to the writer.
// SignDetails is allowed to be nil, in which we will encode that it is not
// present.
func encodeSignDetails(w io.Writer, s *input.SignDetails) error {
// If we don't have sign details, write false and return.
if s == nil {
return binary.Write(w, endian, false)
}
// Otherwise write true, and the contents of the SignDetails.
if err := binary.Write(w, endian, true); err != nil {
return err
}
err := input.WriteSignDescriptor(w, &s.SignDesc)
if err != nil {
return err
}
err = binary.Write(w, endian, uint32(s.SigHashType))
if err != nil {
return err
}
// Write the DER-encoded signature.
b := s.PeerSig.Serialize()
if err := wire.WriteVarBytes(w, 0, b); err != nil {
return err
}
return nil
}
// decodeSignDetails extracts a single SignDetails from the reader. It is
// allowed to return nil in case the SignDetails were empty.
func decodeSignDetails(r io.Reader) (*input.SignDetails, error) {
var present bool
if err := binary.Read(r, endian, &present); err != nil {
return nil, err
}
// Simply return nil if the next SignDetails was not present.
if !present {
return nil, nil
}
// Otherwise decode the elements of the SignDetails.
s := input.SignDetails{}
err := input.ReadSignDescriptor(r, &s.SignDesc)
if err != nil {
return nil, err
}
var sigHash uint32
err = binary.Read(r, endian, &sigHash)
if err != nil {
return nil, err
}
s.SigHashType = txscript.SigHashType(sigHash)
// Read DER-encoded signature.
rawSig, err := wire.ReadVarBytes(r, 0, 200, "signature")
if err != nil {
return nil, err
}
s.PeerSig, err = input.ParseSignature(rawSig)
if err != nil {
return nil, err
}
return &s, nil
}
func encodeIncomingResolution(w io.Writer, i *lnwallet.IncomingHtlcResolution) error {
if _, err := w.Write(i.Preimage[:]); err != nil {
return err
}
if i.SignedSuccessTx == nil {
if err := binary.Write(w, endian, false); err != nil {
return err
}
} else {
if err := binary.Write(w, endian, true); err != nil {
return err
}
if err := i.SignedSuccessTx.Serialize(w); err != nil {
return err
}
}
if err := binary.Write(w, endian, i.CsvDelay); err != nil {
return err
}
if _, err := w.Write(i.ClaimOutpoint.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, endian, i.ClaimOutpoint.Index); err != nil {
return err
}
err := input.WriteSignDescriptor(w, &i.SweepSignDesc)
if err != nil {
return err
}
return nil
}
func decodeIncomingResolution(r io.Reader, h *lnwallet.IncomingHtlcResolution) error {
if _, err := io.ReadFull(r, h.Preimage[:]); err != nil {
return err
}
var txPresent bool
if err := binary.Read(r, endian, &txPresent); err != nil {
return err
}
if txPresent {
h.SignedSuccessTx = &wire.MsgTx{}
if err := h.SignedSuccessTx.Deserialize(r); err != nil {
return err
}
}
err := binary.Read(r, endian, &h.CsvDelay)
if err != nil {
return err
}
_, err = io.ReadFull(r, h.ClaimOutpoint.Hash[:])
if err != nil {
return err
}
err = binary.Read(r, endian, &h.ClaimOutpoint.Index)
if err != nil {
return err
}
return input.ReadSignDescriptor(r, &h.SweepSignDesc)
}
func encodeOutgoingResolution(w io.Writer, o *lnwallet.OutgoingHtlcResolution) error {
if err := binary.Write(w, endian, o.Expiry); err != nil {
return err
}
if o.SignedTimeoutTx == nil {
if err := binary.Write(w, endian, false); err != nil {
return err
}
} else {
if err := binary.Write(w, endian, true); err != nil {
return err
}
if err := o.SignedTimeoutTx.Serialize(w); err != nil {
return err
}
}
if err := binary.Write(w, endian, o.CsvDelay); err != nil {
return err
}
if _, err := w.Write(o.ClaimOutpoint.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, endian, o.ClaimOutpoint.Index); err != nil {
return err
}
return input.WriteSignDescriptor(w, &o.SweepSignDesc)
}
func decodeOutgoingResolution(r io.Reader, o *lnwallet.OutgoingHtlcResolution) error {
err := binary.Read(r, endian, &o.Expiry)
if err != nil {
return err
}
var txPresent bool
if err := binary.Read(r, endian, &txPresent); err != nil {
return err
}
if txPresent {
o.SignedTimeoutTx = &wire.MsgTx{}
if err := o.SignedTimeoutTx.Deserialize(r); err != nil {
return err
}
}
err = binary.Read(r, endian, &o.CsvDelay)
if err != nil {
return err
}
_, err = io.ReadFull(r, o.ClaimOutpoint.Hash[:])
if err != nil {
return err
}
err = binary.Read(r, endian, &o.ClaimOutpoint.Index)
if err != nil {
return err
}
return input.ReadSignDescriptor(r, &o.SweepSignDesc)
}
func encodeCommitResolution(w io.Writer,
c *lnwallet.CommitOutputResolution) error {
if _, err := w.Write(c.SelfOutPoint.Hash[:]); err != nil {
return err
}
err := binary.Write(w, endian, c.SelfOutPoint.Index)
if err != nil {
return err
}
err = input.WriteSignDescriptor(w, &c.SelfOutputSignDesc)
if err != nil {
return err
}
return binary.Write(w, endian, c.MaturityDelay)
}
func decodeCommitResolution(r io.Reader,
c *lnwallet.CommitOutputResolution) error {
_, err := io.ReadFull(r, c.SelfOutPoint.Hash[:])
if err != nil {
return err
}
err = binary.Read(r, endian, &c.SelfOutPoint.Index)
if err != nil {
return err
}
err = input.ReadSignDescriptor(r, &c.SelfOutputSignDesc)
if err != nil {
return err
}
return binary.Read(r, endian, &c.MaturityDelay)
}
func encodeAnchorResolution(w io.Writer,
a *lnwallet.AnchorResolution) error {
if _, err := w.Write(a.CommitAnchor.Hash[:]); err != nil {
return err
}
err := binary.Write(w, endian, a.CommitAnchor.Index)
if err != nil {
return err
}
return input.WriteSignDescriptor(w, &a.AnchorSignDescriptor)
}
func decodeAnchorResolution(r io.Reader,
a *lnwallet.AnchorResolution) error {
_, err := io.ReadFull(r, a.CommitAnchor.Hash[:])
if err != nil {
return err
}
err = binary.Read(r, endian, &a.CommitAnchor.Index)
if err != nil {
return err
}
return input.ReadSignDescriptor(r, &a.AnchorSignDescriptor)
}
func encodeBreachResolution(w io.Writer, b *BreachResolution) error {
if _, err := w.Write(b.FundingOutPoint.Hash[:]); err != nil {
return err
}
return binary.Write(w, endian, b.FundingOutPoint.Index)
}
func decodeBreachResolution(r io.Reader, b *BreachResolution) error {
_, err := io.ReadFull(r, b.FundingOutPoint.Hash[:])
if err != nil {
return err
}
return binary.Read(r, endian, &b.FundingOutPoint.Index)
}
func encodeHtlcSetKey(w io.Writer, htlcSetKey HtlcSetKey) error {
err := binary.Write(w, endian, htlcSetKey.IsRemote)
if err != nil {
return err
}
return binary.Write(w, endian, htlcSetKey.IsPending)
}
func encodeCommitSet(w io.Writer, c *CommitSet) error {
confCommitKey, err := c.ConfCommitKey.UnwrapOrErr(
fmt.Errorf("HtlcSetKey is not set"),
)
if err != nil {
return err
}
if err := encodeHtlcSetKey(w, confCommitKey); err != nil {
return err
}
numSets := uint8(len(c.HtlcSets))
if err := binary.Write(w, endian, numSets); err != nil {
return err
}
for htlcSetKey, htlcs := range c.HtlcSets {
if err := encodeHtlcSetKey(w, htlcSetKey); err != nil {
return err
}
if err := channeldb.SerializeHtlcs(w, htlcs...); err != nil {
return err
}
}
return nil
}
func decodeHtlcSetKey(r io.Reader, h *HtlcSetKey) error {
err := binary.Read(r, endian, &h.IsRemote)
if err != nil {
return err
}
return binary.Read(r, endian, &h.IsPending)
}
func decodeCommitSet(r io.Reader) (*CommitSet, error) {
confCommitKey := HtlcSetKey{}
if err := decodeHtlcSetKey(r, &confCommitKey); err != nil {
return nil, err
}
c := &CommitSet{
ConfCommitKey: fn.Some(confCommitKey),
HtlcSets: make(map[HtlcSetKey][]channeldb.HTLC),
}
var numSets uint8
if err := binary.Read(r, endian, &numSets); err != nil {
return nil, err
}
for i := uint8(0); i < numSets; i++ {
var htlcSetKey HtlcSetKey
if err := decodeHtlcSetKey(r, &htlcSetKey); err != nil {
return nil, err
}
htlcs, err := channeldb.DeserializeHtlcs(r)
if err != nil {
return nil, err
}
c.HtlcSets[htlcSetKey] = htlcs
}
return c, nil
}
func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error {
tapCase := newTaprootBriefcase()
if c.CommitResolution != nil {
commitResolution := c.CommitResolution
commitSignDesc := commitResolution.SelfOutputSignDesc
//nolint:ll
tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = commitSignDesc.ControlBlock
c.CommitResolution.ResolutionBlob.WhenSome(func(b []byte) {
tapCase.SettledCommitBlob = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType2](b),
)
})
}
htlcBlobs := newAuxHtlcBlobs()
for _, htlc := range c.HtlcResolutions.IncomingHTLCs {
htlc := htlc
htlcSignDesc := htlc.SweepSignDesc
ctrlBlock := htlcSignDesc.ControlBlock
if ctrlBlock == nil {
continue
}
var resID resolverID
if htlc.SignedSuccessTx != nil {
resID = newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
)
//nolint:ll
tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we
// also need to store the control block needed to
// publish the second level transaction.
if htlc.SignDetails != nil {
//nolint:ll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:ll
tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:ll
tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = ctrlBlock
}
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
}
for _, htlc := range c.HtlcResolutions.OutgoingHTLCs {
htlc := htlc
htlcSignDesc := htlc.SweepSignDesc
ctrlBlock := htlcSignDesc.ControlBlock
if ctrlBlock == nil {
continue
}
var resID resolverID
if htlc.SignedTimeoutTx != nil {
resID = newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
)
//nolint:ll
tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock
// For HTLCs we need to go to the second level for, we
// also need to store the control block needed to
// publish the second level transaction.
//
//nolint:ll
if htlc.SignDetails != nil {
//nolint:ll
bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock
//nolint:ll
tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:ll
tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock
}
htlc.ResolutionBlob.WhenSome(func(b []byte) {
htlcBlobs[resID] = b
})
}
if c.AnchorResolution != nil {
anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor
tapCase.TapTweaks.Val.AnchorTweak = anchorSignDesc.TapTweak
}
if len(htlcBlobs) != 0 {
tapCase.HtlcBlobs = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](htlcBlobs),
)
}
return tapCase.Encode(w)
}
func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error {
tapCase := newTaprootBriefcase()
if err := tapCase.Decode(r); err != nil {
return err
}
if c.CommitResolution != nil {
c.CommitResolution.SelfOutputSignDesc.ControlBlock =
tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock
tapCase.SettledCommitBlob.WhenSomeV(func(b []byte) {
c.CommitResolution.ResolutionBlob = fn.Some(b)
})
}
htlcBlobs := tapCase.HtlcBlobs.ValOpt().UnwrapOr(newAuxHtlcBlobs())
for i := range c.HtlcResolutions.IncomingHTLCs {
htlc := c.HtlcResolutions.IncomingHTLCs[i]
var resID resolverID
if htlc.SignedSuccessTx != nil {
resID = newResolverID(
htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint,
)
//nolint:ll
ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:ll
if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:ll
ctrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.IncomingHTLCs[i] = htlc
}
for i := range c.HtlcResolutions.OutgoingHTLCs {
htlc := c.HtlcResolutions.OutgoingHTLCs[i]
var resID resolverID
if htlc.SignedTimeoutTx != nil {
resID = newResolverID(
htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint,
)
//nolint:ll
ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
//nolint:ll
if htlc.SignDetails != nil {
bridgeCtrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID]
htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock
}
} else {
resID = newResolverID(htlc.ClaimOutpoint)
//nolint:ll
ctrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID]
htlc.SweepSignDesc.ControlBlock = ctrlBlock
}
if htlcBlob, ok := htlcBlobs[resID]; ok {
htlc.ResolutionBlob = fn.Some(htlcBlob)
}
c.HtlcResolutions.OutgoingHTLCs[i] = htlc
}
if c.AnchorResolution != nil {
c.AnchorResolution.AnchorSignDescriptor.TapTweak =
tapCase.TapTweaks.Val.AnchorTweak
}
return nil
}
package contractcourt
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
)
// ErrChainArbExiting signals that the chain arbitrator is shutting down.
var ErrChainArbExiting = errors.New("ChainArbitrator exiting")
// ResolutionMsg is a message sent by resolvers to outside sub-systems once an
// outgoing contract has been fully resolved. For multi-hop contracts, if we
// resolve the outgoing contract, we'll also need to ensure that the incoming
// contract is resolved as well. We package the items required to resolve the
// incoming contracts within this message.
type ResolutionMsg struct {
// SourceChan identifies the channel that this message is being sent
// from. This is the channel's short channel ID.
SourceChan lnwire.ShortChannelID
// HtlcIndex is the index of the contract within the original
// commitment trace.
HtlcIndex uint64
// Failure will be non-nil if the incoming contract should be canceled
// all together. This can happen if the outgoing contract was dust, if
// if the outgoing HTLC timed out.
Failure lnwire.FailureMessage
// PreImage will be non-nil if the incoming contract can successfully
// be redeemed. This can happen if we learn of the preimage from the
// outgoing HTLC on-chain.
PreImage *[32]byte
}
// ChainArbitratorConfig is a configuration struct that contains all the
// function closures and interface that required to arbitrate on-chain
// contracts for a particular chain.
type ChainArbitratorConfig struct {
// ChainHash is the chain that this arbitrator is to operate within.
ChainHash chainhash.Hash
// IncomingBroadcastDelta is the delta that we'll use to decide when to
// broadcast our commitment transaction if we have incoming htlcs. This
// value should be set based on our current fee estimation of the
// commitment transaction. We use this to determine when we should
// broadcast instead of just the HTLC timeout, as we want to ensure
// that the commitment transaction is already confirmed, by the time the
// HTLC expires. Otherwise we may end up not settling the htlc on-chain
// because the other party managed to time it out.
IncomingBroadcastDelta uint32
// OutgoingBroadcastDelta is the delta that we'll use to decide when to
// broadcast our commitment transaction if there are active outgoing
// htlcs. This value can be lower than the incoming broadcast delta.
OutgoingBroadcastDelta uint32
// NewSweepAddr is a function that returns a new address under control
// by the wallet. We'll use this to sweep any no-delay outputs as a
// result of unilateral channel closes.
//
// NOTE: This SHOULD return a p2wkh script.
NewSweepAddr func() ([]byte, error)
// PublishTx reliably broadcasts a transaction to the network. Once
// this function exits without an error, then they transaction MUST
// continually be rebroadcast if needed.
PublishTx func(*wire.MsgTx, string) error
// DeliverResolutionMsg is a function that will append an outgoing
// message to the "out box" for a ChannelLink. This is used to cancel
// backwards any HTLC's that are either dust, we're timing out, or
// settling on-chain to the incoming link.
DeliverResolutionMsg func(...ResolutionMsg) error
// MarkLinkInactive is a function closure that the ChainArbitrator will
// use to mark that active HTLC's shouldn't be attempted to be routed
// over a particular channel. This function will be called in that a
// ChannelArbitrator decides that it needs to go to chain in order to
// resolve contracts.
//
// TODO(roasbeef): rename, routing based
MarkLinkInactive func(wire.OutPoint) error
// ContractBreach is a function closure that the ChainArbitrator will
// use to notify the BreachArbitrator about a contract breach. It should
// only return a non-nil error when the BreachArbitrator has preserved
// the necessary breach info for this channel point. Once the breach
// resolution is persisted in the ChannelArbitrator, it will be safe
// to mark the channel closed.
ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution) error
// IsOurAddress is a function that returns true if the passed address
// is known to the underlying wallet. Otherwise, false should be
// returned.
IsOurAddress func(btcutil.Address) bool
// IncubateOutputs sends either an incoming HTLC, an outgoing HTLC, or
// both to the utxo nursery. Once this function returns, the nursery
// should have safely persisted the outputs to disk, and should start
// the process of incubation. This is used when a resolver wishes to
// pass off the output to the nursery as we're only waiting on an
// absolute/relative item block.
IncubateOutputs func(wire.OutPoint,
fn.Option[lnwallet.OutgoingHtlcResolution],
fn.Option[lnwallet.IncomingHtlcResolution],
uint32, fn.Option[int32]) error
// PreimageDB is a global store of all known pre-images. We'll use this
// to decide if we should broadcast a commitment transaction to claim
// an HTLC on-chain.
PreimageDB WitnessBeacon
// Notifier is an instance of a chain notifier we'll use to watch for
// certain on-chain events.
Notifier chainntnfs.ChainNotifier
// Mempool is the a mempool watcher that allows us to watch for events
// happened in mempool.
Mempool chainntnfs.MempoolWatcher
// Signer is a signer backed by the active lnd node. This should be
// capable of producing a signature as specified by a valid
// SignDescriptor.
Signer input.Signer
// FeeEstimator will be used to return fee estimates.
FeeEstimator chainfee.Estimator
// ChainIO allows us to query the state of the current main chain.
ChainIO lnwallet.BlockChainIO
// DisableChannel disables a channel, resulting in it not being able to
// forward payments.
DisableChannel func(wire.OutPoint) error
// Sweeper allows resolvers to sweep their final outputs.
Sweeper UtxoSweeper
// Registry is the invoice database that is used by resolvers to lookup
// preimages and settle invoices.
Registry Registry
// NotifyClosedChannel is a function closure that the ChainArbitrator
// will use to notify the ChannelNotifier about a newly closed channel.
NotifyClosedChannel func(wire.OutPoint)
// NotifyFullyResolvedChannel is a function closure that the
// ChainArbitrator will use to notify the ChannelNotifier about a newly
// resolved channel. The main difference to NotifyClosedChannel is that
// in case of a local force close the NotifyClosedChannel is called when
// the published commitment transaction confirms while
// NotifyFullyResolvedChannel is only called when the channel is fully
// resolved (which includes sweeping any time locked funds).
NotifyFullyResolvedChannel func(point wire.OutPoint)
// OnionProcessor is used to decode onion payloads for on-chain
// resolution.
OnionProcessor OnionProcessor
// PaymentsExpirationGracePeriod indicates a time window we let the
// other node to cancel an outgoing htlc that our node has initiated and
// has timed out.
PaymentsExpirationGracePeriod time.Duration
// IsForwardedHTLC checks for a given htlc, identified by channel id and
// htlcIndex, if it is a forwarded one.
IsForwardedHTLC func(chanID lnwire.ShortChannelID, htlcIndex uint64) bool
// Clock is the clock implementation that ChannelArbitrator uses.
// It is useful for testing.
Clock clock.Clock
// SubscribeBreachComplete is used by the breachResolver to register a
// subscription that notifies when the breach resolution process is
// complete.
SubscribeBreachComplete func(op *wire.OutPoint, c chan struct{}) (
bool, error)
// PutFinalHtlcOutcome stores the final outcome of an htlc in the
// database.
PutFinalHtlcOutcome func(chanId lnwire.ShortChannelID,
htlcId uint64, settled bool) error
// HtlcNotifier is an interface that htlc events are sent to.
HtlcNotifier HtlcNotifier
// Budget is the configured budget for the arbitrator.
Budget BudgetConfig
// QueryIncomingCircuit is used to find the outgoing HTLC's
// corresponding incoming HTLC circuit. It queries the circuit map for
// a given outgoing circuit key and returns the incoming circuit key.
//
// TODO(yy): this is a hacky way to get around the cycling import issue
// as we cannot import `htlcswitch` here. A proper way is to define an
// interface here that asks for method `LookupOpenCircuit`,
// meanwhile, turn `PaymentCircuit` into an interface or bring it to a
// lower package.
QueryIncomingCircuit func(circuit models.CircuitKey) *models.CircuitKey
// AuxLeafStore is an optional store that can be used to store auxiliary
// leaves for certain custom channel types.
AuxLeafStore fn.Option[lnwallet.AuxLeafStore]
// AuxSigner is an optional signer that can be used to sign auxiliary
// leaves for certain custom channel types.
AuxSigner fn.Option[lnwallet.AuxSigner]
// AuxResolver is an optional interface that can be used to modify the
// way contracts are resolved.
AuxResolver fn.Option[lnwallet.AuxContractResolver]
}
// ChainArbitrator is a sub-system that oversees the on-chain resolution of all
// active, and channel that are in the "pending close" state. Within the
// contractcourt package, the ChainArbitrator manages a set of active
// ContractArbitrators. Each ContractArbitrators is responsible for watching
// the chain for any activity that affects the state of the channel, and also
// for monitoring each contract in order to determine if any on-chain activity is
// required. Outside sub-systems interact with the ChainArbitrator in order to
// forcibly exit a contract, update the set of live signals for each contract,
// and to receive reports on the state of contract resolution.
type ChainArbitrator struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// Embed the blockbeat consumer struct to get access to the method
// `NotifyBlockProcessed` and the `BlockbeatChan`.
chainio.BeatConsumer
sync.Mutex
// activeChannels is a map of all the active contracts that are still
// open, and not fully resolved.
activeChannels map[wire.OutPoint]*ChannelArbitrator
// activeWatchers is a map of all the active chainWatchers for channels
// that are still considered open.
activeWatchers map[wire.OutPoint]*chainWatcher
// cfg is the config struct for the arbitrator that contains all
// methods and interface it needs to operate.
cfg ChainArbitratorConfig
// chanSource will be used by the ChainArbitrator to fetch all the
// active channels that it must still watch over.
chanSource *channeldb.DB
// beat is the current best known blockbeat.
beat chainio.Blockbeat
quit chan struct{}
wg sync.WaitGroup
}
// NewChainArbitrator returns a new instance of the ChainArbitrator using the
// passed config struct, and backing persistent database.
func NewChainArbitrator(cfg ChainArbitratorConfig,
db *channeldb.DB) *ChainArbitrator {
c := &ChainArbitrator{
cfg: cfg,
activeChannels: make(map[wire.OutPoint]*ChannelArbitrator),
activeWatchers: make(map[wire.OutPoint]*chainWatcher),
chanSource: db,
quit: make(chan struct{}),
}
// Mount the block consumer.
c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name())
return c
}
// Compile-time check for the chainio.Consumer interface.
var _ chainio.Consumer = (*ChainArbitrator)(nil)
// arbChannel is a wrapper around an open channel that channel arbitrators
// interact with.
type arbChannel struct {
// channel is the in-memory channel state.
channel *channeldb.OpenChannel
// c references the chain arbitrator and is used by arbChannel
// internally.
c *ChainArbitrator
}
// NewAnchorResolutions returns the anchor resolutions for currently valid
// commitment transactions.
//
// NOTE: Part of the ArbChannel interface.
func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
error) {
// Get a fresh copy of the database state to base the anchor resolutions
// on. Unfortunately the channel instance that we have here isn't the
// same instance that is used by the link.
chanPoint := a.channel.FundingOutpoint
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint)
if err != nil {
return nil, err
}
var chanOpts []lnwallet.ChannelOpt
a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) {
chanOpts = append(chanOpts, lnwallet.WithLeafStore(s))
})
a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) {
chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s))
})
a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})
chanMachine, err := lnwallet.NewLightningChannel(
a.c.cfg.Signer, channel, nil, chanOpts...,
)
if err != nil {
return nil, err
}
return chanMachine.NewAnchorResolutions()
}
// ForceCloseChan should force close the contract that this attendant is
// watching over. We'll use this when we decide that we need to go to chain. It
// should in addition tell the switch to remove the corresponding link, such
// that we won't accept any new updates.
//
// NOTE: Part of the ArbChannel interface.
func (a *arbChannel) ForceCloseChan() (*wire.MsgTx, error) {
// First, we mark the channel as borked, this ensure
// that no new state transitions can happen, and also
// that the link won't be loaded into the switch.
if err := a.channel.MarkBorked(); err != nil {
return nil, err
}
// With the channel marked as borked, we'll now remove
// the link from the switch if its there. If the link
// is active, then this method will block until it
// exits.
chanPoint := a.channel.FundingOutpoint
if err := a.c.cfg.MarkLinkInactive(chanPoint); err != nil {
log.Errorf("unable to mark link inactive: %v", err)
}
// Now that we know the link can't mutate the channel
// state, we'll read the channel from disk the target
// channel according to its channel point.
channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(chanPoint)
if err != nil {
return nil, err
}
var chanOpts []lnwallet.ChannelOpt
a.c.cfg.AuxLeafStore.WhenSome(func(s lnwallet.AuxLeafStore) {
chanOpts = append(chanOpts, lnwallet.WithLeafStore(s))
})
a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) {
chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s))
})
a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) {
chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s))
})
// Finally, we'll force close the channel completing
// the force close workflow.
chanMachine, err := lnwallet.NewLightningChannel(
a.c.cfg.Signer, channel, nil, chanOpts...,
)
if err != nil {
return nil, err
}
closeSummary, err := chanMachine.ForceClose(
lnwallet.WithSkipContractResolutions(),
)
if err != nil {
return nil, err
}
return closeSummary.CloseTx, nil
}
// newActiveChannelArbitrator creates a new instance of an active channel
// arbitrator given the state of the target channel.
func newActiveChannelArbitrator(channel *channeldb.OpenChannel,
c *ChainArbitrator, chanEvents *ChainEventSubscription) (*ChannelArbitrator, error) {
// TODO(roasbeef): fetch best height (or pass in) so can ensure block
// epoch delivers all the notifications to
chanPoint := channel.FundingOutpoint
log.Tracef("Creating ChannelArbitrator for ChannelPoint(%v)", chanPoint)
// Next we'll create the matching configuration struct that contains
// all interfaces and methods the arbitrator needs to do its job.
arbCfg := ChannelArbitratorConfig{
ChanPoint: chanPoint,
Channel: c.getArbChannel(channel),
ShortChanID: channel.ShortChanID(),
MarkCommitmentBroadcasted: channel.MarkCommitmentBroadcasted,
MarkChannelClosed: func(summary *channeldb.ChannelCloseSummary,
statuses ...channeldb.ChannelStatus) error {
err := channel.CloseChannel(summary, statuses...)
if err != nil {
return err
}
c.cfg.NotifyClosedChannel(summary.ChanPoint)
return nil
},
IsPendingClose: false,
ChainArbitratorConfig: c.cfg,
ChainEvents: chanEvents,
PutResolverReport: func(tx kvdb.RwTx,
report *channeldb.ResolverReport) error {
return c.chanSource.PutResolverReport(
tx, c.cfg.ChainHash, &chanPoint, report,
)
},
FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) {
chanStateDB := c.chanSource.ChannelStateDB()
return chanStateDB.FetchHistoricalChannel(&chanPoint)
},
FindOutgoingHTLCDeadline: func(
htlc channeldb.HTLC) fn.Option[int32] {
return c.FindOutgoingHTLCDeadline(
channel.ShortChanID(), htlc,
)
},
}
// The final component needed is an arbitrator log that the arbitrator
// will use to keep track of its internal state using a backed
// persistent log.
//
// TODO(roasbeef); abstraction leak...
// * rework: adaptor method to set log scope w/ factory func
chanLog, err := newBoltArbitratorLog(
c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint,
)
if err != nil {
return nil, err
}
arbCfg.MarkChannelResolved = func() error {
if c.cfg.NotifyFullyResolvedChannel != nil {
c.cfg.NotifyFullyResolvedChannel(chanPoint)
}
return c.ResolveContract(chanPoint)
}
// Finally, we'll need to construct a series of htlc Sets based on all
// currently known valid commitments.
htlcSets := make(map[HtlcSetKey]htlcSet)
htlcSets[LocalHtlcSet] = newHtlcSet(channel.LocalCommitment.Htlcs)
htlcSets[RemoteHtlcSet] = newHtlcSet(channel.RemoteCommitment.Htlcs)
pendingRemoteCommitment, err := channel.RemoteCommitChainTip()
if err != nil && err != channeldb.ErrNoPendingCommit {
return nil, err
}
if pendingRemoteCommitment != nil {
htlcSets[RemotePendingHtlcSet] = newHtlcSet(
pendingRemoteCommitment.Commitment.Htlcs,
)
}
return NewChannelArbitrator(
arbCfg, htlcSets, chanLog,
), nil
}
// getArbChannel returns an open channel wrapper for use by channel arbitrators.
func (c *ChainArbitrator) getArbChannel(
channel *channeldb.OpenChannel) *arbChannel {
return &arbChannel{
channel: channel,
c: c,
}
}
// ResolveContract marks a contract as fully resolved within the database.
// This is only to be done once all contracts which were live on the channel
// before hitting the chain have been resolved.
func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
log.Infof("Marking ChannelPoint(%v) fully resolved", chanPoint)
// First, we'll we'll mark the channel as fully closed from the PoV of
// the channel source.
err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint)
if err != nil {
log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+
"fully closed: %v", chanPoint, err)
return err
}
// Now that the channel has been marked as fully closed, we'll stop
// both the channel arbitrator and chain watcher for this channel if
// they're still active.
var arbLog ArbitratorLog
c.Lock()
chainArb := c.activeChannels[chanPoint]
delete(c.activeChannels, chanPoint)
chainWatcher := c.activeWatchers[chanPoint]
delete(c.activeWatchers, chanPoint)
c.Unlock()
if chainArb != nil {
arbLog = chainArb.log
if err := chainArb.Stop(); err != nil {
log.Warnf("unable to stop ChannelArbitrator(%v): %v",
chanPoint, err)
}
}
if chainWatcher != nil {
if err := chainWatcher.Stop(); err != nil {
log.Warnf("unable to stop ChainWatcher(%v): %v",
chanPoint, err)
}
}
// Once this has been marked as resolved, we'll wipe the log that the
// channel arbitrator was using to store its persistent state. We do
// this after marking the channel resolved, as otherwise, the
// arbitrator would be re-created, and think it was starting from the
// default state.
if arbLog != nil {
if err := arbLog.WipeHistory(); err != nil {
return err
}
}
return nil
}
// Start launches all goroutines that the ChainArbitrator needs to operate.
func (c *ChainArbitrator) Start(beat chainio.Blockbeat) error {
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return nil
}
// Set the current beat.
c.beat = beat
// First, we'll fetch all the channels that are still open, in order to
// collect them within our set of active contracts.
if err := c.loadOpenChannels(); err != nil {
return err
}
// In addition to the channels that we know to be open, we'll also
// launch arbitrators to finishing resolving any channels that are in
// the pending close state.
if err := c.loadPendingCloseChannels(); err != nil {
return err
}
// Now, we'll start all chain watchers in parallel to shorten start up
// duration. In neutrino mode, this allows spend registrations to take
// advantage of batch spend reporting, instead of doing a single rescan
// per chain watcher.
//
// NOTE: After this point, we Stop the chain arb to ensure that any
// lingering goroutines are cleaned up before exiting.
watcherErrs := make(chan error, len(c.activeWatchers))
var wg sync.WaitGroup
for _, watcher := range c.activeWatchers {
wg.Add(1)
go func(w *chainWatcher) {
defer wg.Done()
select {
case watcherErrs <- w.Start():
case <-c.quit:
watcherErrs <- ErrChainArbExiting
}
}(watcher)
}
// Once all chain watchers have been started, seal the err chan to
// signal the end of the err stream.
go func() {
wg.Wait()
close(watcherErrs)
}()
// stopAndLog is a helper function which shuts down the chain arb and
// logs errors if they occur.
stopAndLog := func() {
if err := c.Stop(); err != nil {
log.Errorf("ChainArbitrator could not shutdown: %v", err)
}
}
// Handle all errors returned from spawning our chain watchers. If any
// of them failed, we will stop the chain arb to shutdown any active
// goroutines.
for err := range watcherErrs {
if err != nil {
stopAndLog()
return err
}
}
// Before we start all of our arbitrators, we do a preliminary state
// lookup so that we can combine all of these lookups in a single db
// transaction.
var startStates map[wire.OutPoint]*chanArbStartState
err := kvdb.View(c.chanSource, func(tx walletdb.ReadTx) error {
for _, arbitrator := range c.activeChannels {
startState, err := arbitrator.getStartState(tx)
if err != nil {
return err
}
startStates[arbitrator.cfg.ChanPoint] = startState
}
return nil
}, func() {
startStates = make(
map[wire.OutPoint]*chanArbStartState,
len(c.activeChannels),
)
})
if err != nil {
stopAndLog()
return err
}
// Launch all the goroutines for each arbitrator so they can carry out
// their duties.
for _, arbitrator := range c.activeChannels {
startState, ok := startStates[arbitrator.cfg.ChanPoint]
if !ok {
stopAndLog()
return fmt.Errorf("arbitrator: %v has no start state",
arbitrator.cfg.ChanPoint)
}
if err := arbitrator.Start(startState, c.beat); err != nil {
stopAndLog()
return err
}
}
// Start our goroutine which will dispatch blocks to each arbitrator.
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.dispatchBlocks()
}()
log.Infof("ChainArbitrator starting at height %d with %d chain "+
"watchers, %d channel arbitrators, and budget config=[%v]",
c.beat.Height(), len(c.activeWatchers), len(c.activeChannels),
&c.cfg.Budget)
// TODO(roasbeef): eventually move all breach watching here
return nil
}
// dispatchBlocks consumes a block epoch notification stream and dispatches
// blocks to each of the chain arb's active channel arbitrators. This function
// must be run in a goroutine.
func (c *ChainArbitrator) dispatchBlocks() {
// Consume block epochs until we receive the instruction to shutdown.
for {
select {
// Consume block epochs, exiting if our subscription is
// terminated.
case beat := <-c.BlockbeatChan:
// Set the current blockbeat.
c.beat = beat
// Send this blockbeat to all the active channels and
// wait for them to finish processing it.
c.handleBlockbeat(beat)
// Exit if the chain arbitrator is shutting down.
case <-c.quit:
return
}
}
}
// handleBlockbeat sends the blockbeat to all active channel arbitrator in
// parallel and wait for them to finish processing it.
func (c *ChainArbitrator) handleBlockbeat(beat chainio.Blockbeat) {
// Read the active channels in a lock.
c.Lock()
// Create a slice to record active channel arbitrator.
channels := make([]chainio.Consumer, 0, len(c.activeChannels))
watchers := make([]chainio.Consumer, 0, len(c.activeWatchers))
// Copy the active channels to the slice.
for _, channel := range c.activeChannels {
channels = append(channels, channel)
}
for _, watcher := range c.activeWatchers {
watchers = append(watchers, watcher)
}
c.Unlock()
// Iterate all the copied watchers and send the blockbeat to them.
err := chainio.DispatchConcurrent(beat, watchers)
if err != nil {
log.Errorf("Notify blockbeat for chainWatcher failed: %v", err)
}
// Iterate all the copied channels and send the blockbeat to them.
//
// NOTE: This method will timeout if the processing of blocks of the
// subsystems is too long (60s).
err = chainio.DispatchConcurrent(beat, channels)
if err != nil {
log.Errorf("Notify blockbeat for ChannelArbitrator failed: %v",
err)
}
// Notify the chain arbitrator has processed the block.
c.NotifyBlockProcessed(beat, err)
}
// republishClosingTxs will load any stored cooperative or unilateral closing
// transactions and republish them. This helps ensure propagation of the
// transactions in the event that prior publications failed.
func (c *ChainArbitrator) republishClosingTxs(
channel *channeldb.OpenChannel) error {
// If the channel has had its unilateral close broadcasted already,
// republish it in case it didn't propagate.
if channel.HasChanStatus(channeldb.ChanStatusCommitBroadcasted) {
err := c.rebroadcast(
channel, channeldb.ChanStatusCommitBroadcasted,
)
if err != nil {
return err
}
}
// If the channel has had its cooperative close broadcasted
// already, republish it in case it didn't propagate.
if channel.HasChanStatus(channeldb.ChanStatusCoopBroadcasted) {
err := c.rebroadcast(
channel, channeldb.ChanStatusCoopBroadcasted,
)
if err != nil {
return err
}
}
return nil
}
// rebroadcast is a helper method which will republish the unilateral or
// cooperative close transaction or a channel in a particular state.
//
// NOTE: There is no risk to calling this method if the channel isn't in either
// CommitmentBroadcasted or CoopBroadcasted, but the logs will be misleading.
func (c *ChainArbitrator) rebroadcast(channel *channeldb.OpenChannel,
state channeldb.ChannelStatus) error {
chanPoint := channel.FundingOutpoint
var (
closeTx *wire.MsgTx
kind string
err error
)
switch state {
case channeldb.ChanStatusCommitBroadcasted:
kind = "force"
closeTx, err = channel.BroadcastedCommitment()
case channeldb.ChanStatusCoopBroadcasted:
kind = "coop"
closeTx, err = channel.BroadcastedCooperative()
default:
return fmt.Errorf("unknown closing state: %v", state)
}
switch {
// This can happen for channels that had their closing tx published
// before we started storing it to disk.
case err == channeldb.ErrNoCloseTx:
log.Warnf("Channel %v is in state %v, but no %s closing tx "+
"to re-publish...", chanPoint, state, kind)
return nil
case err != nil:
return err
}
log.Infof("Re-publishing %s close tx(%v) for channel %v",
kind, closeTx.TxHash(), chanPoint)
label := labels.MakeLabel(
labels.LabelTypeChannelClose, &channel.ShortChannelID,
)
err = c.cfg.PublishTx(closeTx, label)
if err != nil && err != lnwallet.ErrDoubleSpend {
log.Warnf("Unable to broadcast %s close tx(%v): %v",
kind, closeTx.TxHash(), err)
}
return nil
}
// Stop signals the ChainArbitrator to trigger a graceful shutdown. Any active
// channel arbitrators will be signalled to exit, and this method will block
// until they've all exited.
func (c *ChainArbitrator) Stop() error {
if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) {
return nil
}
log.Info("ChainArbitrator shutting down...")
defer log.Debug("ChainArbitrator shutdown complete")
close(c.quit)
var (
activeWatchers = make(map[wire.OutPoint]*chainWatcher)
activeChannels = make(map[wire.OutPoint]*ChannelArbitrator)
)
// Copy the current set of active watchers and arbitrators to shutdown.
// We don't want to hold the lock when shutting down each watcher or
// arbitrator individually, as they may need to acquire this mutex.
c.Lock()
for chanPoint, watcher := range c.activeWatchers {
activeWatchers[chanPoint] = watcher
}
for chanPoint, arbitrator := range c.activeChannels {
activeChannels[chanPoint] = arbitrator
}
c.Unlock()
for chanPoint, watcher := range activeWatchers {
log.Tracef("Attempting to stop ChainWatcher(%v)",
chanPoint)
if err := watcher.Stop(); err != nil {
log.Errorf("unable to stop watcher for "+
"ChannelPoint(%v): %v", chanPoint, err)
}
}
for chanPoint, arbitrator := range activeChannels {
log.Tracef("Attempting to stop ChannelArbitrator(%v)",
chanPoint)
if err := arbitrator.Stop(); err != nil {
log.Errorf("unable to stop arbitrator for "+
"ChannelPoint(%v): %v", chanPoint, err)
}
}
c.wg.Wait()
return nil
}
// ContractUpdate is a message packages the latest set of active HTLCs on a
// commitment, and also identifies which commitment received a new set of
// HTLCs.
type ContractUpdate struct {
// HtlcKey identifies which commitment the HTLCs below are present on.
HtlcKey HtlcSetKey
// Htlcs are the of active HTLCs on the commitment identified by the
// above HtlcKey.
Htlcs []channeldb.HTLC
}
// ContractSignals is used by outside subsystems to notify a channel arbitrator
// of its ShortChannelID.
type ContractSignals struct {
// ShortChanID is the up to date short channel ID for a contract. This
// can change either if when the contract was added it didn't yet have
// a stable identifier, or in the case of a reorg.
ShortChanID lnwire.ShortChannelID
}
// UpdateContractSignals sends a set of active, up to date contract signals to
// the ChannelArbitrator which is has been assigned to the channel infield by
// the passed channel point.
func (c *ChainArbitrator) UpdateContractSignals(chanPoint wire.OutPoint,
signals *ContractSignals) error {
log.Infof("Attempting to update ContractSignals for ChannelPoint(%v)",
chanPoint)
c.Lock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
if !ok {
return fmt.Errorf("unable to find arbitrator")
}
arbitrator.UpdateContractSignals(signals)
return nil
}
// NotifyContractUpdate lets a channel arbitrator know that a new
// ContractUpdate is available. This calls the ChannelArbitrator's internal
// method NotifyContractUpdate which waits for a response on a done chan before
// returning. This method will return an error if the ChannelArbitrator is not
// in the activeChannels map. However, this only happens if the arbitrator is
// resolved and the related link would already be shut down.
func (c *ChainArbitrator) NotifyContractUpdate(chanPoint wire.OutPoint,
update *ContractUpdate) error {
c.Lock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
if !ok {
return fmt.Errorf("can't find arbitrator for %v", chanPoint)
}
arbitrator.notifyContractUpdate(update)
return nil
}
// GetChannelArbitrator safely returns the channel arbitrator for a given
// channel outpoint.
func (c *ChainArbitrator) GetChannelArbitrator(chanPoint wire.OutPoint) (
*ChannelArbitrator, error) {
c.Lock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
if !ok {
return nil, fmt.Errorf("unable to find arbitrator")
}
return arbitrator, nil
}
// forceCloseReq is a request sent from an outside sub-system to the arbitrator
// that watches a particular channel to broadcast the commitment transaction,
// and enter the resolution phase of the channel.
type forceCloseReq struct {
// errResp is a channel that will be sent upon either in the case of
// force close success (nil error), or in the case on an error.
//
// NOTE; This channel MUST be buffered.
errResp chan error
// closeTx is a channel that carries the transaction which ultimately
// closed out the channel.
closeTx chan *wire.MsgTx
}
// ForceCloseContract attempts to force close the channel infield by the passed
// channel point. A force close will immediately terminate the contract,
// causing it to enter the resolution phase. If the force close was successful,
// then the force close transaction itself will be returned.
//
// TODO(roasbeef): just return the summary itself?
func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.MsgTx, error) {
c.Lock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
if !ok {
return nil, fmt.Errorf("unable to find arbitrator")
}
log.Infof("Attempting to force close ChannelPoint(%v)", chanPoint)
// Before closing, we'll attempt to send a disable update for the
// channel. We do so before closing the channel as otherwise the current
// edge policy won't be retrievable from the graph.
if err := c.cfg.DisableChannel(chanPoint); err != nil {
log.Warnf("Unable to disable channel %v on "+
"close: %v", chanPoint, err)
}
errChan := make(chan error, 1)
respChan := make(chan *wire.MsgTx, 1)
// With the channel found, and the request crafted, we'll send over a
// force close request to the arbitrator that watches this channel.
select {
case arbitrator.forceCloseReqs <- &forceCloseReq{
errResp: errChan,
closeTx: respChan,
}:
case <-c.quit:
return nil, ErrChainArbExiting
}
// We'll await two responses: the error response, and the transaction
// that closed out the channel.
select {
case err := <-errChan:
if err != nil {
return nil, err
}
case <-c.quit:
return nil, ErrChainArbExiting
}
var closeTx *wire.MsgTx
select {
case closeTx = <-respChan:
case <-c.quit:
return nil, ErrChainArbExiting
}
return closeTx, nil
}
// WatchNewChannel sends the ChainArbitrator a message to create a
// ChannelArbitrator tasked with watching over a new channel. Once a new
// channel has finished its final funding flow, it should be registered with
// the ChainArbitrator so we can properly react to any on-chain events.
func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error {
c.Lock()
defer c.Unlock()
chanPoint := newChan.FundingOutpoint
log.Infof("Creating new chainWatcher and ChannelArbitrator for "+
"ChannelPoint(%v)", chanPoint)
// If we're already watching this channel, then we'll ignore this
// request.
if _, ok := c.activeChannels[chanPoint]; ok {
return nil
}
// First, also create an active chainWatcher for this channel to ensure
// that we detect any relevant on chain events.
chainWatcher, err := newChainWatcher(
chainWatcherConfig{
chanState: newChan,
notifier: c.cfg.Notifier,
signer: c.cfg.Signer,
isOurAddr: c.cfg.IsOurAddress,
contractBreach: func(
retInfo *lnwallet.BreachRetribution) error {
return c.cfg.ContractBreach(
chanPoint, retInfo,
)
},
extractStateNumHint: lnwallet.GetStateNumHint,
auxLeafStore: c.cfg.AuxLeafStore,
auxResolver: c.cfg.AuxResolver,
},
)
if err != nil {
return err
}
c.activeWatchers[chanPoint] = chainWatcher
// We'll also create a new channel arbitrator instance using this new
// channel, and our internal state.
channelArb, err := newActiveChannelArbitrator(
newChan, c, chainWatcher.SubscribeChannelEvents(),
)
if err != nil {
return err
}
// With the arbitrator created, we'll add it to our set of active
// arbitrators, then launch it.
c.activeChannels[chanPoint] = channelArb
if err := channelArb.Start(nil, c.beat); err != nil {
return err
}
return chainWatcher.Start()
}
// SubscribeChannelEvents returns a new active subscription for the set of
// possible on-chain events for a particular channel. The struct can be used by
// callers to be notified whenever an event that changes the state of the
// channel on-chain occurs.
func (c *ChainArbitrator) SubscribeChannelEvents(
chanPoint wire.OutPoint) (*ChainEventSubscription, error) {
// First, we'll attempt to look up the active watcher for this channel.
// If we can't find it, then we'll return an error back to the caller.
c.Lock()
watcher, ok := c.activeWatchers[chanPoint]
c.Unlock()
if !ok {
return nil, fmt.Errorf("unable to find watcher for: %v",
chanPoint)
}
// With the watcher located, we'll request for it to create a new chain
// event subscription client.
return watcher.SubscribeChannelEvents(), nil
}
// FindOutgoingHTLCDeadline returns the deadline in absolute block height for
// the specified outgoing HTLC. For an outgoing HTLC, its deadline is defined
// by the timeout height of its corresponding incoming HTLC - this is the
// expiry height the that remote peer can spend his/her outgoing HTLC via the
// timeout path.
func (c *ChainArbitrator) FindOutgoingHTLCDeadline(scid lnwire.ShortChannelID,
outgoingHTLC channeldb.HTLC) fn.Option[int32] {
// Find the outgoing HTLC's corresponding incoming HTLC in the circuit
// map.
rHash := outgoingHTLC.RHash
circuit := models.CircuitKey{
ChanID: scid,
HtlcID: outgoingHTLC.HtlcIndex,
}
incomingCircuit := c.cfg.QueryIncomingCircuit(circuit)
// If there's no incoming circuit found, we will use the default
// deadline.
if incomingCircuit == nil {
log.Warnf("ChannelArbitrator(%v): incoming circuit key not "+
"found for rHash=%x, using default deadline instead",
scid, rHash)
return fn.None[int32]()
}
// If this is a locally initiated HTLC, it means we are the first hop.
// In this case, we can relax the deadline.
if incomingCircuit.ChanID.IsDefault() {
log.Infof("ChannelArbitrator(%v): using default deadline for "+
"locally initiated HTLC for rHash=%x", scid, rHash)
return fn.None[int32]()
}
log.Debugf("Found incoming circuit %v for rHash=%x using outgoing "+
"circuit %v", incomingCircuit, rHash, circuit)
c.Lock()
defer c.Unlock()
// Iterate over all active channels to find the incoming HTLC specified
// by its circuit key.
for cp, channelArb := range c.activeChannels {
// Skip if the SCID doesn't match.
if channelArb.cfg.ShortChanID != incomingCircuit.ChanID {
continue
}
// Make sure the channel arbitrator has the latest view of its
// active HTLCs.
channelArb.updateActiveHTLCs()
// Iterate all the known HTLCs to find the targeted incoming
// HTLC.
for _, htlcs := range channelArb.activeHTLCs {
for _, htlc := range htlcs.incomingHTLCs {
// Skip if the index doesn't match.
if htlc.HtlcIndex != incomingCircuit.HtlcID {
continue
}
log.Debugf("ChannelArbitrator(%v): found "+
"incoming HTLC in channel=%v using "+
"rHash=%x, refundTimeout=%v", scid,
cp, rHash, htlc.RefundTimeout)
return fn.Some(int32(htlc.RefundTimeout))
}
}
}
// If there's no incoming HTLC found, yet we have the incoming circuit,
// something is wrong - in this case, we return the none deadline.
log.Errorf("ChannelArbitrator(%v): incoming HTLC not found for "+
"rHash=%x, using default deadline instead", scid, rHash)
return fn.None[int32]()
}
// TODO(roasbeef): arbitration reports
// * types: contested, waiting for success conf, etc
// NOTE: part of the `chainio.Consumer` interface.
func (c *ChainArbitrator) Name() string {
return "ChainArbitrator"
}
// loadOpenChannels loads all channels that are currently open in the database
// and registers them with the chainWatcher for future notification.
func (c *ChainArbitrator) loadOpenChannels() error {
openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels()
if err != nil {
return err
}
if len(openChannels) == 0 {
return nil
}
log.Infof("Creating ChannelArbitrators for %v active channels",
len(openChannels))
// For each open channel, we'll configure then launch a corresponding
// ChannelArbitrator.
for _, channel := range openChannels {
chanPoint := channel.FundingOutpoint
channel := channel
// First, we'll create an active chainWatcher for this channel
// to ensure that we detect any relevant on chain events.
breachClosure := func(ret *lnwallet.BreachRetribution) error {
return c.cfg.ContractBreach(chanPoint, ret)
}
chainWatcher, err := newChainWatcher(
chainWatcherConfig{
chanState: channel,
notifier: c.cfg.Notifier,
signer: c.cfg.Signer,
isOurAddr: c.cfg.IsOurAddress,
contractBreach: breachClosure,
extractStateNumHint: lnwallet.GetStateNumHint,
auxLeafStore: c.cfg.AuxLeafStore,
auxResolver: c.cfg.AuxResolver,
},
)
if err != nil {
return err
}
c.activeWatchers[chanPoint] = chainWatcher
channelArb, err := newActiveChannelArbitrator(
channel, c, chainWatcher.SubscribeChannelEvents(),
)
if err != nil {
return err
}
c.activeChannels[chanPoint] = channelArb
// Republish any closing transactions for this channel.
err = c.republishClosingTxs(channel)
if err != nil {
log.Errorf("Failed to republish closing txs for "+
"channel %v", chanPoint)
}
}
return nil
}
// loadPendingCloseChannels loads all channels that are currently pending
// closure in the database and registers them with the ChannelArbitrator to
// continue the resolution process.
func (c *ChainArbitrator) loadPendingCloseChannels() error {
chanStateDB := c.chanSource.ChannelStateDB()
closingChannels, err := chanStateDB.FetchClosedChannels(true)
if err != nil {
return err
}
if len(closingChannels) == 0 {
return nil
}
log.Infof("Creating ChannelArbitrators for %v closing channels",
len(closingChannels))
// Next, for each channel is the closing state, we'll launch a
// corresponding more restricted resolver, as we don't have to watch
// the chain any longer, only resolve the contracts on the confirmed
// commitment.
//nolint:ll
for _, closeChanInfo := range closingChannels {
// We can leave off the CloseContract and ForceCloseChan
// methods as the channel is already closed at this point.
chanPoint := closeChanInfo.ChanPoint
arbCfg := ChannelArbitratorConfig{
ChanPoint: chanPoint,
ShortChanID: closeChanInfo.ShortChanID,
ChainArbitratorConfig: c.cfg,
ChainEvents: &ChainEventSubscription{},
IsPendingClose: true,
ClosingHeight: closeChanInfo.CloseHeight,
CloseType: closeChanInfo.CloseType,
PutResolverReport: func(tx kvdb.RwTx,
report *channeldb.ResolverReport) error {
return c.chanSource.PutResolverReport(
tx, c.cfg.ChainHash, &chanPoint, report,
)
},
FetchHistoricalChannel: func() (*channeldb.OpenChannel, error) {
return chanStateDB.FetchHistoricalChannel(&chanPoint)
},
FindOutgoingHTLCDeadline: func(
htlc channeldb.HTLC) fn.Option[int32] {
return c.FindOutgoingHTLCDeadline(
closeChanInfo.ShortChanID, htlc,
)
},
}
chanLog, err := newBoltArbitratorLog(
c.chanSource.Backend, arbCfg, c.cfg.ChainHash, chanPoint,
)
if err != nil {
return err
}
arbCfg.MarkChannelResolved = func() error {
if c.cfg.NotifyFullyResolvedChannel != nil {
c.cfg.NotifyFullyResolvedChannel(chanPoint)
}
return c.ResolveContract(chanPoint)
}
// We create an empty map of HTLC's here since it's possible
// that the channel is in StateDefault and updateActiveHTLCs is
// called. We want to avoid writing to an empty map. Since the
// channel is already in the process of being resolved, no new
// HTLCs will be added.
c.activeChannels[chanPoint] = NewChannelArbitrator(
arbCfg, make(map[HtlcSetKey]htlcSet), chanLog,
)
}
return nil
}
// RedispatchBlockbeat resends the current blockbeat to the channels specified
// by the chanPoints. It is used when a channel is added to the chain
// arbitrator after it has been started, e.g., during the channel restore
// process.
func (c *ChainArbitrator) RedispatchBlockbeat(chanPoints []wire.OutPoint) {
// Get the current blockbeat.
beat := c.beat
// Prepare two sets of consumers.
channels := make([]chainio.Consumer, 0, len(chanPoints))
watchers := make([]chainio.Consumer, 0, len(chanPoints))
// Read the active channels in a lock.
c.Lock()
for _, op := range chanPoints {
if channel, ok := c.activeChannels[op]; ok {
channels = append(channels, channel)
}
if watcher, ok := c.activeWatchers[op]; ok {
watchers = append(watchers, watcher)
}
}
c.Unlock()
// Iterate all the copied watchers and send the blockbeat to them.
err := chainio.DispatchConcurrent(beat, watchers)
if err != nil {
log.Errorf("Notify blockbeat for chainWatcher failed: %v", err)
}
// Iterate all the copied channels and send the blockbeat to them.
err = chainio.DispatchConcurrent(beat, channels)
if err != nil {
// Shutdown lnd if there's an error processing the block.
log.Errorf("Notify blockbeat for ChannelArbitrator failed: %v",
err)
}
}
package contractcourt
import (
"bytes"
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/mempool"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// minCommitPointPollTimeout is the minimum time we'll wait before
// polling the database for a channel's commitpoint.
minCommitPointPollTimeout = 1 * time.Second
// maxCommitPointPollTimeout is the maximum time we'll wait before
// polling the database for a channel's commitpoint.
maxCommitPointPollTimeout = 10 * time.Minute
)
// LocalUnilateralCloseInfo encapsulates all the information we need to act on
// a local force close that gets confirmed.
type LocalUnilateralCloseInfo struct {
*chainntnfs.SpendDetail
*lnwallet.LocalForceCloseSummary
*channeldb.ChannelCloseSummary
// CommitSet is the set of known valid commitments at the time the
// remote party's commitment hit the chain.
CommitSet CommitSet
}
// CooperativeCloseInfo encapsulates all the information we need to act on a
// cooperative close that gets confirmed.
type CooperativeCloseInfo struct {
*channeldb.ChannelCloseSummary
}
// RemoteUnilateralCloseInfo wraps the normal UnilateralCloseSummary to couple
// the CommitSet at the time of channel closure.
type RemoteUnilateralCloseInfo struct {
*lnwallet.UnilateralCloseSummary
// CommitSet is the set of known valid commitments at the time the
// remote party's commitment hit the chain.
CommitSet CommitSet
}
// BreachResolution wraps the outpoint of the breached channel.
type BreachResolution struct {
FundingOutPoint wire.OutPoint
}
// BreachCloseInfo wraps the BreachResolution with a CommitSet for the latest,
// non-breached state, with the AnchorResolution for the breached state.
type BreachCloseInfo struct {
*BreachResolution
*lnwallet.AnchorResolution
// CommitHash is the hash of the commitment transaction.
CommitHash chainhash.Hash
// CommitSet is the set of known valid commitments at the time the
// breach occurred on-chain.
CommitSet CommitSet
// CloseSummary gives the recipient of the BreachCloseInfo information
// to mark the channel closed in the database.
CloseSummary channeldb.ChannelCloseSummary
}
// CommitSet is a collection of the set of known valid commitments at a given
// instant. If ConfCommitKey is set, then the commitment identified by the
// HtlcSetKey has hit the chain. This struct will be used to examine all live
// HTLCs to determine if any additional actions need to be made based on the
// remote party's commitments.
type CommitSet struct {
// When the ConfCommitKey is set, it signals that the commitment tx was
// confirmed in the chain.
ConfCommitKey fn.Option[HtlcSetKey]
// HtlcSets stores the set of all known active HTLC for each active
// commitment at the time of channel closure.
HtlcSets map[HtlcSetKey][]channeldb.HTLC
}
// IsEmpty returns true if there are no HTLCs at all within all commitments
// that are a part of this commitment diff.
func (c *CommitSet) IsEmpty() bool {
if c == nil {
return true
}
for _, htlcs := range c.HtlcSets {
if len(htlcs) != 0 {
return false
}
}
return true
}
// toActiveHTLCSets returns the set of all active HTLCs across all commitment
// transactions.
func (c *CommitSet) toActiveHTLCSets() map[HtlcSetKey]htlcSet {
htlcSets := make(map[HtlcSetKey]htlcSet)
for htlcSetKey, htlcs := range c.HtlcSets {
htlcSets[htlcSetKey] = newHtlcSet(htlcs)
}
return htlcSets
}
// String return a human-readable representation of the CommitSet.
func (c *CommitSet) String() string {
if c == nil {
return "nil"
}
// Create a descriptive string for the ConfCommitKey.
commitKey := "none"
c.ConfCommitKey.WhenSome(func(k HtlcSetKey) {
commitKey = k.String()
})
// Create a map to hold all the htlcs.
htlcSet := make(map[string]string)
for k, htlcs := range c.HtlcSets {
// Create a map for this particular set.
desc := make([]string, len(htlcs))
for i, htlc := range htlcs {
desc[i] = fmt.Sprintf("%x", htlc.RHash)
}
// Add the description to the set key.
htlcSet[k.String()] = fmt.Sprintf("count: %v, htlcs=%v",
len(htlcs), desc)
}
return fmt.Sprintf("ConfCommitKey=%v, HtlcSets=%v", commitKey, htlcSet)
}
// ChainEventSubscription is a struct that houses a subscription to be notified
// for any on-chain events related to a channel. There are three types of
// possible on-chain events: a cooperative channel closure, a unilateral
// channel closure, and a channel breach. The fourth type: a force close is
// locally initiated, so we don't provide any event stream for said event.
type ChainEventSubscription struct {
// ChanPoint is that channel that chain events will be dispatched for.
ChanPoint wire.OutPoint
// RemoteUnilateralClosure is a channel that will be sent upon in the
// event that the remote party's commitment transaction is confirmed.
RemoteUnilateralClosure chan *RemoteUnilateralCloseInfo
// LocalUnilateralClosure is a channel that will be sent upon in the
// event that our commitment transaction is confirmed.
LocalUnilateralClosure chan *LocalUnilateralCloseInfo
// CooperativeClosure is a signal that will be sent upon once a
// cooperative channel closure has been detected confirmed.
CooperativeClosure chan *CooperativeCloseInfo
// ContractBreach is a channel that will be sent upon if we detect a
// contract breach. The struct sent across the channel contains all the
// material required to bring the cheating channel peer to justice.
ContractBreach chan *BreachCloseInfo
// Cancel cancels the subscription to the event stream for a particular
// channel. This method should be called once the caller no longer needs to
// be notified of any on-chain events for a particular channel.
Cancel func()
}
// chainWatcherConfig encapsulates all the necessary functions and interfaces
// needed to watch and act on on-chain events for a particular channel.
type chainWatcherConfig struct {
// chanState is a snapshot of the persistent state of the channel that
// we're watching. In the event of an on-chain event, we'll query the
// database to ensure that we act using the most up to date state.
chanState *channeldb.OpenChannel
// notifier is a reference to the channel notifier that we'll use to be
// notified of output spends and when transactions are confirmed.
notifier chainntnfs.ChainNotifier
// signer is the main signer instances that will be responsible for
// signing any HTLC and commitment transaction generated by the state
// machine.
signer input.Signer
// contractBreach is a method that will be called by the watcher if it
// detects that a contract breach transaction has been confirmed. It
// will only return a non-nil error when the BreachArbitrator has
// preserved the necessary breach info for this channel point.
contractBreach func(*lnwallet.BreachRetribution) error
// isOurAddr is a function that returns true if the passed address is
// known to us.
isOurAddr func(btcutil.Address) bool
// extractStateNumHint extracts the encoded state hint using the passed
// obfuscater. This is used by the chain watcher to identify which
// state was broadcast and confirmed on-chain.
extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64
// auxLeafStore can be used to fetch information for custom channels.
auxLeafStore fn.Option[lnwallet.AuxLeafStore]
// auxResolver is used to supplement contract resolution.
auxResolver fn.Option[lnwallet.AuxContractResolver]
}
// chainWatcher is a system that's assigned to every active channel. The duty
// of this system is to watch the chain for spends of the channels chan point.
// If a spend is detected then with chain watcher will notify all subscribers
// that the channel has been closed, and also give them the materials necessary
// to sweep the funds of the channel on chain eventually.
type chainWatcher struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// Embed the blockbeat consumer struct to get access to the method
// `NotifyBlockProcessed` and the `BlockbeatChan`.
chainio.BeatConsumer
quit chan struct{}
wg sync.WaitGroup
cfg chainWatcherConfig
// stateHintObfuscator is a 48-bit state hint that's used to obfuscate
// the current state number on the commitment transactions.
stateHintObfuscator [lnwallet.StateHintSize]byte
// All the fields below are protected by this mutex.
sync.Mutex
// clientID is an ephemeral counter used to keep track of each
// individual client subscription.
clientID uint64
// clientSubscriptions is a map that keeps track of all the active
// client subscriptions for events related to this channel.
clientSubscriptions map[uint64]*ChainEventSubscription
// fundingSpendNtfn is the spending notification subscription for the
// funding outpoint.
fundingSpendNtfn *chainntnfs.SpendEvent
// fundingConfirmedNtfn is the confirmation notification subscription
// for the funding outpoint. This is only created if the channel is
// both taproot and pending confirmation.
//
// For taproot pkscripts, `RegisterSpendNtfn` will only notify on the
// outpoint being spent and not the outpoint+pkscript due to
// `ComputePkScript` being unable to compute the pkscript if a key
// spend is used. We need to add a `RegisterConfirmationsNtfn` here to
// ensure that the outpoint+pkscript pair is confirmed before calling
// `RegisterSpendNtfn`.
fundingConfirmedNtfn *chainntnfs.ConfirmationEvent
}
// newChainWatcher returns a new instance of a chainWatcher for a channel given
// the chan point to watch, and also a notifier instance that will allow us to
// detect on chain events.
func newChainWatcher(cfg chainWatcherConfig) (*chainWatcher, error) {
// In order to be able to detect the nature of a potential channel
// closure we'll need to reconstruct the state hint bytes used to
// obfuscate the commitment state number encoded in the lock time and
// sequence fields.
var stateHint [lnwallet.StateHintSize]byte
chanState := cfg.chanState
if chanState.IsInitiator {
stateHint = lnwallet.DeriveStateHintObfuscator(
chanState.LocalChanCfg.PaymentBasePoint.PubKey,
chanState.RemoteChanCfg.PaymentBasePoint.PubKey,
)
} else {
stateHint = lnwallet.DeriveStateHintObfuscator(
chanState.RemoteChanCfg.PaymentBasePoint.PubKey,
chanState.LocalChanCfg.PaymentBasePoint.PubKey,
)
}
// Get the witness script for the funding output.
fundingPkScript, err := deriveFundingPkScript(chanState)
if err != nil {
return nil, err
}
// Get the channel opening block height.
heightHint := chanState.DeriveHeightHint()
// We'll register for a notification to be dispatched if the funding
// output is spent.
spendNtfn, err := cfg.notifier.RegisterSpendNtfn(
&chanState.FundingOutpoint, fundingPkScript, heightHint,
)
if err != nil {
return nil, err
}
c := &chainWatcher{
cfg: cfg,
stateHintObfuscator: stateHint,
quit: make(chan struct{}),
clientSubscriptions: make(map[uint64]*ChainEventSubscription),
fundingSpendNtfn: spendNtfn,
}
// If this is a pending taproot channel, we need to register for a
// confirmation notification of the funding tx. Check the docs in
// `fundingConfirmedNtfn` for details.
if c.cfg.chanState.IsPending && c.cfg.chanState.ChanType.IsTaproot() {
confNtfn, err := cfg.notifier.RegisterConfirmationsNtfn(
&chanState.FundingOutpoint.Hash, fundingPkScript, 1,
heightHint,
)
if err != nil {
return nil, err
}
c.fundingConfirmedNtfn = confNtfn
}
// Mount the block consumer.
c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name())
return c, nil
}
// Compile-time check for the chainio.Consumer interface.
var _ chainio.Consumer = (*chainWatcher)(nil)
// Name returns the name of the watcher.
//
// NOTE: part of the `chainio.Consumer` interface.
func (c *chainWatcher) Name() string {
return fmt.Sprintf("ChainWatcher(%v)", c.cfg.chanState.FundingOutpoint)
}
// Start starts all goroutines that the chainWatcher needs to perform its
// duties.
func (c *chainWatcher) Start() error {
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return nil
}
log.Debugf("Starting chain watcher for ChannelPoint(%v)",
c.cfg.chanState.FundingOutpoint)
c.wg.Add(1)
go c.closeObserver()
return nil
}
// Stop signals the close observer to gracefully exit.
func (c *chainWatcher) Stop() error {
if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) {
return nil
}
close(c.quit)
c.wg.Wait()
return nil
}
// SubscribeChannelEvents returns an active subscription to the set of channel
// events for the channel watched by this chain watcher. Once clients no longer
// require the subscription, they should call the Cancel() method to allow the
// watcher to regain those committed resources.
func (c *chainWatcher) SubscribeChannelEvents() *ChainEventSubscription {
c.Lock()
clientID := c.clientID
c.clientID++
c.Unlock()
log.Debugf("New ChainEventSubscription(id=%v) for ChannelPoint(%v)",
clientID, c.cfg.chanState.FundingOutpoint)
sub := &ChainEventSubscription{
ChanPoint: c.cfg.chanState.FundingOutpoint,
RemoteUnilateralClosure: make(chan *RemoteUnilateralCloseInfo, 1),
LocalUnilateralClosure: make(chan *LocalUnilateralCloseInfo, 1),
CooperativeClosure: make(chan *CooperativeCloseInfo, 1),
ContractBreach: make(chan *BreachCloseInfo, 1),
Cancel: func() {
c.Lock()
delete(c.clientSubscriptions, clientID)
c.Unlock()
},
}
c.Lock()
c.clientSubscriptions[clientID] = sub
c.Unlock()
return sub
}
// handleUnknownLocalState checks whether the passed spend _could_ be a local
// state that for some reason is unknown to us. This could be a state published
// by us before we lost state, which we will try to sweep. Or it could be one
// of our revoked states that somehow made it to the chain. If that's the case
// we cannot really hope that we'll be able to get our money back, but we'll
// try to sweep it anyway. If this is not an unknown local state, false is
// returned.
func (c *chainWatcher) handleUnknownLocalState(
commitSpend *chainntnfs.SpendDetail, broadcastStateNum uint64,
chainSet *chainSet) (bool, error) {
// If the spend was a local commitment, at this point it must either be
// a past state (we breached!) or a future state (we lost state!). In
// either case, the only thing we can do is to attempt to sweep what is
// there.
// First, we'll re-derive our commitment point for this state since
// this is what we use to randomize each of the keys for this state.
commitSecret, err := c.cfg.chanState.RevocationProducer.AtIndex(
broadcastStateNum,
)
if err != nil {
return false, err
}
commitPoint := input.ComputeCommitmentPoint(commitSecret[:])
// Now that we have the commit point, we'll derive the tweaked local
// and remote keys for this state. We use our point as only we can
// revoke our own commitment.
commitKeyRing := lnwallet.DeriveCommitmentKeys(
commitPoint, lntypes.Local, c.cfg.chanState.ChanType,
&c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg,
)
auxResult, err := fn.MapOptionZ(
c.cfg.auxLeafStore,
//nolint:ll
func(s lnwallet.AuxLeafStore) fn.Result[lnwallet.CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
lnwallet.NewAuxChanState(c.cfg.chanState),
c.cfg.chanState.LocalCommitment, *commitKeyRing,
lntypes.Local,
)
},
).Unpack()
if err != nil {
return false, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// With the keys derived, we'll construct the remote script that'll be
// present if they have a non-dust balance on the commitment.
var leaseExpiry uint32
if c.cfg.chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = c.cfg.chanState.ThawHeight
}
remoteAuxLeaf := fn.FlatMapOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
)(auxResult.AuxLeaves)
remoteScript, _, err := lnwallet.CommitScriptToRemote(
c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator,
commitKeyRing.ToRemoteKey, leaseExpiry,
remoteAuxLeaf,
)
if err != nil {
return false, err
}
// Next, we'll derive our script that includes the revocation base for
// the remote party allowing them to claim this output before the CSV
// delay if we breach.
localAuxLeaf := fn.FlatMapOption(
func(l lnwallet.CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
)(auxResult.AuxLeaves)
localScript, err := lnwallet.CommitScriptToSelf(
c.cfg.chanState.ChanType, c.cfg.chanState.IsInitiator,
commitKeyRing.ToLocalKey, commitKeyRing.RevocationKey,
uint32(c.cfg.chanState.LocalChanCfg.CsvDelay), leaseExpiry,
localAuxLeaf,
)
if err != nil {
return false, err
}
// With all our scripts assembled, we'll examine the outputs of the
// commitment transaction to determine if this is a local force close
// or not.
ourCommit := false
for _, output := range commitSpend.SpendingTx.TxOut {
pkScript := output.PkScript
switch {
case bytes.Equal(localScript.PkScript(), pkScript):
ourCommit = true
case bytes.Equal(remoteScript.PkScript(), pkScript):
ourCommit = true
}
}
// If the script is not present, this cannot be our commit.
if !ourCommit {
return false, nil
}
log.Warnf("Detected local unilateral close of unknown state %v "+
"(our state=%v)", broadcastStateNum,
chainSet.localCommit.CommitHeight)
// If this is our commitment transaction, then we try to act even
// though we won't be able to sweep HTLCs.
chainSet.commitSet.ConfCommitKey = fn.Some(LocalHtlcSet)
if err := c.dispatchLocalForceClose(
commitSpend, broadcastStateNum, chainSet.commitSet,
); err != nil {
return false, fmt.Errorf("unable to handle local"+
"close for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
}
// chainSet includes all the information we need to dispatch a channel close
// event to any subscribers.
type chainSet struct {
// remoteStateNum is the commitment number of the lowest valid
// commitment the remote party holds from our PoV. This value is used
// to determine if the remote party is playing a state that's behind,
// in line, or ahead of the latest state we know for it.
remoteStateNum uint64
// commitSet includes information pertaining to the set of active HTLCs
// on each commitment.
commitSet CommitSet
// remoteCommit is the current commitment of the remote party.
remoteCommit channeldb.ChannelCommitment
// localCommit is our current commitment.
localCommit channeldb.ChannelCommitment
// remotePendingCommit points to the dangling commitment of the remote
// party, if it exists. If there's no dangling commitment, then this
// pointer will be nil.
remotePendingCommit *channeldb.ChannelCommitment
}
// newChainSet creates a new chainSet given the current up to date channel
// state.
func newChainSet(chanState *channeldb.OpenChannel) (*chainSet, error) {
// First, we'll grab the current unrevoked commitments for ourselves
// and the remote party.
localCommit, remoteCommit, err := chanState.LatestCommitments()
if err != nil {
return nil, fmt.Errorf("unable to fetch channel state for "+
"chan_point=%v: %v", chanState.FundingOutpoint, err)
}
log.Tracef("ChannelPoint(%v): local_commit_type=%v, local_commit=%v",
chanState.FundingOutpoint, chanState.ChanType,
spew.Sdump(localCommit))
log.Tracef("ChannelPoint(%v): remote_commit_type=%v, remote_commit=%v",
chanState.FundingOutpoint, chanState.ChanType,
spew.Sdump(remoteCommit))
// Fetch the current known commit height for the remote party, and
// their pending commitment chain tip if it exists.
remoteStateNum := remoteCommit.CommitHeight
remoteChainTip, err := chanState.RemoteCommitChainTip()
if err != nil && err != channeldb.ErrNoPendingCommit {
return nil, fmt.Errorf("unable to obtain chain tip for "+
"ChannelPoint(%v): %v",
chanState.FundingOutpoint, err)
}
// Now that we have all the possible valid commitments, we'll make the
// CommitSet the ChannelArbitrator will need in order to carry out its
// duty.
commitSet := CommitSet{
HtlcSets: map[HtlcSetKey][]channeldb.HTLC{
LocalHtlcSet: localCommit.Htlcs,
RemoteHtlcSet: remoteCommit.Htlcs,
},
}
var remotePendingCommit *channeldb.ChannelCommitment
if remoteChainTip != nil {
remotePendingCommit = &remoteChainTip.Commitment
log.Tracef("ChannelPoint(%v): remote_pending_commit_type=%v, "+
"remote_pending_commit=%v", chanState.FundingOutpoint,
chanState.ChanType,
spew.Sdump(remoteChainTip.Commitment))
htlcs := remoteChainTip.Commitment.Htlcs
commitSet.HtlcSets[RemotePendingHtlcSet] = htlcs
}
// We'll now retrieve the latest state of the revocation store so we
// can populate the revocation information within the channel state
// object that we have.
//
// TODO(roasbeef): mutation is bad mkay
_, err = chanState.RemoteRevocationStore()
if err != nil {
return nil, fmt.Errorf("unable to fetch revocation state for "+
"chan_point=%v", chanState.FundingOutpoint)
}
return &chainSet{
remoteStateNum: remoteStateNum,
commitSet: commitSet,
localCommit: *localCommit,
remoteCommit: *remoteCommit,
remotePendingCommit: remotePendingCommit,
}, nil
}
// closeObserver is a dedicated goroutine that will watch for any closes of the
// channel that it's watching on chain. In the event of an on-chain event, the
// close observer will assembled the proper materials required to claim the
// funds of the channel on-chain (if required), then dispatch these as
// notifications to all subscribers.
func (c *chainWatcher) closeObserver() {
defer c.wg.Done()
defer c.fundingSpendNtfn.Cancel()
log.Infof("Close observer for ChannelPoint(%v) active",
c.cfg.chanState.FundingOutpoint)
for {
select {
// A new block is received, we will check whether this block
// contains a spending tx that we are interested in.
case beat := <-c.BlockbeatChan:
log.Debugf("ChainWatcher(%v) received blockbeat %v",
c.cfg.chanState.FundingOutpoint, beat.Height())
// Process the block.
c.handleBlockbeat(beat)
// If the funding outpoint is spent, we now go ahead and handle
// it. Note that we cannot rely solely on the `block` event
// above to trigger a close event, as deep down, the receiving
// of block notifications and the receiving of spending
// notifications are done in two different goroutines, so the
// expected order: [receive block -> receive spend] is not
// guaranteed .
case spend, ok := <-c.fundingSpendNtfn.Spend:
// If the channel was closed, then this means that the
// notifier exited, so we will as well.
if !ok {
return
}
err := c.handleCommitSpend(spend)
if err != nil {
log.Errorf("Failed to handle commit spend: %v",
err)
}
// The chainWatcher has been signalled to exit, so we'll do so
// now.
case <-c.quit:
return
}
}
}
// handleKnownLocalState checks whether the passed spend is a local state that
// is known to us (the current state). If so we will act on this state using
// the passed chainSet. If this is not a known local state, false is returned.
func (c *chainWatcher) handleKnownLocalState(
commitSpend *chainntnfs.SpendDetail, broadcastStateNum uint64,
chainSet *chainSet) (bool, error) {
// If the channel is recovered, we won't have a local commit to check
// against, so immediately return.
if c.cfg.chanState.HasChanStatus(channeldb.ChanStatusRestored) {
return false, nil
}
commitTxBroadcast := commitSpend.SpendingTx
commitHash := commitTxBroadcast.TxHash()
// Check whether our latest local state hit the chain.
if chainSet.localCommit.CommitTx.TxHash() != commitHash {
return false, nil
}
chainSet.commitSet.ConfCommitKey = fn.Some(LocalHtlcSet)
if err := c.dispatchLocalForceClose(
commitSpend, broadcastStateNum, chainSet.commitSet,
); err != nil {
return false, fmt.Errorf("unable to handle local"+
"close for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
}
// handleKnownRemoteState checks whether the passed spend is a remote state
// that is known to us (a revoked, current or pending state). If so we will act
// on this state using the passed chainSet. If this is not a known remote
// state, false is returned.
func (c *chainWatcher) handleKnownRemoteState(
commitSpend *chainntnfs.SpendDetail, broadcastStateNum uint64,
chainSet *chainSet) (bool, error) {
// If the channel is recovered, we won't have any remote commit to
// check against, so imemdiately return.
if c.cfg.chanState.HasChanStatus(channeldb.ChanStatusRestored) {
return false, nil
}
commitTxBroadcast := commitSpend.SpendingTx
commitHash := commitTxBroadcast.TxHash()
switch {
// If the spending transaction matches the current latest state, then
// they've initiated a unilateral close. So we'll trigger the
// unilateral close signal so subscribers can clean up the state as
// necessary.
case chainSet.remoteCommit.CommitTx.TxHash() == commitHash:
log.Infof("Remote party broadcast base set, "+
"commit_num=%v", chainSet.remoteStateNum)
chainSet.commitSet.ConfCommitKey = fn.Some(RemoteHtlcSet)
err := c.dispatchRemoteForceClose(
commitSpend, chainSet.remoteCommit,
chainSet.commitSet,
c.cfg.chanState.RemoteCurrentRevocation,
)
if err != nil {
return false, fmt.Errorf("unable to handle remote "+
"close for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
// We'll also handle the case of the remote party broadcasting
// their commitment transaction which is one height above ours.
// This case can arise when we initiate a state transition, but
// the remote party has a fail crash _after_ accepting the new
// state, but _before_ sending their signature to us.
case chainSet.remotePendingCommit != nil &&
chainSet.remotePendingCommit.CommitTx.TxHash() == commitHash:
log.Infof("Remote party broadcast pending set, "+
"commit_num=%v", chainSet.remoteStateNum+1)
chainSet.commitSet.ConfCommitKey = fn.Some(RemotePendingHtlcSet)
err := c.dispatchRemoteForceClose(
commitSpend, *chainSet.remotePendingCommit,
chainSet.commitSet,
c.cfg.chanState.RemoteNextRevocation,
)
if err != nil {
return false, fmt.Errorf("unable to handle remote "+
"close for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
}
// This is neither a remote force close or a "future" commitment, we
// now check whether it's a remote breach and properly handle it.
return c.handlePossibleBreach(commitSpend, broadcastStateNum, chainSet)
}
// handlePossibleBreach checks whether the remote has breached and dispatches a
// breach resolution to claim funds.
func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail,
broadcastStateNum uint64, chainSet *chainSet) (bool, error) {
// We check if we have a revoked state at this state num that matches
// the spend transaction.
spendHeight := uint32(commitSpend.SpendingHeight)
retribution, err := lnwallet.NewBreachRetribution(
c.cfg.chanState, broadcastStateNum, spendHeight,
commitSpend.SpendingTx, c.cfg.auxLeafStore, c.cfg.auxResolver,
)
switch {
// If we had no log entry at this height, this was not a revoked state.
case err == channeldb.ErrLogEntryNotFound:
return false, nil
case err == channeldb.ErrNoPastDeltas:
return false, nil
case err != nil:
return false, fmt.Errorf("unable to create breach "+
"retribution: %v", err)
}
// We found a revoked state at this height, but it could still be our
// own broadcasted state we are looking at. Therefore check that the
// commit matches before assuming it was a breach.
commitHash := commitSpend.SpendingTx.TxHash()
if retribution.BreachTxHash != commitHash {
return false, nil
}
// Create an AnchorResolution for the breached state.
anchorRes, err := lnwallet.NewAnchorResolution(
c.cfg.chanState, commitSpend.SpendingTx, retribution.KeyRing,
lntypes.Remote,
)
if err != nil {
return false, fmt.Errorf("unable to create anchor "+
"resolution: %v", err)
}
// We'll set the ConfCommitKey here as the remote htlc set. This is
// only used to ensure a nil-pointer-dereference doesn't occur and is
// not used otherwise. The HTLC's may not exist for the
// RemotePendingHtlcSet.
chainSet.commitSet.ConfCommitKey = fn.Some(RemoteHtlcSet)
// THEY'RE ATTEMPTING TO VIOLATE THE CONTRACT LAID OUT WITHIN THE
// PAYMENT CHANNEL. Therefore we close the signal indicating a revoked
// broadcast to allow subscribers to swiftly dispatch justice!!!
err = c.dispatchContractBreach(
commitSpend, chainSet, broadcastStateNum, retribution,
anchorRes,
)
if err != nil {
return false, fmt.Errorf("unable to handle channel "+
"breach for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
}
// handleUnknownRemoteState is the last attempt we make at reclaiming funds
// from the closed channel, by checkin whether the passed spend _could_ be a
// remote spend that is unknown to us (we lost state). We will try to initiate
// Data Loss Protection in order to restore our commit point and reclaim our
// funds from the channel. If we are not able to act on it, false is returned.
func (c *chainWatcher) handleUnknownRemoteState(
commitSpend *chainntnfs.SpendDetail, broadcastStateNum uint64,
chainSet *chainSet) (bool, error) {
log.Warnf("Remote node broadcast state #%v, "+
"which is more than 1 beyond best known "+
"state #%v!!! Attempting recovery...",
broadcastStateNum, chainSet.remoteStateNum)
// If this isn't a tweakless commitment, then we'll need to wait for
// the remote party's latest unrevoked commitment point to be presented
// to us as we need this to sweep. Otherwise, we can dispatch the
// remote close and sweep immediately using a fake commitPoint as it
// isn't actually needed for recovery anymore.
commitPoint := c.cfg.chanState.RemoteCurrentRevocation
tweaklessCommit := c.cfg.chanState.ChanType.IsTweakless()
if !tweaklessCommit {
commitPoint = c.waitForCommitmentPoint()
if commitPoint == nil {
return false, fmt.Errorf("unable to get commit point")
}
log.Infof("Recovered commit point(%x) for "+
"channel(%v)! Now attempting to use it to "+
"sweep our funds...",
commitPoint.SerializeCompressed(),
c.cfg.chanState.FundingOutpoint)
} else {
log.Infof("ChannelPoint(%v) is tweakless, "+
"moving to sweep directly on chain",
c.cfg.chanState.FundingOutpoint)
}
// Since we don't have the commitment stored for this state, we'll just
// pass an empty commitment within the commitment set. Note that this
// means we won't be able to recover any HTLC funds.
//
// TODO(halseth): can we try to recover some HTLCs?
chainSet.commitSet.ConfCommitKey = fn.Some(RemoteHtlcSet)
err := c.dispatchRemoteForceClose(
commitSpend, channeldb.ChannelCommitment{},
chainSet.commitSet, commitPoint,
)
if err != nil {
return false, fmt.Errorf("unable to handle remote "+
"close for chan_point=%v: %v",
c.cfg.chanState.FundingOutpoint, err)
}
return true, nil
}
// toSelfAmount takes a transaction and returns the sum of all outputs that pay
// to a script that the wallet controls or the channel defines as its delivery
// script . If no outputs pay to us (determined by these criteria), then we
// return zero. This is possible as our output may have been trimmed due to
// being dust.
func (c *chainWatcher) toSelfAmount(tx *wire.MsgTx) btcutil.Amount {
// There are two main cases we have to handle here. First, in the coop
// close case we will always have saved the delivery address we used
// whether it was from the upfront shutdown, from the delivery address
// requested at close time, or even an automatically generated one. All
// coop-close cases can be identified in the following manner:
shutdown, _ := c.cfg.chanState.ShutdownInfo()
oDeliveryAddr := fn.MapOption(
func(i channeldb.ShutdownInfo) lnwire.DeliveryAddress {
return i.DeliveryScript.Val
})(shutdown)
// Here we define a function capable of identifying whether an output
// corresponds with our local delivery script from a ShutdownInfo if we
// have a ShutdownInfo for this chainWatcher's underlying channel.
//
// isDeliveryOutput :: *TxOut -> bool
isDeliveryOutput := func(o *wire.TxOut) bool {
return fn.ElimOption(
oDeliveryAddr,
// If we don't have a delivery addr, then the output
// can't match it.
func() bool { return false },
// Otherwise if the PkScript of the TxOut matches our
// delivery script then this is a delivery output.
func(a lnwire.DeliveryAddress) bool {
return slices.Equal(a, o.PkScript)
},
)
}
// Here we define a function capable of identifying whether an output
// belongs to the LND wallet. We use this as a heuristic in the case
// where we might be looking for spendable force closure outputs.
//
// isWalletOutput :: *TxOut -> bool
isWalletOutput := func(out *wire.TxOut) bool {
_, addrs, _, err := txscript.ExtractPkScriptAddrs(
// Doesn't matter what net we actually pass in.
out.PkScript, &chaincfg.TestNet3Params,
)
if err != nil {
return false
}
return fn.Any(addrs, c.cfg.isOurAddr)
}
// Grab all of the outputs that correspond with our delivery address
// or our wallet is aware of.
outs := fn.Filter(tx.TxOut, fn.PredOr(isDeliveryOutput, isWalletOutput))
// Grab the values for those outputs.
vals := fn.Map(outs, func(o *wire.TxOut) int64 { return o.Value })
// Return the sum.
return btcutil.Amount(fn.Sum(vals))
}
// dispatchCooperativeClose processed a detect cooperative channel closure.
// We'll use the spending transaction to locate our output within the
// transaction, then clean up the database state. We'll also dispatch a
// notification to all subscribers that the channel has been closed in this
// manner.
func (c *chainWatcher) dispatchCooperativeClose(commitSpend *chainntnfs.SpendDetail) error {
broadcastTx := commitSpend.SpendingTx
log.Infof("Cooperative closure for ChannelPoint(%v): %v",
c.cfg.chanState.FundingOutpoint, spew.Sdump(broadcastTx))
// If the input *is* final, then we'll check to see which output is
// ours.
localAmt := c.toSelfAmount(broadcastTx)
// Once this is known, we'll mark the state as fully closed in the
// database. We can do this as a cooperatively closed channel has all
// its outputs resolved after only one confirmation.
closeSummary := &channeldb.ChannelCloseSummary{
ChanPoint: c.cfg.chanState.FundingOutpoint,
ChainHash: c.cfg.chanState.ChainHash,
ClosingTXID: *commitSpend.SpenderTxHash,
RemotePub: c.cfg.chanState.IdentityPub,
Capacity: c.cfg.chanState.Capacity,
CloseHeight: uint32(commitSpend.SpendingHeight),
SettledBalance: localAmt,
CloseType: channeldb.CooperativeClose,
ShortChanID: c.cfg.chanState.ShortChanID(),
IsPending: true,
RemoteCurrentRevocation: c.cfg.chanState.RemoteCurrentRevocation,
RemoteNextRevocation: c.cfg.chanState.RemoteNextRevocation,
LocalChanConfig: c.cfg.chanState.LocalChanCfg,
}
// Attempt to add a channel sync message to the close summary.
chanSync, err := c.cfg.chanState.ChanSyncMsg()
if err != nil {
log.Errorf("ChannelPoint(%v): unable to create channel sync "+
"message: %v", c.cfg.chanState.FundingOutpoint, err)
} else {
closeSummary.LastChanSyncMsg = chanSync
}
// Create a summary of all the information needed to handle the
// cooperative closure.
closeInfo := &CooperativeCloseInfo{
ChannelCloseSummary: closeSummary,
}
// With the event processed, we'll now notify all subscribers of the
// event.
c.Lock()
for _, sub := range c.clientSubscriptions {
select {
case sub.CooperativeClosure <- closeInfo:
case <-c.quit:
c.Unlock()
return fmt.Errorf("exiting")
}
}
c.Unlock()
return nil
}
// dispatchLocalForceClose processes a unilateral close by us being confirmed.
func (c *chainWatcher) dispatchLocalForceClose(
commitSpend *chainntnfs.SpendDetail,
stateNum uint64, commitSet CommitSet) error {
log.Infof("Local unilateral close of ChannelPoint(%v) "+
"detected", c.cfg.chanState.FundingOutpoint)
forceClose, err := lnwallet.NewLocalForceCloseSummary(
c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum,
c.cfg.auxLeafStore, c.cfg.auxResolver,
)
if err != nil {
return err
}
// As we've detected that the channel has been closed, immediately
// creating a close summary for future usage by related sub-systems.
chanSnapshot := forceClose.ChanSnapshot
closeSummary := &channeldb.ChannelCloseSummary{
ChanPoint: chanSnapshot.ChannelPoint,
ChainHash: chanSnapshot.ChainHash,
ClosingTXID: forceClose.CloseTx.TxHash(),
RemotePub: &chanSnapshot.RemoteIdentity,
Capacity: chanSnapshot.Capacity,
CloseType: channeldb.LocalForceClose,
IsPending: true,
ShortChanID: c.cfg.chanState.ShortChanID(),
CloseHeight: uint32(commitSpend.SpendingHeight),
RemoteCurrentRevocation: c.cfg.chanState.RemoteCurrentRevocation,
RemoteNextRevocation: c.cfg.chanState.RemoteNextRevocation,
LocalChanConfig: c.cfg.chanState.LocalChanCfg,
}
resolutions, err := forceClose.ContractResolutions.UnwrapOrErr(
fmt.Errorf("resolutions not found"),
)
if err != nil {
return err
}
// If our commitment output isn't dust or we have active HTLC's on the
// commitment transaction, then we'll populate the balances on the
// close channel summary.
if resolutions.CommitResolution != nil {
localBalance := chanSnapshot.LocalBalance.ToSatoshis()
closeSummary.SettledBalance = localBalance
closeSummary.TimeLockedBalance = localBalance
}
if resolutions.HtlcResolutions != nil {
for _, htlc := range resolutions.HtlcResolutions.OutgoingHTLCs {
htlcValue := btcutil.Amount(
htlc.SweepSignDesc.Output.Value,
)
closeSummary.TimeLockedBalance += htlcValue
}
}
// Attempt to add a channel sync message to the close summary.
chanSync, err := c.cfg.chanState.ChanSyncMsg()
if err != nil {
log.Errorf("ChannelPoint(%v): unable to create channel sync "+
"message: %v", c.cfg.chanState.FundingOutpoint, err)
} else {
closeSummary.LastChanSyncMsg = chanSync
}
// With the event processed, we'll now notify all subscribers of the
// event.
closeInfo := &LocalUnilateralCloseInfo{
SpendDetail: commitSpend,
LocalForceCloseSummary: forceClose,
ChannelCloseSummary: closeSummary,
CommitSet: commitSet,
}
c.Lock()
for _, sub := range c.clientSubscriptions {
select {
case sub.LocalUnilateralClosure <- closeInfo:
case <-c.quit:
c.Unlock()
return fmt.Errorf("exiting")
}
}
c.Unlock()
return nil
}
// dispatchRemoteForceClose processes a detected unilateral channel closure by
// the remote party. This function will prepare a UnilateralCloseSummary which
// will then be sent to any subscribers allowing them to resolve all our funds
// in the channel on chain. Once this close summary is prepared, all registered
// subscribers will receive a notification of this event. The commitPoint
// argument should be set to the per_commitment_point corresponding to the
// spending commitment.
//
// NOTE: The remoteCommit argument should be set to the stored commitment for
// this particular state. If we don't have the commitment stored (should only
// happen in case we have lost state) it should be set to an empty struct, in
// which case we will attempt to sweep the non-HTLC output using the passed
// commitPoint.
func (c *chainWatcher) dispatchRemoteForceClose(
commitSpend *chainntnfs.SpendDetail,
remoteCommit channeldb.ChannelCommitment,
commitSet CommitSet, commitPoint *btcec.PublicKey) error {
log.Infof("Unilateral close of ChannelPoint(%v) "+
"detected", c.cfg.chanState.FundingOutpoint)
// First, we'll create a closure summary that contains all the
// materials required to let each subscriber sweep the funds in the
// channel on-chain.
uniClose, err := lnwallet.NewUnilateralCloseSummary(
c.cfg.chanState, c.cfg.signer, commitSpend, remoteCommit,
commitPoint, c.cfg.auxLeafStore, c.cfg.auxResolver,
)
if err != nil {
return err
}
// With the event processed, we'll now notify all subscribers of the
// event.
c.Lock()
for _, sub := range c.clientSubscriptions {
select {
case sub.RemoteUnilateralClosure <- &RemoteUnilateralCloseInfo{
UnilateralCloseSummary: uniClose,
CommitSet: commitSet,
}:
case <-c.quit:
c.Unlock()
return fmt.Errorf("exiting")
}
}
c.Unlock()
return nil
}
// dispatchContractBreach processes a detected contract breached by the remote
// party. This method is to be called once we detect that the remote party has
// broadcast a prior revoked commitment state. This method well prepare all the
// materials required to bring the cheater to justice, then notify all
// registered subscribers of this event.
func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail,
chainSet *chainSet, broadcastStateNum uint64,
retribution *lnwallet.BreachRetribution,
anchorRes *lnwallet.AnchorResolution) error {
log.Warnf("Remote peer has breached the channel contract for "+
"ChannelPoint(%v). Revoked state #%v was broadcast!!!",
c.cfg.chanState.FundingOutpoint, broadcastStateNum)
if err := c.cfg.chanState.MarkBorked(); err != nil {
return fmt.Errorf("unable to mark channel as borked: %w", err)
}
spendHeight := uint32(spendEvent.SpendingHeight)
log.Debugf("Punishment breach retribution created: %v",
lnutils.NewLogClosure(func() string {
retribution.KeyRing.LocalHtlcKey = nil
retribution.KeyRing.RemoteHtlcKey = nil
retribution.KeyRing.ToLocalKey = nil
retribution.KeyRing.ToRemoteKey = nil
retribution.KeyRing.RevocationKey = nil
return spew.Sdump(retribution)
}))
settledBalance := chainSet.remoteCommit.LocalBalance.ToSatoshis()
closeSummary := channeldb.ChannelCloseSummary{
ChanPoint: c.cfg.chanState.FundingOutpoint,
ChainHash: c.cfg.chanState.ChainHash,
ClosingTXID: *spendEvent.SpenderTxHash,
CloseHeight: spendHeight,
RemotePub: c.cfg.chanState.IdentityPub,
Capacity: c.cfg.chanState.Capacity,
SettledBalance: settledBalance,
CloseType: channeldb.BreachClose,
IsPending: true,
ShortChanID: c.cfg.chanState.ShortChanID(),
RemoteCurrentRevocation: c.cfg.chanState.RemoteCurrentRevocation,
RemoteNextRevocation: c.cfg.chanState.RemoteNextRevocation,
LocalChanConfig: c.cfg.chanState.LocalChanCfg,
}
// Attempt to add a channel sync message to the close summary.
chanSync, err := c.cfg.chanState.ChanSyncMsg()
if err != nil {
log.Errorf("ChannelPoint(%v): unable to create channel sync "+
"message: %v", c.cfg.chanState.FundingOutpoint, err)
} else {
closeSummary.LastChanSyncMsg = chanSync
}
// Hand the retribution info over to the BreachArbitrator. This function
// will wait for a response from the breach arbiter and then proceed to
// send a BreachCloseInfo to the channel arbitrator. The channel arb
// will then mark the channel as closed after resolutions and the
// commit set are logged in the arbitrator log.
if err := c.cfg.contractBreach(retribution); err != nil {
log.Errorf("unable to hand breached contract off to "+
"BreachArbitrator: %v", err)
return err
}
breachRes := &BreachResolution{
FundingOutPoint: c.cfg.chanState.FundingOutpoint,
}
breachInfo := &BreachCloseInfo{
CommitHash: spendEvent.SpendingTx.TxHash(),
BreachResolution: breachRes,
AnchorResolution: anchorRes,
CommitSet: chainSet.commitSet,
CloseSummary: closeSummary,
}
// With the event processed and channel closed, we'll now notify all
// subscribers of the event.
c.Lock()
for _, sub := range c.clientSubscriptions {
select {
case sub.ContractBreach <- breachInfo:
case <-c.quit:
c.Unlock()
return fmt.Errorf("quitting")
}
}
c.Unlock()
return nil
}
// waitForCommitmentPoint waits for the commitment point to be inserted into
// the local database. We'll use this method in the DLP case, to wait for the
// remote party to send us their point, as we can't proceed until we have that.
func (c *chainWatcher) waitForCommitmentPoint() *btcec.PublicKey {
// If we are lucky, the remote peer sent us the correct commitment
// point during channel sync, such that we can sweep our funds. If we
// cannot find the commit point, there's not much we can do other than
// wait for us to retrieve it. We will attempt to retrieve it from the
// peer each time we connect to it.
//
// TODO(halseth): actively initiate re-connection to the peer?
backoff := minCommitPointPollTimeout
for {
commitPoint, err := c.cfg.chanState.DataLossCommitPoint()
if err == nil {
return commitPoint
}
log.Errorf("Unable to retrieve commitment point for "+
"channel(%v) with lost state: %v. Retrying in %v.",
c.cfg.chanState.FundingOutpoint, err, backoff)
select {
// Wait before retrying, with an exponential backoff.
case <-time.After(backoff):
backoff = 2 * backoff
if backoff > maxCommitPointPollTimeout {
backoff = maxCommitPointPollTimeout
}
case <-c.quit:
return nil
}
}
}
// deriveFundingPkScript derives the script used in the funding output.
func deriveFundingPkScript(chanState *channeldb.OpenChannel) ([]byte, error) {
localKey := chanState.LocalChanCfg.MultiSigKey.PubKey
remoteKey := chanState.RemoteChanCfg.MultiSigKey.PubKey
var (
err error
fundingPkScript []byte
)
if chanState.ChanType.IsTaproot() {
fundingPkScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, 0, chanState.TapscriptRoot,
)
if err != nil {
return nil, err
}
} else {
multiSigScript, err := input.GenMultiSigScript(
localKey.SerializeCompressed(),
remoteKey.SerializeCompressed(),
)
if err != nil {
return nil, err
}
fundingPkScript, err = input.WitnessScriptHash(multiSigScript)
if err != nil {
return nil, err
}
}
return fundingPkScript, nil
}
// handleCommitSpend takes a spending tx of the funding output and handles the
// channel close based on the closure type.
func (c *chainWatcher) handleCommitSpend(
commitSpend *chainntnfs.SpendDetail) error {
commitTxBroadcast := commitSpend.SpendingTx
// First, we'll construct the chainset which includes all the data we
// need to dispatch an event to our subscribers about this possible
// channel close event.
chainSet, err := newChainSet(c.cfg.chanState)
if err != nil {
return fmt.Errorf("create commit set: %w", err)
}
// Decode the state hint encoded within the commitment transaction to
// determine if this is a revoked state or not.
obfuscator := c.stateHintObfuscator
broadcastStateNum := c.cfg.extractStateNumHint(
commitTxBroadcast, obfuscator,
)
// We'll go on to check whether it could be our own commitment that was
// published and know is confirmed.
ok, err := c.handleKnownLocalState(
commitSpend, broadcastStateNum, chainSet,
)
if err != nil {
return fmt.Errorf("handle known local state: %w", err)
}
if ok {
return nil
}
// Now that we know it is neither a non-cooperative closure nor a local
// close with the latest state, we check if it is the remote that
// closed with any prior or current state.
ok, err = c.handleKnownRemoteState(
commitSpend, broadcastStateNum, chainSet,
)
if err != nil {
return fmt.Errorf("handle known remote state: %w", err)
}
if ok {
return nil
}
// Next, we'll check to see if this is a cooperative channel closure or
// not. This is characterized by having an input sequence number that's
// finalized. This won't happen with regular commitment transactions
// due to the state hint encoding scheme.
switch commitTxBroadcast.TxIn[0].Sequence {
case wire.MaxTxInSequenceNum:
fallthrough
case mempool.MaxRBFSequence:
// TODO(roasbeef): rare but possible, need itest case for
err := c.dispatchCooperativeClose(commitSpend)
if err != nil {
return fmt.Errorf("handle coop close: %w", err)
}
return nil
}
log.Warnf("Unknown commitment broadcast for ChannelPoint(%v) ",
c.cfg.chanState.FundingOutpoint)
// We'll try to recover as best as possible from losing state. We
// first check if this was a local unknown state. This could happen if
// we force close, then lose state or attempt recovery before the
// commitment confirms.
ok, err = c.handleUnknownLocalState(
commitSpend, broadcastStateNum, chainSet,
)
if err != nil {
return fmt.Errorf("handle known local state: %w", err)
}
if ok {
return nil
}
// Since it was neither a known remote state, nor a local state that
// was published, it most likely mean we lost state and the remote node
// closed. In this case we must start the DLP protocol in hope of
// getting our money back.
ok, err = c.handleUnknownRemoteState(
commitSpend, broadcastStateNum, chainSet,
)
if err != nil {
return fmt.Errorf("handle unknown remote state: %w", err)
}
if ok {
return nil
}
log.Errorf("Unable to handle spending tx %v of channel point %v",
commitTxBroadcast.TxHash(), c.cfg.chanState.FundingOutpoint)
return nil
}
// checkFundingSpend performs a non-blocking read on the spendNtfn channel to
// check whether there's a commit spend already. Returns the spend details if
// found.
func (c *chainWatcher) checkFundingSpend() *chainntnfs.SpendDetail {
select {
// We've detected a spend of the channel onchain! Depending on the type
// of spend, we'll act accordingly, so we'll examine the spending
// transaction to determine what we should do.
//
// TODO(Roasbeef): need to be able to ensure this only triggers
// on confirmation, to ensure if multiple txns are broadcast, we
// act on the one that's timestamped
case spend, ok := <-c.fundingSpendNtfn.Spend:
// If the channel was closed, then this means that the notifier
// exited, so we will as well.
if !ok {
return nil
}
log.Debugf("Found spend details for funding output: %v",
spend.SpenderTxHash)
return spend
default:
}
return nil
}
// chanPointConfirmed checks whether the given channel point has confirmed.
// This is used to ensure that the funding output has confirmed on chain before
// we proceed with the rest of the close observer logic for taproot channels.
// Check the docs in `fundingConfirmedNtfn` for details.
func (c *chainWatcher) chanPointConfirmed() bool {
op := c.cfg.chanState.FundingOutpoint
select {
case _, ok := <-c.fundingConfirmedNtfn.Confirmed:
// If the channel was closed, then this means that the notifier
// exited, so we will as well.
if !ok {
return false
}
log.Debugf("Taproot ChannelPoint(%v) confirmed", op)
// The channel point has confirmed on chain. We now cancel the
// subscription.
c.fundingConfirmedNtfn.Cancel()
return true
default:
log.Infof("Taproot ChannelPoint(%v) not confirmed yet", op)
return false
}
}
// handleBlockbeat takes a blockbeat and queries for a spending tx for the
// funding output. If the spending tx is found, it will be handled based on the
// closure type.
func (c *chainWatcher) handleBlockbeat(beat chainio.Blockbeat) {
// Notify the chain watcher has processed the block.
defer c.NotifyBlockProcessed(beat, nil)
// If we have a fundingConfirmedNtfn, it means this is a taproot
// channel that is pending, before we proceed, we want to ensure that
// the expected funding output has confirmed on chain. Check the docs
// in `fundingConfirmedNtfn` for details.
if c.fundingConfirmedNtfn != nil {
// If the funding output hasn't confirmed in this block, we
// will check it again in the next block.
if !c.chanPointConfirmed() {
return
}
}
// Perform a non-blocking read to check whether the funding output was
// spent.
spend := c.checkFundingSpend()
if spend == nil {
log.Tracef("No spend found for ChannelPoint(%v) in block %v",
c.cfg.chanState.FundingOutpoint, beat.Height())
return
}
// The funding output was spent, we now handle it by sending a close
// event to the channel arbitrator.
err := c.handleCommitSpend(spend)
if err != nil {
log.Errorf("Failed to handle commit spend: %v", err)
}
}
package contractcourt
import (
"bytes"
"context"
"errors"
"fmt"
"math"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/sweep"
)
var (
// errAlreadyForceClosed is an error returned when we attempt to force
// close a channel that's already in the process of doing so.
errAlreadyForceClosed = errors.New("channel is already in the " +
"process of being force closed")
)
const (
// arbitratorBlockBufferSize is the size of the buffer we give to each
// channel arbitrator.
arbitratorBlockBufferSize = 20
// AnchorOutputValue is the output value for the anchor output of an
// anchor channel.
// See BOLT 03 for more details:
// https://github.com/lightning/bolts/blob/master/03-transactions.md
AnchorOutputValue = btcutil.Amount(330)
)
// WitnessSubscription represents an intent to be notified once new witnesses
// are discovered by various active contract resolvers. A contract resolver may
// use this to be notified of when it can satisfy an incoming contract after we
// discover the witness for an outgoing contract.
type WitnessSubscription struct {
// WitnessUpdates is a channel that newly discovered witnesses will be
// sent over.
//
// TODO(roasbeef): couple with WitnessType?
WitnessUpdates <-chan lntypes.Preimage
// CancelSubscription is a function closure that should be used by a
// client to cancel the subscription once they are no longer interested
// in receiving new updates.
CancelSubscription func()
}
// WitnessBeacon is a global beacon of witnesses. Contract resolvers will use
// this interface to lookup witnesses (preimages typically) of contracts
// they're trying to resolve, add new preimages they resolve, and finally
// receive new updates each new time a preimage is discovered.
//
// TODO(roasbeef): need to delete the pre-images once we've used them
// and have been sufficiently confirmed?
type WitnessBeacon interface {
// SubscribeUpdates returns a channel that will be sent upon *each* time
// a new preimage is discovered.
SubscribeUpdates(chanID lnwire.ShortChannelID, htlc *channeldb.HTLC,
payload *hop.Payload,
nextHopOnionBlob []byte) (*WitnessSubscription, error)
// LookupPreimage attempts to lookup a preimage in the global cache.
// True is returned for the second argument if the preimage is found.
LookupPreimage(payhash lntypes.Hash) (lntypes.Preimage, bool)
// AddPreimages adds a batch of newly discovered preimages to the global
// cache, and also signals any subscribers of the newly discovered
// witness.
AddPreimages(preimages ...lntypes.Preimage) error
}
// ArbChannel is an abstraction that allows the channel arbitrator to interact
// with an open channel.
type ArbChannel interface {
// ForceCloseChan should force close the contract that this attendant
// is watching over. We'll use this when we decide that we need to go
// to chain. It should in addition tell the switch to remove the
// corresponding link, such that we won't accept any new updates. The
// returned summary contains all items needed to eventually resolve all
// outputs on chain.
ForceCloseChan() (*wire.MsgTx, error)
// NewAnchorResolutions returns the anchor resolutions for currently
// valid commitment transactions.
NewAnchorResolutions() (*lnwallet.AnchorResolutions, error)
}
// ChannelArbitratorConfig contains all the functionality that the
// ChannelArbitrator needs in order to properly arbitrate any contract dispute
// on chain.
type ChannelArbitratorConfig struct {
// ChanPoint is the channel point that uniquely identifies this
// channel.
ChanPoint wire.OutPoint
// Channel is the full channel data structure. For legacy channels, this
// field may not always be set after a restart.
Channel ArbChannel
// ShortChanID describes the exact location of the channel within the
// chain. We'll use this to address any messages that we need to send
// to the switch during contract resolution.
ShortChanID lnwire.ShortChannelID
// ChainEvents is an active subscription to the chain watcher for this
// channel to be notified of any on-chain activity related to this
// channel.
ChainEvents *ChainEventSubscription
// MarkCommitmentBroadcasted should mark the channel as the commitment
// being broadcast, and we are waiting for the commitment to confirm.
MarkCommitmentBroadcasted func(*wire.MsgTx, lntypes.ChannelParty) error
// MarkChannelClosed marks the channel closed in the database, with the
// passed close summary. After this method successfully returns we can
// no longer expect to receive chain events for this channel, and must
// be able to recover from a failure without getting the close event
// again. It takes an optional channel status which will update the
// channel status in the record that we keep of historical channels.
MarkChannelClosed func(*channeldb.ChannelCloseSummary,
...channeldb.ChannelStatus) error
// IsPendingClose is a boolean indicating whether the channel is marked
// as pending close in the database.
IsPendingClose bool
// ClosingHeight is the height at which the channel was closed. Note
// that this value is only valid if IsPendingClose is true.
ClosingHeight uint32
// CloseType is the type of the close event in case IsPendingClose is
// true. Otherwise this value is unset.
CloseType channeldb.ClosureType
// MarkChannelResolved is a function closure that serves to mark a
// channel as "fully resolved". A channel itself can be considered
// fully resolved once all active contracts have individually been
// fully resolved.
//
// TODO(roasbeef): need RPC's to combine for pendingchannels RPC
MarkChannelResolved func() error
// PutResolverReport records a resolver report for the channel. If the
// transaction provided is nil, the function should write the report
// in a new transaction.
PutResolverReport func(tx kvdb.RwTx,
report *channeldb.ResolverReport) error
// FetchHistoricalChannel retrieves the historical state of a channel.
// This is mostly used to supplement the ContractResolvers with
// additional information required for proper contract resolution.
FetchHistoricalChannel func() (*channeldb.OpenChannel, error)
// FindOutgoingHTLCDeadline returns the deadline in absolute block
// height for the specified outgoing HTLC. For an outgoing HTLC, its
// deadline is defined by the timeout height of its corresponding
// incoming HTLC - this is the expiry height the that remote peer can
// spend his/her outgoing HTLC via the timeout path.
FindOutgoingHTLCDeadline func(htlc channeldb.HTLC) fn.Option[int32]
ChainArbitratorConfig
}
// ReportOutputType describes the type of output that is being reported
// on.
type ReportOutputType uint8
const (
// ReportOutputIncomingHtlc is an incoming hash time locked contract on
// the commitment tx.
ReportOutputIncomingHtlc ReportOutputType = iota
// ReportOutputOutgoingHtlc is an outgoing hash time locked contract on
// the commitment tx.
ReportOutputOutgoingHtlc
// ReportOutputUnencumbered is an uncontested output on the commitment
// transaction paying to us directly.
ReportOutputUnencumbered
// ReportOutputAnchor is an anchor output on the commitment tx.
ReportOutputAnchor
)
// ContractReport provides a summary of a commitment tx output.
type ContractReport struct {
// Outpoint is the final output that will be swept back to the wallet.
Outpoint wire.OutPoint
// Type indicates the type of the reported output.
Type ReportOutputType
// Amount is the final value that will be swept in back to the wallet.
Amount btcutil.Amount
// MaturityHeight is the absolute block height that this output will
// mature at.
MaturityHeight uint32
// Stage indicates whether the htlc is in the CLTV-timeout stage (1) or
// the CSV-delay stage (2). A stage 1 htlc's maturity height will be set
// to its expiry height, while a stage 2 htlc's maturity height will be
// set to its confirmation height plus the maturity requirement.
Stage uint32
// LimboBalance is the total number of frozen coins within this
// contract.
LimboBalance btcutil.Amount
// RecoveredBalance is the total value that has been successfully swept
// back to the user's wallet.
RecoveredBalance btcutil.Amount
}
// resolverReport creates a resolve report using some of the information in the
// contract report.
func (c *ContractReport) resolverReport(spendTx *chainhash.Hash,
resolverType channeldb.ResolverType,
outcome channeldb.ResolverOutcome) *channeldb.ResolverReport {
return &channeldb.ResolverReport{
OutPoint: c.Outpoint,
Amount: c.Amount,
ResolverType: resolverType,
ResolverOutcome: outcome,
SpendTxID: spendTx,
}
}
// htlcSet represents the set of active HTLCs on a given commitment
// transaction.
type htlcSet struct {
// incomingHTLCs is a map of all incoming HTLCs on the target
// commitment transaction. We may potentially go onchain to claim the
// funds sent to us within this set.
incomingHTLCs map[uint64]channeldb.HTLC
// outgoingHTLCs is a map of all outgoing HTLCs on the target
// commitment transaction. We may potentially go onchain to reclaim the
// funds that are currently in limbo.
outgoingHTLCs map[uint64]channeldb.HTLC
}
// newHtlcSet constructs a new HTLC set from a slice of HTLC's.
func newHtlcSet(htlcs []channeldb.HTLC) htlcSet {
outHTLCs := make(map[uint64]channeldb.HTLC)
inHTLCs := make(map[uint64]channeldb.HTLC)
for _, htlc := range htlcs {
if htlc.Incoming {
inHTLCs[htlc.HtlcIndex] = htlc
continue
}
outHTLCs[htlc.HtlcIndex] = htlc
}
return htlcSet{
incomingHTLCs: inHTLCs,
outgoingHTLCs: outHTLCs,
}
}
// HtlcSetKey is a two-tuple that uniquely identifies a set of HTLCs on a
// commitment transaction.
type HtlcSetKey struct {
// IsRemote denotes if the HTLCs are on the remote commitment
// transaction.
IsRemote bool
// IsPending denotes if the commitment transaction that HTLCS are on
// are pending (the higher of two unrevoked commitments).
IsPending bool
}
var (
// LocalHtlcSet is the HtlcSetKey used for local commitments.
LocalHtlcSet = HtlcSetKey{IsRemote: false, IsPending: false}
// RemoteHtlcSet is the HtlcSetKey used for remote commitments.
RemoteHtlcSet = HtlcSetKey{IsRemote: true, IsPending: false}
// RemotePendingHtlcSet is the HtlcSetKey used for dangling remote
// commitment transactions.
RemotePendingHtlcSet = HtlcSetKey{IsRemote: true, IsPending: true}
)
// String returns a human readable string describing the target HtlcSetKey.
func (h HtlcSetKey) String() string {
switch h {
case LocalHtlcSet:
return "LocalHtlcSet"
case RemoteHtlcSet:
return "RemoteHtlcSet"
case RemotePendingHtlcSet:
return "RemotePendingHtlcSet"
default:
return "unknown HtlcSetKey"
}
}
// ChannelArbitrator is the on-chain arbitrator for a particular channel. The
// struct will keep in sync with the current set of HTLCs on the commitment
// transaction. The job of the attendant is to go on-chain to either settle or
// cancel an HTLC as necessary iff: an HTLC times out, or we known the
// pre-image to an HTLC, but it wasn't settled by the link off-chain. The
// ChannelArbitrator will factor in an expected confirmation delta when
// broadcasting to ensure that we avoid any possibility of race conditions, and
// sweep the output(s) without contest.
type ChannelArbitrator struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// Embed the blockbeat consumer struct to get access to the method
// `NotifyBlockProcessed` and the `BlockbeatChan`.
chainio.BeatConsumer
// startTimestamp is the time when this ChannelArbitrator was started.
startTimestamp time.Time
// log is a persistent log that the attendant will use to checkpoint
// its next action, and the state of any unresolved contracts.
log ArbitratorLog
// activeHTLCs is the set of active incoming/outgoing HTLC's on all
// currently valid commitment transactions.
activeHTLCs map[HtlcSetKey]htlcSet
// unmergedSet is used to update the activeHTLCs map in two callsites:
// checkLocalChainActions and sweepAnchors. It contains the latest
// updates from the link. It is not deleted from, its entries may be
// replaced on subsequent calls to notifyContractUpdate.
unmergedSet map[HtlcSetKey]htlcSet
unmergedMtx sync.RWMutex
// cfg contains all the functionality that the ChannelArbitrator requires
// to do its duty.
cfg ChannelArbitratorConfig
// signalUpdates is a channel that any new live signals for the channel
// we're watching over will be sent.
signalUpdates chan *signalUpdateMsg
// activeResolvers is a slice of any active resolvers. This is used to
// be able to signal them for shutdown in the case that we shutdown.
activeResolvers []ContractResolver
// activeResolversLock prevents simultaneous read and write to the
// resolvers slice.
activeResolversLock sync.RWMutex
// resolutionSignal is a channel that will be sent upon by contract
// resolvers once their contract has been fully resolved. With each
// send, we'll check to see if the contract is fully resolved.
resolutionSignal chan struct{}
// forceCloseReqs is a channel that requests to forcibly close the
// contract will be sent over.
forceCloseReqs chan *forceCloseReq
// state is the current state of the arbitrator. This state is examined
// upon start up to decide which actions to take.
state ArbitratorState
wg sync.WaitGroup
quit chan struct{}
}
// NewChannelArbitrator returns a new instance of a ChannelArbitrator backed by
// the passed config struct.
func NewChannelArbitrator(cfg ChannelArbitratorConfig,
htlcSets map[HtlcSetKey]htlcSet, log ArbitratorLog) *ChannelArbitrator {
// Create a new map for unmerged HTLC's as we will overwrite the values
// and want to avoid modifying activeHTLCs directly. This soft copying
// is done to ensure that activeHTLCs isn't reset as an empty map later
// on.
unmerged := make(map[HtlcSetKey]htlcSet)
unmerged[LocalHtlcSet] = htlcSets[LocalHtlcSet]
unmerged[RemoteHtlcSet] = htlcSets[RemoteHtlcSet]
// If the pending set exists, write that as well.
if _, ok := htlcSets[RemotePendingHtlcSet]; ok {
unmerged[RemotePendingHtlcSet] = htlcSets[RemotePendingHtlcSet]
}
c := &ChannelArbitrator{
log: log,
signalUpdates: make(chan *signalUpdateMsg),
resolutionSignal: make(chan struct{}),
forceCloseReqs: make(chan *forceCloseReq),
activeHTLCs: htlcSets,
unmergedSet: unmerged,
cfg: cfg,
quit: make(chan struct{}),
}
// Mount the block consumer.
c.BeatConsumer = chainio.NewBeatConsumer(c.quit, c.Name())
return c
}
// Compile-time check for the chainio.Consumer interface.
var _ chainio.Consumer = (*ChannelArbitrator)(nil)
// chanArbStartState contains the information from disk that we need to start
// up a channel arbitrator.
type chanArbStartState struct {
currentState ArbitratorState
commitSet *CommitSet
}
// getStartState retrieves the information from disk that our channel arbitrator
// requires to start.
func (c *ChannelArbitrator) getStartState(tx kvdb.RTx) (*chanArbStartState,
error) {
// First, we'll read our last state from disk, so our internal state
// machine can act accordingly.
state, err := c.log.CurrentState(tx)
if err != nil {
return nil, err
}
// Next we'll fetch our confirmed commitment set. This will only exist
// if the channel has been closed out on chain for modern nodes. For
// older nodes, this won't be found at all, and will rely on the
// existing written chain actions. Additionally, if this channel hasn't
// logged any actions in the log, then this field won't be present.
commitSet, err := c.log.FetchConfirmedCommitSet(tx)
if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist {
return nil, err
}
return &chanArbStartState{
currentState: state,
commitSet: commitSet,
}, nil
}
// Start starts all the goroutines that the ChannelArbitrator needs to operate.
// If takes a start state, which will be looked up on disk if it is not
// provided.
func (c *ChannelArbitrator) Start(state *chanArbStartState,
beat chainio.Blockbeat) error {
if !atomic.CompareAndSwapInt32(&c.started, 0, 1) {
return nil
}
c.startTimestamp = c.cfg.Clock.Now()
// If the state passed in is nil, we look it up now.
if state == nil {
var err error
state, err = c.getStartState(nil)
if err != nil {
return err
}
}
log.Tracef("Starting ChannelArbitrator(%v), htlc_set=%v, state=%v",
c.cfg.ChanPoint, lnutils.SpewLogClosure(c.activeHTLCs),
state.currentState)
// Set our state from our starting state.
c.state = state.currentState
// Get the starting height.
bestHeight := beat.Height()
c.wg.Add(1)
go c.channelAttendant(bestHeight, state.commitSet)
return nil
}
// progressStateMachineAfterRestart attempts to progress the state machine
// after a restart. This makes sure that if the state transition failed, we
// will try to progress the state machine again. Moreover it will relaunch
// resolvers if the channel is still in the pending close state and has not
// been fully resolved yet.
func (c *ChannelArbitrator) progressStateMachineAfterRestart(bestHeight int32,
commitSet *CommitSet) error {
// If the channel has been marked pending close in the database, and we
// haven't transitioned the state machine to StateContractClosed (or a
// succeeding state), then a state transition most likely failed. We'll
// try to recover from this by manually advancing the state by setting
// the corresponding close trigger.
trigger := chainTrigger
triggerHeight := uint32(bestHeight)
if c.cfg.IsPendingClose {
switch c.state {
case StateDefault:
fallthrough
case StateBroadcastCommit:
fallthrough
case StateCommitmentBroadcasted:
switch c.cfg.CloseType {
case channeldb.CooperativeClose:
trigger = coopCloseTrigger
case channeldb.BreachClose:
trigger = breachCloseTrigger
case channeldb.LocalForceClose:
trigger = localCloseTrigger
case channeldb.RemoteForceClose:
trigger = remoteCloseTrigger
}
log.Warnf("ChannelArbitrator(%v): detected stalled "+
"state=%v for closed channel",
c.cfg.ChanPoint, c.state)
}
triggerHeight = c.cfg.ClosingHeight
}
log.Infof("ChannelArbitrator(%v): starting state=%v, trigger=%v, "+
"triggerHeight=%v", c.cfg.ChanPoint, c.state, trigger,
triggerHeight)
// We'll now attempt to advance our state forward based on the current
// on-chain state, and our set of active contracts.
startingState := c.state
nextState, _, err := c.advanceState(
triggerHeight, trigger, commitSet,
)
if err != nil {
switch err {
// If we detect that we tried to fetch resolutions, but failed,
// this channel was marked closed in the database before
// resolutions successfully written. In this case there is not
// much we can do, so we don't return the error.
case errScopeBucketNoExist:
fallthrough
case errNoResolutions:
log.Warnf("ChannelArbitrator(%v): detected closed"+
"channel with no contract resolutions written.",
c.cfg.ChanPoint)
default:
return err
}
}
// If we start and ended at the awaiting full resolution state, then
// we'll relaunch our set of unresolved contracts.
if startingState == StateWaitingFullResolution &&
nextState == StateWaitingFullResolution {
// In order to relaunch the resolvers, we'll need to fetch the
// set of HTLCs that were present in the commitment transaction
// at the time it was confirmed. commitSet.ConfCommitKey can't
// be nil at this point since we're in
// StateWaitingFullResolution. We can only be in
// StateWaitingFullResolution after we've transitioned from
// StateContractClosed which can only be triggered by the local
// or remote close trigger. This trigger is only fired when we
// receive a chain event from the chain watcher that the
// commitment has been confirmed on chain, and before we
// advance our state step, we call InsertConfirmedCommitSet.
err := c.relaunchResolvers(commitSet, triggerHeight)
if err != nil {
return err
}
}
return nil
}
// maybeAugmentTaprootResolvers will update the contract resolution information
// for taproot channels. This ensures that all the resolvers have the latest
// resolution, which may also include data such as the control block and tap
// tweaks.
func maybeAugmentTaprootResolvers(chanType channeldb.ChannelType,
resolver ContractResolver,
contractResolutions *ContractResolutions) {
if !chanType.IsTaproot() {
return
}
// The on disk resolutions contains all the ctrl block
// information, so we'll set that now for the relevant
// resolvers.
switch r := resolver.(type) {
case *commitSweepResolver:
if contractResolutions.CommitResolution != nil {
//nolint:ll
r.commitResolution = *contractResolutions.CommitResolution
}
case *htlcOutgoingContestResolver:
//nolint:ll
htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcTimeoutResolver:
//nolint:ll
htlcResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcIncomingContestResolver:
//nolint:ll
htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
case *htlcSuccessResolver:
//nolint:ll
htlcResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs
for _, htlcRes := range htlcResolutions {
htlcRes := htlcRes
if r.htlcResolution.ClaimOutpoint ==
htlcRes.ClaimOutpoint {
r.htlcResolution = htlcRes
}
}
}
}
// relauchResolvers relaunches the set of resolvers for unresolved contracts in
// order to provide them with information that's not immediately available upon
// starting the ChannelArbitrator. This information should ideally be stored in
// the database, so this only serves as a intermediate work-around to prevent a
// migration.
func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet,
heightHint uint32) error {
// We'll now query our log to see if there are any active unresolved
// contracts. If this is the case, then we'll relaunch all contract
// resolvers.
unresolvedContracts, err := c.log.FetchUnresolvedContracts()
if err != nil {
return err
}
// Retrieve the commitment tx hash from the log.
contractResolutions, err := c.log.FetchContractResolutions()
if err != nil {
log.Errorf("unable to fetch contract resolutions: %v",
err)
return err
}
commitHash := contractResolutions.CommitHash
// In prior versions of lnd, the information needed to supplement the
// resolvers (in most cases, the full amount of the HTLC) was found in
// the chain action map, which is now deprecated. As a result, if the
// commitSet is nil (an older node with unresolved HTLCs at time of
// upgrade), then we'll use the chain action information in place. The
// chain actions may exclude some information, but we cannot recover it
// for these older nodes at the moment.
var confirmedHTLCs []channeldb.HTLC
if commitSet != nil && commitSet.ConfCommitKey.IsSome() {
confCommitKey, err := commitSet.ConfCommitKey.UnwrapOrErr(
fmt.Errorf("no commitKey available"),
)
if err != nil {
return err
}
confirmedHTLCs = commitSet.HtlcSets[confCommitKey]
} else {
chainActions, err := c.log.FetchChainActions()
if err != nil {
log.Errorf("unable to fetch chain actions: %v", err)
return err
}
for _, htlcs := range chainActions {
confirmedHTLCs = append(confirmedHTLCs, htlcs...)
}
}
// Reconstruct the htlc outpoints and data from the chain action log.
// The purpose of the constructed htlc map is to supplement to
// resolvers restored from database with extra data. Ideally this data
// is stored as part of the resolver in the log. This is a workaround
// to prevent a db migration. We use all available htlc sets here in
// order to ensure we have complete coverage.
htlcMap := make(map[wire.OutPoint]*channeldb.HTLC)
for _, htlc := range confirmedHTLCs {
htlc := htlc
outpoint := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
htlcMap[outpoint] = &htlc
}
// We'll also fetch the historical state of this channel, as it should
// have been marked as closed by now, and supplement it to each resolver
// such that we can properly resolve our pending contracts.
var chanState *channeldb.OpenChannel
chanState, err = c.cfg.FetchHistoricalChannel()
switch {
// If we don't find this channel, then it may be the case that it
// was closed before we started to retain the final state
// information for open channels.
case err == channeldb.ErrNoHistoricalBucket:
fallthrough
case err == channeldb.ErrChannelNotFound:
log.Warnf("ChannelArbitrator(%v): unable to fetch historical "+
"state", c.cfg.ChanPoint)
case err != nil:
return err
}
log.Infof("ChannelArbitrator(%v): relaunching %v contract "+
"resolvers", c.cfg.ChanPoint, len(unresolvedContracts))
for i := range unresolvedContracts {
resolver := unresolvedContracts[i]
if chanState != nil {
resolver.SupplementState(chanState)
// For taproot channels, we'll need to also make sure
// the control block information was set properly.
maybeAugmentTaprootResolvers(
chanState.ChanType, resolver,
contractResolutions,
)
}
unresolvedContracts[i] = resolver
htlcResolver, ok := resolver.(htlcContractResolver)
if !ok {
continue
}
htlcPoint := htlcResolver.HtlcPoint()
htlc, ok := htlcMap[htlcPoint]
if !ok {
return fmt.Errorf(
"htlc resolver %T unavailable", resolver,
)
}
htlcResolver.Supplement(*htlc)
// If this is an outgoing HTLC, we will also need to supplement
// the resolver with the expiry block height of its
// corresponding incoming HTLC.
if !htlc.Incoming {
deadline := c.cfg.FindOutgoingHTLCDeadline(*htlc)
htlcResolver.SupplementDeadline(deadline)
}
}
// The anchor resolver is stateless and can always be re-instantiated.
if contractResolutions.AnchorResolution != nil {
anchorResolver := newAnchorResolver(
contractResolutions.AnchorResolution.AnchorSignDescriptor,
contractResolutions.AnchorResolution.CommitAnchor,
heightHint, c.cfg.ChanPoint,
ResolverConfig{
ChannelArbitratorConfig: c.cfg,
},
)
anchorResolver.SupplementState(chanState)
unresolvedContracts = append(unresolvedContracts, anchorResolver)
// TODO(roasbeef): this isn't re-launched?
}
c.resolveContracts(unresolvedContracts)
return nil
}
// Report returns htlc reports for the active resolvers.
func (c *ChannelArbitrator) Report() []*ContractReport {
c.activeResolversLock.RLock()
defer c.activeResolversLock.RUnlock()
var reports []*ContractReport
for _, resolver := range c.activeResolvers {
r, ok := resolver.(reportingContractResolver)
if !ok {
continue
}
report := r.report()
if report == nil {
continue
}
reports = append(reports, report)
}
return reports
}
// Stop signals the ChannelArbitrator for a graceful shutdown.
func (c *ChannelArbitrator) Stop() error {
if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) {
return nil
}
log.Debugf("Stopping ChannelArbitrator(%v)", c.cfg.ChanPoint)
if c.cfg.ChainEvents.Cancel != nil {
go c.cfg.ChainEvents.Cancel()
}
c.activeResolversLock.RLock()
for _, activeResolver := range c.activeResolvers {
activeResolver.Stop()
}
c.activeResolversLock.RUnlock()
close(c.quit)
c.wg.Wait()
return nil
}
// transitionTrigger is an enum that denotes exactly *why* a state transition
// was initiated. This is useful as depending on the initial trigger, we may
// skip certain states as those actions are expected to have already taken
// place as a result of the external trigger.
type transitionTrigger uint8
const (
// chainTrigger is a transition trigger that has been attempted due to
// changing on-chain conditions such as a block which times out HTLC's
// being attached.
chainTrigger transitionTrigger = iota
// userTrigger is a transition trigger driven by user action. Examples
// of such a trigger include a user requesting a force closure of the
// channel.
userTrigger
// remoteCloseTrigger is a transition trigger driven by the remote
// peer's commitment being confirmed.
remoteCloseTrigger
// localCloseTrigger is a transition trigger driven by our commitment
// being confirmed.
localCloseTrigger
// coopCloseTrigger is a transition trigger driven by a cooperative
// close transaction being confirmed.
coopCloseTrigger
// breachCloseTrigger is a transition trigger driven by a remote breach
// being confirmed. In this case the channel arbitrator will wait for
// the BreachArbitrator to finish and then clean up gracefully.
breachCloseTrigger
)
// String returns a human readable string describing the passed
// transitionTrigger.
func (t transitionTrigger) String() string {
switch t {
case chainTrigger:
return "chainTrigger"
case remoteCloseTrigger:
return "remoteCloseTrigger"
case userTrigger:
return "userTrigger"
case localCloseTrigger:
return "localCloseTrigger"
case coopCloseTrigger:
return "coopCloseTrigger"
case breachCloseTrigger:
return "breachCloseTrigger"
default:
return "unknown trigger"
}
}
// stateStep is a help method that examines our internal state, and attempts
// the appropriate state transition if necessary. The next state we transition
// to is returned, Additionally, if the next transition results in a commitment
// broadcast, the commitment transaction itself is returned.
func (c *ChannelArbitrator) stateStep(
triggerHeight uint32, trigger transitionTrigger,
confCommitSet *CommitSet) (ArbitratorState, *wire.MsgTx, error) {
var (
nextState ArbitratorState
closeTx *wire.MsgTx
)
switch c.state {
// If we're in the default state, then we'll check our set of actions
// to see if while we were down, conditions have changed.
case StateDefault:
log.Debugf("ChannelArbitrator(%v): examining active HTLCs in "+
"block %v, confCommitSet: %v", c.cfg.ChanPoint,
triggerHeight, lnutils.LogClosure(confCommitSet.String))
// As a new block has been connected to the end of the main
// chain, we'll check to see if we need to make any on-chain
// claims on behalf of the channel contract that we're
// arbitrating for. If a commitment has confirmed, then we'll
// use the set snapshot from the chain, otherwise we'll use our
// current set.
var (
chainActions ChainActionMap
err error
)
// Normally if we force close the channel locally we will have
// no confCommitSet. However when the remote commitment confirms
// without us ever broadcasting our local commitment we need to
// make sure we cancel all upstream HTLCs for outgoing dust
// HTLCs as well hence we need to fetch the chain actions here
// as well.
if confCommitSet == nil {
// Update the set of activeHTLCs so
// checkLocalChainActions has an up-to-date view of the
// commitments.
c.updateActiveHTLCs()
htlcs := c.activeHTLCs
chainActions, err = c.checkLocalChainActions(
triggerHeight, trigger, htlcs, false,
)
if err != nil {
return StateDefault, nil, err
}
} else {
chainActions, err = c.constructChainActions(
confCommitSet, triggerHeight, trigger,
)
if err != nil {
return StateDefault, nil, err
}
}
// If there are no actions to be made, then we'll remain in the
// default state. If this isn't a self initiated event (we're
// checking due to a chain update), then we'll exit now.
if len(chainActions) == 0 && trigger == chainTrigger {
log.Debugf("ChannelArbitrator(%v): no actions for "+
"chain trigger, terminating", c.cfg.ChanPoint)
return StateDefault, closeTx, nil
}
// Otherwise, we'll log that we checked the HTLC actions as the
// commitment transaction has already been broadcast.
log.Tracef("ChannelArbitrator(%v): logging chain_actions=%v",
c.cfg.ChanPoint, lnutils.SpewLogClosure(chainActions))
// Cancel upstream HTLCs for all outgoing dust HTLCs available
// either on the local or the remote/remote pending commitment
// transaction.
dustHTLCs := chainActions[HtlcFailDustAction]
if len(dustHTLCs) > 0 {
log.Debugf("ChannelArbitrator(%v): canceling %v dust "+
"HTLCs backwards", c.cfg.ChanPoint,
len(dustHTLCs))
getIdx := func(htlc channeldb.HTLC) uint64 {
return htlc.HtlcIndex
}
dustHTLCSet := fn.NewSet(fn.Map(dustHTLCs, getIdx)...)
err = c.abandonForwards(dustHTLCSet)
if err != nil {
return StateError, closeTx, err
}
}
// Depending on the type of trigger, we'll either "tunnel"
// through to a farther state, or just proceed linearly to the
// next state.
switch trigger {
// If this is a chain trigger, then we'll go straight to the
// next state, as we still need to broadcast the commitment
// transaction.
case chainTrigger:
fallthrough
case userTrigger:
nextState = StateBroadcastCommit
// If the trigger is a cooperative close being confirmed, then
// we can go straight to StateFullyResolved, as there won't be
// any contracts to resolve.
case coopCloseTrigger:
nextState = StateFullyResolved
// Otherwise, if this state advance was triggered by a
// commitment being confirmed on chain, then we'll jump
// straight to the state where the contract has already been
// closed, and we will inspect the set of unresolved contracts.
case localCloseTrigger:
log.Errorf("ChannelArbitrator(%v): unexpected local "+
"commitment confirmed while in StateDefault",
c.cfg.ChanPoint)
fallthrough
case remoteCloseTrigger:
nextState = StateContractClosed
case breachCloseTrigger:
nextContractState, err := c.checkLegacyBreach()
if nextContractState == StateError {
return nextContractState, nil, err
}
nextState = nextContractState
}
// If we're in this state, then we've decided to broadcast the
// commitment transaction. We enter this state either due to an outside
// sub-system, or because an on-chain action has been triggered.
case StateBroadcastCommit:
// Under normal operation, we can only enter
// StateBroadcastCommit via a user or chain trigger. On restart,
// this state may be reexecuted after closing the channel, but
// failing to commit to StateContractClosed or
// StateFullyResolved. In that case, one of the four close
// triggers will be presented, signifying that we should skip
// rebroadcasting, and go straight to resolving the on-chain
// contract or marking the channel resolved.
switch trigger {
case localCloseTrigger, remoteCloseTrigger:
log.Infof("ChannelArbitrator(%v): detected %s "+
"close after closing channel, fast-forwarding "+
"to %s to resolve contract",
c.cfg.ChanPoint, trigger, StateContractClosed)
return StateContractClosed, closeTx, nil
case breachCloseTrigger:
nextContractState, err := c.checkLegacyBreach()
if nextContractState == StateError {
log.Infof("ChannelArbitrator(%v): unable to "+
"advance breach close resolution: %v",
c.cfg.ChanPoint, nextContractState)
return StateError, closeTx, err
}
log.Infof("ChannelArbitrator(%v): detected %s close "+
"after closing channel, fast-forwarding to %s"+
" to resolve contract", c.cfg.ChanPoint,
trigger, nextContractState)
return nextContractState, closeTx, nil
case coopCloseTrigger:
log.Infof("ChannelArbitrator(%v): detected %s "+
"close after closing channel, fast-forwarding "+
"to %s to resolve contract",
c.cfg.ChanPoint, trigger, StateFullyResolved)
return StateFullyResolved, closeTx, nil
}
log.Infof("ChannelArbitrator(%v): force closing "+
"chan", c.cfg.ChanPoint)
// Now that we have all the actions decided for the set of
// HTLC's, we'll broadcast the commitment transaction, and
// signal the link to exit.
// We'll tell the switch that it should remove the link for
// this channel, in addition to fetching the force close
// summary needed to close this channel on chain.
forceCloseTx, err := c.cfg.Channel.ForceCloseChan()
if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+
"force close: %v", c.cfg.ChanPoint, err)
// We tried to force close (HTLC may be expiring from
// our PoV, etc), but we think we've lost data. In this
// case, we'll not force close, but terminate the state
// machine here to wait to see what confirms on chain.
if errors.Is(err, lnwallet.ErrForceCloseLocalDataLoss) {
log.Error("ChannelArbitrator(%v): broadcast "+
"failed due to local data loss, "+
"waiting for on chain confimation...",
c.cfg.ChanPoint)
return StateBroadcastCommit, nil, nil
}
return StateError, closeTx, err
}
closeTx = forceCloseTx
// Before publishing the transaction, we store it to the
// database, such that we can re-publish later in case it
// didn't propagate. We initiated the force close, so we
// mark broadcast with local initiator set to true.
err = c.cfg.MarkCommitmentBroadcasted(closeTx, lntypes.Local)
if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+
"mark commitment broadcasted: %v",
c.cfg.ChanPoint, err)
return StateError, closeTx, err
}
// With the close transaction in hand, broadcast the
// transaction to the network, thereby entering the post
// channel resolution state.
log.Infof("Broadcasting force close transaction %v, "+
"ChannelPoint(%v): %v", closeTx.TxHash(),
c.cfg.ChanPoint, lnutils.SpewLogClosure(closeTx))
// At this point, we'll now broadcast the commitment
// transaction itself.
label := labels.MakeLabel(
labels.LabelTypeChannelClose, &c.cfg.ShortChanID,
)
if err := c.cfg.PublishTx(closeTx, label); err != nil {
log.Errorf("ChannelArbitrator(%v): unable to broadcast "+
"close tx: %v", c.cfg.ChanPoint, err)
// This makes sure we don't fail at startup if the
// commitment transaction has too low fees to make it
// into mempool. The rebroadcaster makes sure this
// transaction is republished regularly until confirmed
// or replaced.
if !errors.Is(err, lnwallet.ErrDoubleSpend) &&
!errors.Is(err, lnwallet.ErrMempoolFee) {
return StateError, closeTx, err
}
}
// We go to the StateCommitmentBroadcasted state, where we'll
// be waiting for the commitment to be confirmed.
nextState = StateCommitmentBroadcasted
// In this state we have broadcasted our own commitment, and will need
// to wait for a commitment (not necessarily the one we broadcasted!)
// to be confirmed.
case StateCommitmentBroadcasted:
switch trigger {
// We are waiting for a commitment to be confirmed.
case chainTrigger, userTrigger:
// The commitment transaction has been broadcast, but it
// doesn't necessarily need to be the commitment
// transaction version that is going to be confirmed. To
// be sure that any of those versions can be anchored
// down, we now submit all anchor resolutions to the
// sweeper. The sweeper will keep trying to sweep all of
// them.
//
// Note that the sweeper is idempotent. If we ever
// happen to end up at this point in the code again, no
// harm is done by re-offering the anchors to the
// sweeper.
anchors, err := c.cfg.Channel.NewAnchorResolutions()
if err != nil {
return StateError, closeTx, err
}
err = c.sweepAnchors(anchors, triggerHeight)
if err != nil {
return StateError, closeTx, err
}
nextState = StateCommitmentBroadcasted
// If this state advance was triggered by any of the
// commitments being confirmed, then we'll jump to the state
// where the contract has been closed.
case localCloseTrigger, remoteCloseTrigger:
nextState = StateContractClosed
// If a coop close was confirmed, jump straight to the fully
// resolved state.
case coopCloseTrigger:
nextState = StateFullyResolved
case breachCloseTrigger:
nextContractState, err := c.checkLegacyBreach()
if nextContractState == StateError {
return nextContractState, closeTx, err
}
nextState = nextContractState
}
log.Infof("ChannelArbitrator(%v): trigger %v moving from "+
"state %v to %v", c.cfg.ChanPoint, trigger, c.state,
nextState)
// If we're in this state, then the contract has been fully closed to
// outside sub-systems, so we'll process the prior set of on-chain
// contract actions and launch a set of resolvers.
case StateContractClosed:
// First, we'll fetch our chain actions, and both sets of
// resolutions so we can process them.
contractResolutions, err := c.log.FetchContractResolutions()
if err != nil {
log.Errorf("unable to fetch contract resolutions: %v",
err)
return StateError, closeTx, err
}
// If the resolution is empty, and we have no HTLCs at all to
// send to, then we're done here. We don't need to launch any
// resolvers, and can go straight to our final state.
if contractResolutions.IsEmpty() && confCommitSet.IsEmpty() {
log.Infof("ChannelArbitrator(%v): contract "+
"resolutions empty, marking channel as fully resolved!",
c.cfg.ChanPoint)
nextState = StateFullyResolved
break
}
// First, we'll reconstruct a fresh set of chain actions as the
// set of actions we need to act on may differ based on if it
// was our commitment, or they're commitment that hit the chain.
htlcActions, err := c.constructChainActions(
confCommitSet, triggerHeight, trigger,
)
if err != nil {
return StateError, closeTx, err
}
// In case its a breach transaction we fail back all upstream
// HTLCs for their corresponding outgoing HTLCs on the remote
// commitment set (remote and remote pending set).
if contractResolutions.BreachResolution != nil {
// cancelBreachedHTLCs is a set which holds HTLCs whose
// corresponding incoming HTLCs will be failed back
// because the peer broadcasted an old state.
cancelBreachedHTLCs := fn.NewSet[uint64]()
// We'll use the CommitSet, we'll fail back all
// upstream HTLCs for their corresponding outgoing
// HTLC that exist on either of the remote commitments.
// The map is used to deduplicate any shared HTLC's.
for htlcSetKey, htlcs := range confCommitSet.HtlcSets {
if !htlcSetKey.IsRemote {
continue
}
for _, htlc := range htlcs {
// Only outgoing HTLCs have a
// corresponding incoming HTLC.
if htlc.Incoming {
continue
}
cancelBreachedHTLCs.Add(htlc.HtlcIndex)
}
}
err := c.abandonForwards(cancelBreachedHTLCs)
if err != nil {
return StateError, closeTx, err
}
} else {
// If it's not a breach, we resolve all incoming dust
// HTLCs immediately after the commitment is confirmed.
err = c.failIncomingDust(
htlcActions[HtlcIncomingDustFinalAction],
)
if err != nil {
return StateError, closeTx, err
}
// We fail the upstream HTLCs for all remote pending
// outgoing HTLCs as soon as the commitment is
// confirmed. The upstream HTLCs for outgoing dust
// HTLCs have already been resolved before we reach
// this point.
getIdx := func(htlc channeldb.HTLC) uint64 {
return htlc.HtlcIndex
}
remoteDangling := fn.NewSet(fn.Map(
htlcActions[HtlcFailDanglingAction], getIdx,
)...)
err := c.abandonForwards(remoteDangling)
if err != nil {
return StateError, closeTx, err
}
}
// Now that we know we'll need to act, we'll process all the
// resolvers, then create the structures we need to resolve all
// outstanding contracts.
resolvers, err := c.prepContractResolutions(
contractResolutions, triggerHeight, htlcActions,
)
if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to "+
"resolve contracts: %v", c.cfg.ChanPoint, err)
return StateError, closeTx, err
}
log.Debugf("ChannelArbitrator(%v): inserting %v contract "+
"resolvers", c.cfg.ChanPoint, len(resolvers))
err = c.log.InsertUnresolvedContracts(nil, resolvers...)
if err != nil {
return StateError, closeTx, err
}
// Finally, we'll launch all the required contract resolvers.
// Once they're all resolved, we're no longer needed.
c.resolveContracts(resolvers)
nextState = StateWaitingFullResolution
// This is our terminal state. We'll keep returning this state until
// all contracts are fully resolved.
case StateWaitingFullResolution:
log.Infof("ChannelArbitrator(%v): still awaiting contract "+
"resolution", c.cfg.ChanPoint)
unresolved, err := c.log.FetchUnresolvedContracts()
if err != nil {
return StateError, closeTx, err
}
// If we have no unresolved contracts, then we can move to the
// final state.
if len(unresolved) == 0 {
nextState = StateFullyResolved
break
}
// Otherwise we still have unresolved contracts, then we'll
// stay alive to oversee their resolution.
nextState = StateWaitingFullResolution
// Add debug logs.
for _, r := range unresolved {
log.Debugf("ChannelArbitrator(%v): still have "+
"unresolved contract: %T", c.cfg.ChanPoint, r)
}
// If we start as fully resolved, then we'll end as fully resolved.
case StateFullyResolved:
// To ensure that the state of the contract in persistent
// storage is properly reflected, we'll mark the contract as
// fully resolved now.
nextState = StateFullyResolved
log.Infof("ChannelPoint(%v) has been fully resolved "+
"on-chain at height=%v", c.cfg.ChanPoint, triggerHeight)
if err := c.cfg.MarkChannelResolved(); err != nil {
log.Errorf("unable to mark channel resolved: %v", err)
return StateError, closeTx, err
}
}
log.Tracef("ChannelArbitrator(%v): next_state=%v", c.cfg.ChanPoint,
nextState)
return nextState, closeTx, nil
}
// sweepAnchors offers all given anchor resolutions to the sweeper. It requests
// sweeping at the minimum fee rate. This fee rate can be upped manually by the
// user via the BumpFee rpc.
func (c *ChannelArbitrator) sweepAnchors(anchors *lnwallet.AnchorResolutions,
heightHint uint32) error {
// Update the set of activeHTLCs so that the sweeping routine has an
// up-to-date view of the set of commitments.
c.updateActiveHTLCs()
// Prepare the sweeping requests for all possible versions of
// commitments.
sweepReqs, err := c.prepareAnchorSweeps(heightHint, anchors)
if err != nil {
return err
}
// Send out the sweeping requests to the sweeper.
for _, req := range sweepReqs {
_, err = c.cfg.Sweeper.SweepInput(req.input, req.params)
if err != nil {
return err
}
}
return nil
}
// findCommitmentDeadlineAndValue finds the deadline (relative block height)
// for a commitment transaction by extracting the minimum CLTV from its HTLCs.
// From our PoV, the deadline delta is defined to be the smaller of,
// - half of the least CLTV from outgoing HTLCs' corresponding incoming
// HTLCs, or,
// - half of the least CLTV from incoming HTLCs if the preimage is available.
//
// We use half of the CTLV value to ensure that we have enough time to sweep
// the second-level HTLCs.
//
// It also finds the total value that are time-sensitive, which is the sum of
// all the outgoing HTLCs plus incoming HTLCs whose preimages are known. It
// then returns the value left after subtracting the budget used for sweeping
// the time-sensitive HTLCs.
//
// NOTE: when the deadline turns out to be 0 blocks, we will replace it with 1
// block because our fee estimator doesn't allow a 0 conf target. This also
// means we've left behind and should increase our fee to make the transaction
// confirmed asap.
func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32,
htlcs htlcSet) (fn.Option[int32], btcutil.Amount, error) {
deadlineMinHeight := uint32(math.MaxUint32)
totalValue := btcutil.Amount(0)
// First, iterate through the outgoingHTLCs to find the lowest CLTV
// value.
for _, htlc := range htlcs.outgoingHTLCs {
// Skip if the HTLC is dust.
if htlc.OutputIndex < 0 {
log.Debugf("ChannelArbitrator(%v): skipped deadline "+
"for dust htlc=%x",
c.cfg.ChanPoint, htlc.RHash[:])
continue
}
value := htlc.Amt.ToSatoshis()
// Find the expiry height for this outgoing HTLC's incoming
// HTLC.
deadlineOpt := c.cfg.FindOutgoingHTLCDeadline(htlc)
// The deadline is default to the current deadlineMinHeight,
// and it's overwritten when it's not none.
deadline := deadlineMinHeight
deadlineOpt.WhenSome(func(d int32) {
deadline = uint32(d)
// We only consider the value is under protection when
// it's time-sensitive.
totalValue += value
})
if deadline < deadlineMinHeight {
deadlineMinHeight = deadline
log.Tracef("ChannelArbitrator(%v): outgoing HTLC has "+
"deadline=%v, value=%v", c.cfg.ChanPoint,
deadlineMinHeight, value)
}
}
// Then going through the incomingHTLCs, and update the minHeight when
// conditions met.
for _, htlc := range htlcs.incomingHTLCs {
// Skip if the HTLC is dust.
if htlc.OutputIndex < 0 {
log.Debugf("ChannelArbitrator(%v): skipped deadline "+
"for dust htlc=%x",
c.cfg.ChanPoint, htlc.RHash[:])
continue
}
// Since it's an HTLC sent to us, check if we have preimage for
// this HTLC.
preimageAvailable, err := c.isPreimageAvailable(htlc.RHash)
if err != nil {
return fn.None[int32](), 0, err
}
if !preimageAvailable {
continue
}
value := htlc.Amt.ToSatoshis()
totalValue += value
if htlc.RefundTimeout < deadlineMinHeight {
deadlineMinHeight = htlc.RefundTimeout
log.Tracef("ChannelArbitrator(%v): incoming HTLC has "+
"deadline=%v, amt=%v", c.cfg.ChanPoint,
deadlineMinHeight, value)
}
}
// Calculate the deadline. There are two cases to be handled here,
// - when the deadlineMinHeight never gets updated, which could
// happen when we have no outgoing HTLCs, and, for incoming HTLCs,
// * either we have none, or,
// * none of the HTLCs are preimageAvailable.
// - when our deadlineMinHeight is no greater than the heightHint,
// which means we are behind our schedule.
var deadline uint32
switch {
// When we couldn't find a deadline height from our HTLCs, we will fall
// back to the default value as there's no time pressure here.
case deadlineMinHeight == math.MaxUint32:
return fn.None[int32](), 0, nil
// When the deadline is passed, we will fall back to the smallest conf
// target (1 block).
case deadlineMinHeight <= heightHint:
log.Warnf("ChannelArbitrator(%v): deadline is passed with "+
"deadlineMinHeight=%d, heightHint=%d",
c.cfg.ChanPoint, deadlineMinHeight, heightHint)
deadline = 1
// Use half of the deadline delta, and leave the other half to be used
// to sweep the HTLCs.
default:
deadline = (deadlineMinHeight - heightHint) / 2
}
// Calculate the value left after subtracting the budget used for
// sweeping the time-sensitive HTLCs.
valueLeft := totalValue - calculateBudget(
totalValue, c.cfg.Budget.DeadlineHTLCRatio,
c.cfg.Budget.DeadlineHTLC,
)
log.Debugf("ChannelArbitrator(%v): calculated valueLeft=%v, "+
"deadline=%d, using deadlineMinHeight=%d, heightHint=%d",
c.cfg.ChanPoint, valueLeft, deadline, deadlineMinHeight,
heightHint)
return fn.Some(int32(deadline)), valueLeft, nil
}
// resolveContracts updates the activeResolvers list and starts to resolve each
// contract concurrently, and launches them.
func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) {
c.activeResolversLock.Lock()
c.activeResolvers = resolvers
c.activeResolversLock.Unlock()
// Launch all resolvers.
c.launchResolvers()
for _, contract := range resolvers {
c.wg.Add(1)
go c.resolveContract(contract)
}
}
// launchResolvers launches all the active resolvers concurrently.
func (c *ChannelArbitrator) launchResolvers() {
c.activeResolversLock.Lock()
resolvers := c.activeResolvers
c.activeResolversLock.Unlock()
// errChans is a map of channels that will be used to receive errors
// returned from launching the resolvers.
errChans := make(map[ContractResolver]chan error, len(resolvers))
// Launch each resolver in goroutines.
for _, r := range resolvers {
// If the contract is already resolved, there's no need to
// launch it again.
if r.IsResolved() {
log.Debugf("ChannelArbitrator(%v): skipping resolver "+
"%T as it's already resolved", c.cfg.ChanPoint,
r)
continue
}
// Create a signal chan.
errChan := make(chan error, 1)
errChans[r] = errChan
go func() {
err := r.Launch()
errChan <- err
}()
}
// Wait for all resolvers to finish launching.
for r, errChan := range errChans {
select {
case err := <-errChan:
if err == nil {
continue
}
log.Errorf("ChannelArbitrator(%v): unable to launch "+
"contract resolver(%T): %v", c.cfg.ChanPoint, r,
err)
case <-c.quit:
log.Debugf("ChannelArbitrator quit signal received, " +
"exit launchResolvers")
return
}
}
}
// advanceState is the main driver of our state machine. This method is an
// iterative function which repeatedly attempts to advance the internal state
// of the channel arbitrator. The state will be advanced until we reach a
// redundant transition, meaning that the state transition is a noop. The final
// param is a callback that allows the caller to execute an arbitrary action
// after each state transition.
func (c *ChannelArbitrator) advanceState(
triggerHeight uint32, trigger transitionTrigger,
confCommitSet *CommitSet) (ArbitratorState, *wire.MsgTx, error) {
var (
priorState ArbitratorState
forceCloseTx *wire.MsgTx
)
// We'll continue to advance our state forward until the state we
// transition to is that same state that we started at.
for {
priorState = c.state
log.Debugf("ChannelArbitrator(%v): attempting state step with "+
"trigger=%v from state=%v at height=%v",
c.cfg.ChanPoint, trigger, priorState, triggerHeight)
nextState, closeTx, err := c.stateStep(
triggerHeight, trigger, confCommitSet,
)
if err != nil {
log.Errorf("ChannelArbitrator(%v): unable to advance "+
"state: %v", c.cfg.ChanPoint, err)
return priorState, nil, err
}
if forceCloseTx == nil && closeTx != nil {
forceCloseTx = closeTx
}
// Our termination transition is a noop transition. If we get
// our prior state back as the next state, then we'll
// terminate.
if nextState == priorState {
log.Debugf("ChannelArbitrator(%v): terminating at "+
"state=%v", c.cfg.ChanPoint, nextState)
return nextState, forceCloseTx, nil
}
// As the prior state was successfully executed, we can now
// commit the next state. This ensures that we will re-execute
// the prior state if anything fails.
if err := c.log.CommitState(nextState); err != nil {
log.Errorf("ChannelArbitrator(%v): unable to commit "+
"next state(%v): %v", c.cfg.ChanPoint,
nextState, err)
return priorState, nil, err
}
c.state = nextState
}
}
// ChainAction is an enum that encompasses all possible on-chain actions
// we'll take for a set of HTLC's.
type ChainAction uint8
const (
// NoAction is the min chainAction type, indicating that no action
// needs to be taken for a given HTLC.
NoAction ChainAction = 0
// HtlcTimeoutAction indicates that the HTLC will timeout soon. As a
// result, we should get ready to sweep it on chain after the timeout.
HtlcTimeoutAction = 1
// HtlcClaimAction indicates that we should claim the HTLC on chain
// before its timeout period.
HtlcClaimAction = 2
// HtlcFailDustAction indicates that we should fail the upstream HTLC
// for an outgoing dust HTLC immediately (even before the commitment
// transaction is confirmed) because it has no output on the commitment
// transaction. This also includes remote pending outgoing dust HTLCs.
HtlcFailDustAction = 3
// HtlcOutgoingWatchAction indicates that we can't yet timeout this
// HTLC, but we had to go to chain on order to resolve an existing
// HTLC. In this case, we'll either: time it out once it expires, or
// will learn the pre-image if the remote party claims the output. In
// this case, well add the pre-image to our global store.
HtlcOutgoingWatchAction = 4
// HtlcIncomingWatchAction indicates that we don't yet have the
// pre-image to claim incoming HTLC, but we had to go to chain in order
// to resolve and existing HTLC. In this case, we'll either: let the
// other party time it out, or eventually learn of the pre-image, in
// which case we'll claim on chain.
HtlcIncomingWatchAction = 5
// HtlcIncomingDustFinalAction indicates that we should mark an incoming
// dust htlc as final because it can't be claimed on-chain.
HtlcIncomingDustFinalAction = 6
// HtlcFailDanglingAction indicates that we should fail the upstream
// HTLC for an outgoing HTLC immediately after the commitment
// transaction has confirmed because it has no corresponding output on
// the commitment transaction. This category does NOT include any dust
// HTLCs which are mapped in the "HtlcFailDustAction" category.
HtlcFailDanglingAction = 7
)
// String returns a human readable string describing a chain action.
func (c ChainAction) String() string {
switch c {
case NoAction:
return "NoAction"
case HtlcTimeoutAction:
return "HtlcTimeoutAction"
case HtlcClaimAction:
return "HtlcClaimAction"
case HtlcFailDustAction:
return "HtlcFailDustAction"
case HtlcOutgoingWatchAction:
return "HtlcOutgoingWatchAction"
case HtlcIncomingWatchAction:
return "HtlcIncomingWatchAction"
case HtlcIncomingDustFinalAction:
return "HtlcIncomingDustFinalAction"
case HtlcFailDanglingAction:
return "HtlcFailDanglingAction"
default:
return "<unknown action>"
}
}
// ChainActionMap is a map of a chain action, to the set of HTLC's that need to
// be acted upon for a given action type. The channel
type ChainActionMap map[ChainAction][]channeldb.HTLC
// Merge merges the passed chain actions with the target chain action map.
func (c ChainActionMap) Merge(actions ChainActionMap) {
for chainAction, htlcs := range actions {
c[chainAction] = append(c[chainAction], htlcs...)
}
}
// shouldGoOnChain takes into account the absolute timeout of the HTLC, if the
// confirmation delta that we need is close, and returns a bool indicating if
// we should go on chain to claim. We do this rather than waiting up until the
// last minute as we want to ensure that when we *need* (HTLC is timed out) to
// sweep, the commitment is already confirmed.
func (c *ChannelArbitrator) shouldGoOnChain(htlc channeldb.HTLC,
broadcastDelta, currentHeight uint32) bool {
// We'll calculate the broadcast cut off for this HTLC. This is the
// height that (based on our current fee estimation) we should
// broadcast in order to ensure the commitment transaction is confirmed
// before the HTLC fully expires.
broadcastCutOff := htlc.RefundTimeout - broadcastDelta
log.Tracef("ChannelArbitrator(%v): examining outgoing contract: "+
"expiry=%v, cutoff=%v, height=%v", c.cfg.ChanPoint, htlc.RefundTimeout,
broadcastCutOff, currentHeight)
// TODO(roasbeef): take into account default HTLC delta, don't need to
// broadcast immediately
// * can then batch with SINGLE | ANYONECANPAY
// We should on-chain for this HTLC, iff we're within out broadcast
// cutoff window.
if currentHeight < broadcastCutOff {
return false
}
// In case of incoming htlc we should go to chain.
if htlc.Incoming {
return true
}
// For htlcs that are result of our initiated payments we give some grace
// period before force closing the channel. During this time we expect
// both nodes to connect and give a chance to the other node to send its
// updates and cancel the htlc.
// This shouldn't add any security risk as there is no incoming htlc to
// fulfill at this case and the expectation is that when the channel is
// active the other node will send update_fail_htlc to remove the htlc
// without closing the channel. It is up to the user to force close the
// channel if the peer misbehaves and doesn't send the update_fail_htlc.
// It is useful when this node is most of the time not online and is
// likely to miss the time slot where the htlc may be cancelled.
isForwarded := c.cfg.IsForwardedHTLC(c.cfg.ShortChanID, htlc.HtlcIndex)
upTime := c.cfg.Clock.Now().Sub(c.startTimestamp)
return isForwarded || upTime > c.cfg.PaymentsExpirationGracePeriod
}
// checkCommitChainActions is called for each new block connected to the end of
// the main chain. Given the new block height, this new method will examine all
// active HTLC's, and determine if we need to go on-chain to claim any of them.
// A map of action -> []htlc is returned, detailing what action (if any) should
// be performed for each HTLC. For timed out HTLC's, once the commitment has
// been sufficiently confirmed, the HTLC's should be canceled backwards. For
// redeemed HTLC's, we should send the pre-image back to the incoming link.
func (c *ChannelArbitrator) checkCommitChainActions(height uint32,
trigger transitionTrigger, htlcs htlcSet) (ChainActionMap, error) {
// TODO(roasbeef): would need to lock channel? channel totem?
// * race condition if adding and we broadcast, etc
// * or would make each instance sync?
log.Debugf("ChannelArbitrator(%v): checking commit chain actions at "+
"height=%v, in_htlc_count=%v, out_htlc_count=%v",
c.cfg.ChanPoint, height,
len(htlcs.incomingHTLCs), len(htlcs.outgoingHTLCs))
actionMap := make(ChainActionMap)
// First, we'll make an initial pass over the set of incoming and
// outgoing HTLC's to decide if we need to go on chain at all.
haveChainActions := false
for _, htlc := range htlcs.outgoingHTLCs {
// We'll need to go on-chain for an outgoing HTLC if it was
// never resolved downstream, and it's "close" to timing out.
//
// TODO(yy): If there's no corresponding incoming HTLC, it
// means we are the first hop, hence the payer. This is a
// tricky case - unlike a forwarding hop, we don't have an
// incoming HTLC that will time out, which means as long as we
// can learn the preimage, we can settle the invoice (before it
// expires?).
toChain := c.shouldGoOnChain(
htlc, c.cfg.OutgoingBroadcastDelta, height,
)
if toChain {
// Convert to int64 in case of overflow.
remainingBlocks := int64(htlc.RefundTimeout) -
int64(height)
log.Infof("ChannelArbitrator(%v): go to chain for "+
"outgoing htlc %x: timeout=%v, amount=%v, "+
"blocks_until_expiry=%v, broadcast_delta=%v",
c.cfg.ChanPoint, htlc.RHash[:],
htlc.RefundTimeout, htlc.Amt, remainingBlocks,
c.cfg.OutgoingBroadcastDelta,
)
}
haveChainActions = haveChainActions || toChain
}
for _, htlc := range htlcs.incomingHTLCs {
// We'll need to go on-chain to pull an incoming HTLC iff we
// know the pre-image and it's close to timing out. We need to
// ensure that we claim the funds that are rightfully ours
// on-chain.
preimageAvailable, err := c.isPreimageAvailable(htlc.RHash)
if err != nil {
return nil, err
}
if !preimageAvailable {
continue
}
toChain := c.shouldGoOnChain(
htlc, c.cfg.IncomingBroadcastDelta, height,
)
if toChain {
// Convert to int64 in case of overflow.
remainingBlocks := int64(htlc.RefundTimeout) -
int64(height)
log.Infof("ChannelArbitrator(%v): go to chain for "+
"incoming htlc %x: timeout=%v, amount=%v, "+
"blocks_until_expiry=%v, broadcast_delta=%v",
c.cfg.ChanPoint, htlc.RHash[:],
htlc.RefundTimeout, htlc.Amt, remainingBlocks,
c.cfg.IncomingBroadcastDelta,
)
}
haveChainActions = haveChainActions || toChain
}
// If we don't have any actions to make, then we'll return an empty
// action map. We only do this if this was a chain trigger though, as
// if we're going to broadcast the commitment (or the remote party did)
// we're *forced* to act on each HTLC.
if !haveChainActions && trigger == chainTrigger {
log.Tracef("ChannelArbitrator(%v): no actions to take at "+
"height=%v", c.cfg.ChanPoint, height)
return actionMap, nil
}
// Now that we know we'll need to go on-chain, we'll examine all of our
// active outgoing HTLC's to see if we either need to: sweep them after
// a timeout (then cancel backwards), cancel them backwards
// immediately, or watch them as they're still active contracts.
for _, htlc := range htlcs.outgoingHTLCs {
switch {
// If the HTLC is dust, then we can cancel it backwards
// immediately as there's no matching contract to arbitrate
// on-chain. We know the HTLC is dust, if the OutputIndex
// negative.
case htlc.OutputIndex < 0:
log.Tracef("ChannelArbitrator(%v): immediately "+
"failing dust htlc=%x", c.cfg.ChanPoint,
htlc.RHash[:])
actionMap[HtlcFailDustAction] = append(
actionMap[HtlcFailDustAction], htlc,
)
// If we don't need to immediately act on this HTLC, then we'll
// mark it still "live". After we broadcast, we'll monitor it
// until the HTLC times out to see if we can also redeem it
// on-chain.
case !c.shouldGoOnChain(htlc, c.cfg.OutgoingBroadcastDelta,
height,
):
// TODO(roasbeef): also need to be able to query
// circuit map to see if HTLC hasn't been fully
// resolved
//
// * can't fail incoming until if outgoing not yet
// failed
log.Tracef("ChannelArbitrator(%v): watching chain to "+
"decide action for outgoing htlc=%x",
c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcOutgoingWatchAction] = append(
actionMap[HtlcOutgoingWatchAction], htlc,
)
// Otherwise, we'll update our actionMap to mark that we need
// to sweep this HTLC on-chain
default:
log.Tracef("ChannelArbitrator(%v): going on-chain to "+
"timeout htlc=%x", c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcTimeoutAction] = append(
actionMap[HtlcTimeoutAction], htlc,
)
}
}
// Similarly, for each incoming HTLC, now that we need to go on-chain,
// we'll either: sweep it immediately if we know the pre-image, or
// observe the output on-chain if we don't In this last, case we'll
// either learn of it eventually from the outgoing HTLC, or the sender
// will timeout the HTLC.
for _, htlc := range htlcs.incomingHTLCs {
// If the HTLC is dust, there is no action to be taken.
if htlc.OutputIndex < 0 {
log.Debugf("ChannelArbitrator(%v): no resolution "+
"needed for incoming dust htlc=%x",
c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcIncomingDustFinalAction] = append(
actionMap[HtlcIncomingDustFinalAction], htlc,
)
continue
}
log.Tracef("ChannelArbitrator(%v): watching chain to decide "+
"action for incoming htlc=%x", c.cfg.ChanPoint,
htlc.RHash[:])
actionMap[HtlcIncomingWatchAction] = append(
actionMap[HtlcIncomingWatchAction], htlc,
)
}
return actionMap, nil
}
// isPreimageAvailable returns whether the hash preimage is available in either
// the preimage cache or the invoice database.
func (c *ChannelArbitrator) isPreimageAvailable(hash lntypes.Hash) (bool,
error) {
// Start by checking the preimage cache for preimages of
// forwarded HTLCs.
_, preimageAvailable := c.cfg.PreimageDB.LookupPreimage(
hash,
)
if preimageAvailable {
return true, nil
}
// Then check if we have an invoice that can be settled by this HTLC.
//
// TODO(joostjager): Check that there are still more blocks remaining
// than the invoice cltv delta. We don't want to go to chain only to
// have the incoming contest resolver decide that we don't want to
// settle this invoice.
invoice, err := c.cfg.Registry.LookupInvoice(context.Background(), hash)
switch {
case err == nil:
case errors.Is(err, invoices.ErrInvoiceNotFound) ||
errors.Is(err, invoices.ErrNoInvoicesCreated):
return false, nil
default:
return false, err
}
preimageAvailable = invoice.Terms.PaymentPreimage != nil
return preimageAvailable, nil
}
// checkLocalChainActions is similar to checkCommitChainActions, but it also
// examines the set of HTLCs on the remote party's commitment. This allows us
// to ensure we're able to satisfy the HTLC timeout constraints for incoming vs
// outgoing HTLCs.
func (c *ChannelArbitrator) checkLocalChainActions(
height uint32, trigger transitionTrigger,
activeHTLCs map[HtlcSetKey]htlcSet,
commitsConfirmed bool) (ChainActionMap, error) {
// First, we'll check our local chain actions as normal. This will only
// examine HTLCs on our local commitment (timeout or settle).
localCommitActions, err := c.checkCommitChainActions(
height, trigger, activeHTLCs[LocalHtlcSet],
)
if err != nil {
return nil, err
}
// Next, we'll examine the remote commitment (and maybe a dangling one)
// to see if the set difference of our HTLCs is non-empty. If so, then
// we may need to cancel back some HTLCs if we decide go to chain.
remoteDanglingActions := c.checkRemoteDanglingActions(
height, activeHTLCs, commitsConfirmed,
)
// Finally, we'll merge the two set of chain actions.
localCommitActions.Merge(remoteDanglingActions)
return localCommitActions, nil
}
// checkRemoteDanglingActions examines the set of remote commitments for any
// HTLCs that are close to timing out. If we find any, then we'll return a set
// of chain actions for HTLCs that are on our commitment, but not theirs to
// cancel immediately.
func (c *ChannelArbitrator) checkRemoteDanglingActions(
height uint32, activeHTLCs map[HtlcSetKey]htlcSet,
commitsConfirmed bool) ChainActionMap {
var (
pendingRemoteHTLCs []channeldb.HTLC
localHTLCs = make(map[uint64]struct{})
remoteHTLCs = make(map[uint64]channeldb.HTLC)
actionMap = make(ChainActionMap)
)
// First, we'll construct two sets of the outgoing HTLCs: those on our
// local commitment, and those that are on the remote commitment(s).
for htlcSetKey, htlcs := range activeHTLCs {
if htlcSetKey.IsRemote {
for _, htlc := range htlcs.outgoingHTLCs {
remoteHTLCs[htlc.HtlcIndex] = htlc
}
} else {
for _, htlc := range htlcs.outgoingHTLCs {
localHTLCs[htlc.HtlcIndex] = struct{}{}
}
}
}
// With both sets constructed, we'll now compute the set difference of
// our two sets of HTLCs. This'll give us the HTLCs that exist on the
// remote commitment transaction, but not on ours.
for htlcIndex, htlc := range remoteHTLCs {
if _, ok := localHTLCs[htlcIndex]; ok {
continue
}
pendingRemoteHTLCs = append(pendingRemoteHTLCs, htlc)
}
// Finally, we'll examine all the pending remote HTLCs for those that
// have expired. If we find any, then we'll recommend that they be
// failed now so we can free up the incoming HTLC.
for _, htlc := range pendingRemoteHTLCs {
// We'll now check if we need to go to chain in order to cancel
// the incoming HTLC.
goToChain := c.shouldGoOnChain(htlc, c.cfg.OutgoingBroadcastDelta,
height,
)
// If we don't need to go to chain, and no commitments have
// been confirmed, then we can move on. Otherwise, if
// commitments have been confirmed, then we need to cancel back
// *all* of the pending remote HTLCS.
if !goToChain && !commitsConfirmed {
continue
}
preimageAvailable, err := c.isPreimageAvailable(htlc.RHash)
if err != nil {
log.Errorf("ChannelArbitrator(%v): failed to query "+
"preimage for dangling htlc=%x from remote "+
"commitments diff", c.cfg.ChanPoint,
htlc.RHash[:])
continue
}
if preimageAvailable {
continue
}
// Dust htlcs can be canceled back even before the commitment
// transaction confirms. Dust htlcs are not enforceable onchain.
// If another version of the commit tx would confirm we either
// gain or lose those dust amounts but there is no other way
// than cancelling the incoming back because we will never learn
// the preimage.
if htlc.OutputIndex < 0 {
log.Infof("ChannelArbitrator(%v): fail dangling dust "+
"htlc=%x from local/remote commitments diff",
c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcFailDustAction] = append(
actionMap[HtlcFailDustAction], htlc,
)
continue
}
log.Infof("ChannelArbitrator(%v): fail dangling htlc=%x from "+
"local/remote commitments diff",
c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcFailDanglingAction] = append(
actionMap[HtlcFailDanglingAction], htlc,
)
}
return actionMap
}
// checkRemoteChainActions examines the two possible remote commitment chains
// and returns the set of chain actions we need to carry out if the remote
// commitment (non pending) confirms. The pendingConf indicates if the pending
// remote commitment confirmed. This is similar to checkCommitChainActions, but
// we'll immediately fail any HTLCs on the pending remote commit, but not the
// remote commit (or the other way around).
func (c *ChannelArbitrator) checkRemoteChainActions(
height uint32, trigger transitionTrigger,
activeHTLCs map[HtlcSetKey]htlcSet,
pendingConf bool) (ChainActionMap, error) {
// First, we'll examine all the normal chain actions on the remote
// commitment that confirmed.
confHTLCs := activeHTLCs[RemoteHtlcSet]
if pendingConf {
confHTLCs = activeHTLCs[RemotePendingHtlcSet]
}
remoteCommitActions, err := c.checkCommitChainActions(
height, trigger, confHTLCs,
)
if err != nil {
return nil, err
}
// With these actions computed, we'll now check the diff of the HTLCs on
// the commitments, and cancel back any that are on the pending but not
// the non-pending.
remoteDiffActions := c.checkRemoteDiffActions(
activeHTLCs, pendingConf,
)
// Finally, we'll merge all the chain actions and the final set of
// chain actions.
remoteCommitActions.Merge(remoteDiffActions)
return remoteCommitActions, nil
}
// checkRemoteDiffActions checks the set difference of the HTLCs on the remote
// confirmed commit and remote pending commit for HTLCS that we need to cancel
// back. If we find any HTLCs on the remote pending but not the remote, then
// we'll mark them to be failed immediately.
func (c *ChannelArbitrator) checkRemoteDiffActions(
activeHTLCs map[HtlcSetKey]htlcSet,
pendingConf bool) ChainActionMap {
// First, we'll partition the HTLCs into those that are present on the
// confirmed commitment, and those on the dangling commitment.
confHTLCs := activeHTLCs[RemoteHtlcSet]
danglingHTLCs := activeHTLCs[RemotePendingHtlcSet]
if pendingConf {
confHTLCs = activeHTLCs[RemotePendingHtlcSet]
danglingHTLCs = activeHTLCs[RemoteHtlcSet]
}
// Next, we'll create a set of all the HTLCs confirmed commitment.
remoteHtlcs := make(map[uint64]struct{})
for _, htlc := range confHTLCs.outgoingHTLCs {
remoteHtlcs[htlc.HtlcIndex] = struct{}{}
}
// With the remote HTLCs assembled, we'll mark any HTLCs only on the
// remote pending commitment to be failed asap.
actionMap := make(ChainActionMap)
for _, htlc := range danglingHTLCs.outgoingHTLCs {
if _, ok := remoteHtlcs[htlc.HtlcIndex]; ok {
continue
}
preimageAvailable, err := c.isPreimageAvailable(htlc.RHash)
if err != nil {
log.Errorf("ChannelArbitrator(%v): failed to query "+
"preimage for dangling htlc=%x from remote "+
"commitments diff", c.cfg.ChanPoint,
htlc.RHash[:])
continue
}
if preimageAvailable {
continue
}
// Dust HTLCs on the remote commitment can be failed back.
if htlc.OutputIndex < 0 {
log.Infof("ChannelArbitrator(%v): fail dangling dust "+
"htlc=%x from remote commitments diff",
c.cfg.ChanPoint, htlc.RHash[:])
actionMap[HtlcFailDustAction] = append(
actionMap[HtlcFailDustAction], htlc,
)
continue
}
actionMap[HtlcFailDanglingAction] = append(
actionMap[HtlcFailDanglingAction], htlc,
)
log.Infof("ChannelArbitrator(%v): fail dangling htlc=%x from "+
"remote commitments diff",
c.cfg.ChanPoint, htlc.RHash[:])
}
return actionMap
}
// constructChainActions returns the set of actions that should be taken for
// confirmed HTLCs at the specified height. Our actions will depend on the set
// of HTLCs that were active across all channels at the time of channel
// closure.
func (c *ChannelArbitrator) constructChainActions(confCommitSet *CommitSet,
height uint32, trigger transitionTrigger) (ChainActionMap, error) {
// If we've reached this point and have not confirmed commitment set,
// then this is an older node that had a pending close channel before
// the CommitSet was introduced. In this case, we'll just return the
// existing ChainActionMap they had on disk.
if confCommitSet == nil || confCommitSet.ConfCommitKey.IsNone() {
return c.log.FetchChainActions()
}
// Otherwise, we have the full commitment set written to disk, and can
// proceed as normal.
htlcSets := confCommitSet.toActiveHTLCSets()
confCommitKey, err := confCommitSet.ConfCommitKey.UnwrapOrErr(
fmt.Errorf("no commitKey available"),
)
if err != nil {
return nil, err
}
switch confCommitKey {
// If the local commitment transaction confirmed, then we'll examine
// that as well as their commitments to the set of chain actions.
case LocalHtlcSet:
return c.checkLocalChainActions(
height, trigger, htlcSets, true,
)
// If the remote commitment confirmed, then we'll grab all the chain
// actions for the remote commit, and check the pending commit for any
// HTLCS we need to handle immediately (dust).
case RemoteHtlcSet:
return c.checkRemoteChainActions(
height, trigger, htlcSets, false,
)
// Otherwise, the remote pending commitment confirmed, so we'll examine
// the HTLCs on that unrevoked dangling commitment.
case RemotePendingHtlcSet:
return c.checkRemoteChainActions(
height, trigger, htlcSets, true,
)
}
return nil, fmt.Errorf("unable to locate chain actions")
}
// prepContractResolutions is called either in the case that we decide we need
// to go to chain, or the remote party goes to chain. Given a set of actions we
// need to take for each HTLC, this method will return a set of contract
// resolvers that will resolve the contracts on-chain if needed, and also a set
// of packets to send to the htlcswitch in order to ensure all incoming HTLC's
// are properly resolved.
func (c *ChannelArbitrator) prepContractResolutions(
contractResolutions *ContractResolutions, height uint32,
htlcActions ChainActionMap) ([]ContractResolver, error) {
// We'll also fetch the historical state of this channel, as it should
// have been marked as closed by now, and supplement it to each resolver
// such that we can properly resolve our pending contracts.
var chanState *channeldb.OpenChannel
chanState, err := c.cfg.FetchHistoricalChannel()
switch {
// If we don't find this channel, then it may be the case that it
// was closed before we started to retain the final state
// information for open channels.
case err == channeldb.ErrNoHistoricalBucket:
fallthrough
case err == channeldb.ErrChannelNotFound:
log.Warnf("ChannelArbitrator(%v): unable to fetch historical "+
"state", c.cfg.ChanPoint)
case err != nil:
return nil, err
}
incomingResolutions := contractResolutions.HtlcResolutions.IncomingHTLCs
outgoingResolutions := contractResolutions.HtlcResolutions.OutgoingHTLCs
// We'll use these two maps to quickly look up an active HTLC with its
// matching HTLC resolution.
outResolutionMap := make(map[wire.OutPoint]lnwallet.OutgoingHtlcResolution)
inResolutionMap := make(map[wire.OutPoint]lnwallet.IncomingHtlcResolution)
for i := 0; i < len(incomingResolutions); i++ {
inRes := incomingResolutions[i]
inResolutionMap[inRes.HtlcPoint()] = inRes
}
for i := 0; i < len(outgoingResolutions); i++ {
outRes := outgoingResolutions[i]
outResolutionMap[outRes.HtlcPoint()] = outRes
}
// We'll create the resolver kit that we'll be cloning for each
// resolver so they each can do their duty.
resolverCfg := ResolverConfig{
ChannelArbitratorConfig: c.cfg,
Checkpoint: func(res ContractResolver,
reports ...*channeldb.ResolverReport) error {
return c.log.InsertUnresolvedContracts(reports, res)
},
}
commitHash := contractResolutions.CommitHash
var htlcResolvers []ContractResolver
// We instantiate an anchor resolver if the commitment tx has an
// anchor.
if contractResolutions.AnchorResolution != nil {
anchorResolver := newAnchorResolver(
contractResolutions.AnchorResolution.AnchorSignDescriptor,
contractResolutions.AnchorResolution.CommitAnchor,
height, c.cfg.ChanPoint, resolverCfg,
)
anchorResolver.SupplementState(chanState)
htlcResolvers = append(htlcResolvers, anchorResolver)
}
// If this is a breach close, we'll create a breach resolver, determine
// the htlc's to fail back, and exit. This is done because the other
// steps taken for non-breach-closes do not matter for breach-closes.
if contractResolutions.BreachResolution != nil {
breachResolver := newBreachResolver(resolverCfg)
htlcResolvers = append(htlcResolvers, breachResolver)
return htlcResolvers, nil
}
// For each HTLC, we'll either act immediately, meaning we'll instantly
// fail the HTLC, or we'll act only once the transaction has been
// confirmed, in which case we'll need an HTLC resolver.
for htlcAction, htlcs := range htlcActions {
switch htlcAction {
// If we can claim this HTLC, we'll create an HTLC resolver to
// claim the HTLC (second-level or directly), then add the pre
case HtlcClaimAction:
for _, htlc := range htlcs {
htlc := htlc
htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
resolution, ok := inResolutionMap[htlcOp]
if !ok {
// TODO(roasbeef): panic?
log.Errorf("ChannelArbitrator(%v) unable to find "+
"incoming resolution: %v",
c.cfg.ChanPoint, htlcOp)
continue
}
resolver := newSuccessResolver(
resolution, height, htlc, resolverCfg,
)
if chanState != nil {
resolver.SupplementState(chanState)
}
htlcResolvers = append(htlcResolvers, resolver)
}
// If we can timeout the HTLC directly, then we'll create the
// proper resolver to do so, who will then cancel the packet
// backwards.
case HtlcTimeoutAction:
for _, htlc := range htlcs {
htlc := htlc
htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
resolution, ok := outResolutionMap[htlcOp]
if !ok {
log.Errorf("ChannelArbitrator(%v) unable to find "+
"outgoing resolution: %v", c.cfg.ChanPoint, htlcOp)
continue
}
resolver := newTimeoutResolver(
resolution, height, htlc, resolverCfg,
)
if chanState != nil {
resolver.SupplementState(chanState)
}
// For outgoing HTLCs, we will also need to
// supplement the resolver with the expiry
// block height of its corresponding incoming
// HTLC.
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
resolver.SupplementDeadline(deadline)
htlcResolvers = append(htlcResolvers, resolver)
}
// If this is an incoming HTLC, but we can't act yet, then
// we'll create an incoming resolver to redeem the HTLC if we
// learn of the pre-image, or let the remote party time out.
case HtlcIncomingWatchAction:
for _, htlc := range htlcs {
htlc := htlc
htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
// TODO(roasbeef): need to handle incoming dust...
// TODO(roasbeef): can't be negative!!!
resolution, ok := inResolutionMap[htlcOp]
if !ok {
log.Errorf("ChannelArbitrator(%v) unable to find "+
"incoming resolution: %v",
c.cfg.ChanPoint, htlcOp)
continue
}
resolver := newIncomingContestResolver(
resolution, height, htlc,
resolverCfg,
)
if chanState != nil {
resolver.SupplementState(chanState)
}
htlcResolvers = append(htlcResolvers, resolver)
}
// Finally, if this is an outgoing HTLC we've sent, then we'll
// launch a resolver to watch for the pre-image (and settle
// backwards), or just timeout.
case HtlcOutgoingWatchAction:
for _, htlc := range htlcs {
htlc := htlc
htlcOp := wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex),
}
resolution, ok := outResolutionMap[htlcOp]
if !ok {
log.Errorf("ChannelArbitrator(%v) "+
"unable to find outgoing "+
"resolution: %v",
c.cfg.ChanPoint, htlcOp)
continue
}
resolver := newOutgoingContestResolver(
resolution, height, htlc, resolverCfg,
)
if chanState != nil {
resolver.SupplementState(chanState)
}
// For outgoing HTLCs, we will also need to
// supplement the resolver with the expiry
// block height of its corresponding incoming
// HTLC.
deadline := c.cfg.FindOutgoingHTLCDeadline(htlc)
resolver.SupplementDeadline(deadline)
htlcResolvers = append(htlcResolvers, resolver)
}
}
}
// If this is was an unilateral closure, then we'll also create a
// resolver to sweep our commitment output (but only if it wasn't
// trimmed).
if contractResolutions.CommitResolution != nil {
resolver := newCommitSweepResolver(
*contractResolutions.CommitResolution, height,
c.cfg.ChanPoint, resolverCfg,
)
if chanState != nil {
resolver.SupplementState(chanState)
}
htlcResolvers = append(htlcResolvers, resolver)
}
return htlcResolvers, nil
}
// replaceResolver replaces a in the list of active resolvers. If the resolver
// to be replaced is not found, it returns an error.
func (c *ChannelArbitrator) replaceResolver(oldResolver,
newResolver ContractResolver) error {
c.activeResolversLock.Lock()
defer c.activeResolversLock.Unlock()
oldKey := oldResolver.ResolverKey()
for i, r := range c.activeResolvers {
if bytes.Equal(r.ResolverKey(), oldKey) {
c.activeResolvers[i] = newResolver
return nil
}
}
return errors.New("resolver to be replaced not found")
}
// resolveContract is a goroutine tasked with fully resolving an unresolved
// contract. Either the initial contract will be resolved after a single step,
// or the contract will itself create another contract to be resolved. In
// either case, one the contract has been fully resolved, we'll signal back to
// the main goroutine so it can properly keep track of the set of unresolved
// contracts.
//
// NOTE: This MUST be run as a goroutine.
func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) {
defer c.wg.Done()
log.Tracef("ChannelArbitrator(%v): attempting to resolve %T",
c.cfg.ChanPoint, currentContract)
// Until the contract is fully resolved, we'll continue to iteratively
// resolve the contract one step at a time.
for !currentContract.IsResolved() {
log.Tracef("ChannelArbitrator(%v): contract %T not yet "+
"resolved", c.cfg.ChanPoint, currentContract)
select {
// If we've been signalled to quit, then we'll exit early.
case <-c.quit:
return
default:
// Otherwise, we'll attempt to resolve the current
// contract.
nextContract, err := currentContract.Resolve()
if err != nil {
if err == errResolverShuttingDown {
return
}
log.Errorf("ChannelArbitrator(%v): unable to "+
"progress %T: %v",
c.cfg.ChanPoint, currentContract, err)
return
}
switch {
// If this contract produced another, then this means
// the current contract was only able to be partially
// resolved in this step. So we'll do a contract swap
// within our logs: the new contract will take the
// place of the old one.
case nextContract != nil:
log.Debugf("ChannelArbitrator(%v): swapping "+
"out contract %T for %T ",
c.cfg.ChanPoint, currentContract,
nextContract)
// Swap contract in log.
err := c.log.SwapContract(
currentContract, nextContract,
)
if err != nil {
log.Errorf("unable to add recurse "+
"contract: %v", err)
}
// Swap contract in resolvers list. This is to
// make sure that reports are queried from the
// new resolver.
err = c.replaceResolver(
currentContract, nextContract,
)
if err != nil {
log.Errorf("unable to replace "+
"contract: %v", err)
}
// As this contract produced another, we'll
// re-assign, so we can continue our resolution
// loop.
currentContract = nextContract
// Launch the new contract.
err = currentContract.Launch()
if err != nil {
log.Errorf("Failed to launch %T: %v",
currentContract, err)
}
// If this contract is actually fully resolved, then
// we'll mark it as such within the database.
case currentContract.IsResolved():
log.Debugf("ChannelArbitrator(%v): marking "+
"contract %T fully resolved",
c.cfg.ChanPoint, currentContract)
err := c.log.ResolveContract(currentContract)
if err != nil {
log.Errorf("unable to resolve contract: %v",
err)
}
// Now that the contract has been resolved,
// well signal to the main goroutine.
select {
case c.resolutionSignal <- struct{}{}:
case <-c.quit:
return
}
}
}
}
}
// signalUpdateMsg is a struct that carries fresh signals to the
// ChannelArbitrator. We need to receive a message like this each time the
// channel becomes active, as it's internal state may change.
type signalUpdateMsg struct {
// newSignals is the set of new active signals to be sent to the
// arbitrator.
newSignals *ContractSignals
// doneChan is a channel that will be closed on the arbitrator has
// attached the new signals.
doneChan chan struct{}
}
// UpdateContractSignals updates the set of signals the ChannelArbitrator needs
// to receive from a channel in real-time in order to keep in sync with the
// latest state of the contract.
func (c *ChannelArbitrator) UpdateContractSignals(newSignals *ContractSignals) {
done := make(chan struct{})
select {
case c.signalUpdates <- &signalUpdateMsg{
newSignals: newSignals,
doneChan: done,
}:
case <-c.quit:
}
select {
case <-done:
case <-c.quit:
}
}
// notifyContractUpdate updates the ChannelArbitrator's unmerged mappings such
// that it can later be merged with activeHTLCs when calling
// checkLocalChainActions or sweepAnchors. These are the only two places that
// activeHTLCs is used.
func (c *ChannelArbitrator) notifyContractUpdate(upd *ContractUpdate) {
c.unmergedMtx.Lock()
defer c.unmergedMtx.Unlock()
// Update the mapping.
c.unmergedSet[upd.HtlcKey] = newHtlcSet(upd.Htlcs)
log.Tracef("ChannelArbitrator(%v): fresh set of htlcs=%v",
c.cfg.ChanPoint, lnutils.SpewLogClosure(upd))
}
// updateActiveHTLCs merges the unmerged set of HTLCs from the link with
// activeHTLCs.
func (c *ChannelArbitrator) updateActiveHTLCs() {
c.unmergedMtx.RLock()
defer c.unmergedMtx.RUnlock()
// Update the mapping.
c.activeHTLCs[LocalHtlcSet] = c.unmergedSet[LocalHtlcSet]
c.activeHTLCs[RemoteHtlcSet] = c.unmergedSet[RemoteHtlcSet]
// If the pending set exists, update that as well.
if _, ok := c.unmergedSet[RemotePendingHtlcSet]; ok {
pendingSet := c.unmergedSet[RemotePendingHtlcSet]
c.activeHTLCs[RemotePendingHtlcSet] = pendingSet
}
}
// channelAttendant is the primary goroutine that acts at the judicial
// arbitrator between our channel state, the remote channel peer, and the
// blockchain (Our judge). This goroutine will ensure that we faithfully execute
// all clauses of our contract in the case that we need to go on-chain for a
// dispute. Currently, two such conditions warrant our intervention: when an
// outgoing HTLC is about to timeout, and when we know the pre-image for an
// incoming HTLC, but it hasn't yet been settled off-chain. In these cases,
// we'll: broadcast our commitment, cancel/settle any HTLC's backwards after
// sufficient confirmation, and finally send our set of outputs to the UTXO
// Nursery for incubation, and ultimate sweeping.
//
// NOTE: This MUST be run as a goroutine.
func (c *ChannelArbitrator) channelAttendant(bestHeight int32,
commitSet *CommitSet) {
// TODO(roasbeef): tell top chain arb we're done
defer func() {
c.wg.Done()
}()
err := c.progressStateMachineAfterRestart(bestHeight, commitSet)
if err != nil {
// In case of an error, we return early but we do not shutdown
// LND, because there might be other channels that still can be
// resolved and we don't want to interfere with that.
// We continue to run the channel attendant in case the channel
// closes via other means for example the remote pary force
// closes the channel. So we log the error and continue.
log.Errorf("Unable to progress state machine after "+
"restart: %v", err)
}
for {
select {
// A new block has arrived, we'll examine all the active HTLC's
// to see if any of them have expired, and also update our
// track of the best current height.
case beat := <-c.BlockbeatChan:
bestHeight = beat.Height()
log.Debugf("ChannelArbitrator(%v): received new "+
"block: height=%v, processing...",
c.cfg.ChanPoint, bestHeight)
err := c.handleBlockbeat(beat)
if err != nil {
log.Errorf("Handle block=%v got err: %v",
bestHeight, err)
}
// If as a result of this trigger, the contract is
// fully resolved, then well exit.
if c.state == StateFullyResolved {
return
}
// A new signal update was just sent. This indicates that the
// channel under watch is now live, and may modify its internal
// state, so we'll get the most up to date signals to we can
// properly do our job.
case signalUpdate := <-c.signalUpdates:
log.Tracef("ChannelArbitrator(%v): got new signal "+
"update!", c.cfg.ChanPoint)
// We'll update the ShortChannelID.
c.cfg.ShortChanID = signalUpdate.newSignals.ShortChanID
// Now that the signal has been updated, we'll now
// close the done channel to signal to the caller we've
// registered the new ShortChannelID.
close(signalUpdate.doneChan)
// We've cooperatively closed the channel, so we're no longer
// needed. We'll mark the channel as resolved and exit.
case closeInfo := <-c.cfg.ChainEvents.CooperativeClosure:
err := c.handleCoopCloseEvent(closeInfo)
if err != nil {
log.Errorf("Failed to handle coop close: %v",
err)
return
}
// We have broadcasted our commitment, and it is now confirmed
// on-chain.
case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure:
if c.state != StateCommitmentBroadcasted {
log.Errorf("ChannelArbitrator(%v): unexpected "+
"local on-chain channel close",
c.cfg.ChanPoint)
}
err := c.handleLocalForceCloseEvent(closeInfo)
if err != nil {
log.Errorf("Failed to handle local force "+
"close: %v", err)
return
}
// The remote party has broadcast the commitment on-chain.
// We'll examine our state to determine if we need to act at
// all.
case uniClosure := <-c.cfg.ChainEvents.RemoteUnilateralClosure:
err := c.handleRemoteForceCloseEvent(uniClosure)
if err != nil {
log.Errorf("Failed to handle remote force "+
"close: %v", err)
return
}
// The remote has breached the channel. As this is handled by
// the ChainWatcher and BreachArbitrator, we don't have to do
// anything in particular, so just advance our state and
// gracefully exit.
case breachInfo := <-c.cfg.ChainEvents.ContractBreach:
err := c.handleContractBreach(breachInfo)
if err != nil {
log.Errorf("Failed to handle contract breach: "+
"%v", err)
return
}
// A new contract has just been resolved, we'll now check our
// log to see if all contracts have been resolved. If so, then
// we can exit as the contract is fully resolved.
case <-c.resolutionSignal:
log.Infof("ChannelArbitrator(%v): a contract has been "+
"fully resolved!", c.cfg.ChanPoint)
nextState, _, err := c.advanceState(
uint32(bestHeight), chainTrigger, nil,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
// If we don't have anything further to do after
// advancing our state, then we'll exit.
if nextState == StateFullyResolved {
log.Infof("ChannelArbitrator(%v): all "+
"contracts fully resolved, exiting",
c.cfg.ChanPoint)
return
}
// We've just received a request to forcibly close out the
// channel. We'll
case closeReq := <-c.forceCloseReqs:
log.Infof("ChannelArbitrator(%v): received force "+
"close request", c.cfg.ChanPoint)
if c.state != StateDefault {
select {
case closeReq.closeTx <- nil:
case <-c.quit:
}
select {
case closeReq.errResp <- errAlreadyForceClosed:
case <-c.quit:
}
continue
}
nextState, closeTx, err := c.advanceState(
uint32(bestHeight), userTrigger, nil,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
select {
case closeReq.closeTx <- closeTx:
case <-c.quit:
return
}
select {
case closeReq.errResp <- err:
case <-c.quit:
return
}
// If we don't have anything further to do after
// advancing our state, then we'll exit.
if nextState == StateFullyResolved {
log.Infof("ChannelArbitrator(%v): all "+
"contracts resolved, exiting",
c.cfg.ChanPoint)
return
}
case <-c.quit:
return
}
}
}
// handleBlockbeat processes a newly received blockbeat by advancing the
// arbitrator's internal state using the received block height.
func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error {
// Notify we've processed the block.
defer c.NotifyBlockProcessed(beat, nil)
// If the state is StateContractClosed, StateWaitingFullResolution, or
// StateFullyResolved, there's no need to read the close event channel
// since the arbitrator can only get to this state after processing a
// previous close event and launched all its resolvers.
if c.state.IsContractClosed() {
log.Infof("ChannelArbitrator(%v): skipping reading close "+
"events in state=%v", c.cfg.ChanPoint, c.state)
// Launch all active resolvers when a new blockbeat is
// received, even when the contract is closed, we still need
// this as the resolvers may transform into new ones. For
// already launched resolvers this will be NOOP as they track
// their own `launched` states.
c.launchResolvers()
return nil
}
// Perform a non-blocking read on the close events in case the channel
// is closed in this blockbeat.
c.receiveAndProcessCloseEvent()
// Try to advance the state if we are in StateDefault.
if c.state == StateDefault {
// Now that a new block has arrived, we'll attempt to advance
// our state forward.
_, _, err := c.advanceState(
uint32(beat.Height()), chainTrigger, nil,
)
if err != nil {
return fmt.Errorf("unable to advance state: %w", err)
}
}
// Launch all active resolvers when a new blockbeat is received.
c.launchResolvers()
return nil
}
// receiveAndProcessCloseEvent does a non-blocking read on all the channel
// close event channels. If an event is received, it will be further processed.
func (c *ChannelArbitrator) receiveAndProcessCloseEvent() {
select {
// Received a coop close event, we now mark the channel as resolved and
// exit.
case closeInfo := <-c.cfg.ChainEvents.CooperativeClosure:
err := c.handleCoopCloseEvent(closeInfo)
if err != nil {
log.Errorf("Failed to handle coop close: %v", err)
return
}
// We have broadcast our commitment, and it is now confirmed onchain.
case closeInfo := <-c.cfg.ChainEvents.LocalUnilateralClosure:
if c.state != StateCommitmentBroadcasted {
log.Errorf("ChannelArbitrator(%v): unexpected "+
"local on-chain channel close", c.cfg.ChanPoint)
}
err := c.handleLocalForceCloseEvent(closeInfo)
if err != nil {
log.Errorf("Failed to handle local force close: %v",
err)
return
}
// The remote party has broadcast the commitment. We'll examine our
// state to determine if we need to act at all.
case uniClosure := <-c.cfg.ChainEvents.RemoteUnilateralClosure:
err := c.handleRemoteForceCloseEvent(uniClosure)
if err != nil {
log.Errorf("Failed to handle remote force close: %v",
err)
return
}
// The remote has breached the channel! We now launch the breach
// contract resolvers.
case breachInfo := <-c.cfg.ChainEvents.ContractBreach:
err := c.handleContractBreach(breachInfo)
if err != nil {
log.Errorf("Failed to handle contract breach: %v", err)
return
}
default:
log.Infof("ChannelArbitrator(%v) no close event",
c.cfg.ChanPoint)
}
}
// Name returns a human-readable string for this subsystem.
//
// NOTE: Part of chainio.Consumer interface.
func (c *ChannelArbitrator) Name() string {
return fmt.Sprintf("ChannelArbitrator(%v)", c.cfg.ChanPoint)
}
// checkLegacyBreach returns StateFullyResolved if the channel was closed with
// a breach transaction before the channel arbitrator launched its own breach
// resolver. StateContractClosed is returned if this is a modern breach close
// with a breach resolver. StateError is returned if the log lookup failed.
func (c *ChannelArbitrator) checkLegacyBreach() (ArbitratorState, error) {
// A previous version of the channel arbitrator would make the breach
// close skip to StateFullyResolved. If there are no contract
// resolutions in the bolt arbitrator log, then this is an older breach
// close. Otherwise, if there are resolutions, the state should advance
// to StateContractClosed.
_, err := c.log.FetchContractResolutions()
if err == errNoResolutions {
// This is an older breach close still in the database.
return StateFullyResolved, nil
} else if err != nil {
return StateError, err
}
// This is a modern breach close with resolvers.
return StateContractClosed, nil
}
// sweepRequest wraps the arguments used when calling `SweepInput`.
type sweepRequest struct {
// input is the input to be swept.
input input.Input
// params holds the sweeping parameters.
params sweep.Params
}
// createSweepRequest creates an anchor sweeping request for a particular
// version (local/remote/remote pending) of the commitment.
func (c *ChannelArbitrator) createSweepRequest(
anchor *lnwallet.AnchorResolution, htlcs htlcSet, anchorPath string,
heightHint uint32) (sweepRequest, error) {
// Use the chan id as the exclusive group. This prevents any of the
// anchors from being batched together.
exclusiveGroup := c.cfg.ShortChanID.ToUint64()
// Find the deadline for this specific anchor.
deadline, value, err := c.findCommitmentDeadlineAndValue(
heightHint, htlcs,
)
if err != nil {
return sweepRequest{}, err
}
// If we cannot find a deadline, it means there's no HTLCs at stake,
// which means we can relax our anchor sweeping conditions as we don't
// have any time sensitive outputs to sweep. However we need to
// register the anchor output with the sweeper so we are later able to
// bump the close fee.
if deadline.IsNone() {
log.Infof("ChannelArbitrator(%v): no HTLCs at stake, "+
"sweeping anchor with default deadline",
c.cfg.ChanPoint)
}
witnessType := input.CommitmentAnchor
// For taproot channels, we need to use the proper witness type.
if txscript.IsPayToTaproot(
anchor.AnchorSignDescriptor.Output.PkScript,
) {
witnessType = input.TaprootAnchorSweepSpend
}
// Prepare anchor output for sweeping.
anchorInput := input.MakeBaseInput(
&anchor.CommitAnchor,
witnessType,
&anchor.AnchorSignDescriptor,
heightHint,
&input.TxInfo{
Fee: anchor.CommitFee,
Weight: anchor.CommitWeight,
},
)
// If we have a deadline, we'll use it to calculate the deadline
// height, otherwise default to none.
deadlineDesc := "None"
deadlineHeight := fn.MapOption(func(d int32) int32 {
deadlineDesc = fmt.Sprintf("%d", d)
return d + int32(heightHint)
})(deadline)
// Calculate the budget based on the value under protection, which is
// the sum of all HTLCs on this commitment subtracted by their budgets.
// The anchor output in itself has a small output value of 330 sats so
// we also include it in the budget to pay for the cpfp transaction.
budget := calculateBudget(
value, c.cfg.Budget.AnchorCPFPRatio, c.cfg.Budget.AnchorCPFP,
) + AnchorOutputValue
log.Infof("ChannelArbitrator(%v): offering anchor from %s commitment "+
"%v to sweeper with deadline=%v, budget=%v", c.cfg.ChanPoint,
anchorPath, anchor.CommitAnchor, deadlineDesc, budget)
// Sweep anchor output with a confirmation target fee preference.
// Because this is a cpfp-operation, the anchor will only be attempted
// to sweep when the current fee estimate for the confirmation target
// exceeds the commit fee rate.
return sweepRequest{
input: &anchorInput,
params: sweep.Params{
ExclusiveGroup: &exclusiveGroup,
Budget: budget,
DeadlineHeight: deadlineHeight,
},
}, nil
}
// prepareAnchorSweeps creates a list of requests to be used by the sweeper for
// all possible commitment versions.
func (c *ChannelArbitrator) prepareAnchorSweeps(heightHint uint32,
anchors *lnwallet.AnchorResolutions) ([]sweepRequest, error) {
// requests holds all the possible anchor sweep requests. We can have
// up to 3 different versions of commitments (local/remote/remote
// dangling) to be CPFPed by the anchors.
requests := make([]sweepRequest, 0, 3)
// remotePendingReq holds the request for sweeping the anchor output on
// the remote pending commitment. It's only set when there's an actual
// pending remote commitment and it's used to decide whether we need to
// update the fee budget when sweeping the anchor output on the local
// commitment.
remotePendingReq := fn.None[sweepRequest]()
// First we check on the remote pending commitment and optionally
// create an anchor sweeping request.
htlcs, ok := c.activeHTLCs[RemotePendingHtlcSet]
if ok && anchors.RemotePending != nil {
req, err := c.createSweepRequest(
anchors.RemotePending, htlcs, "remote pending",
heightHint,
)
if err != nil {
return nil, err
}
// Save the request.
requests = append(requests, req)
// Set the optional variable.
remotePendingReq = fn.Some(req)
}
// Check the local commitment and optionally create an anchor sweeping
// request. The params used in this request will be influenced by the
// anchor sweeping request made from the pending remote commitment.
htlcs, ok = c.activeHTLCs[LocalHtlcSet]
if ok && anchors.Local != nil {
req, err := c.createSweepRequest(
anchors.Local, htlcs, "local", heightHint,
)
if err != nil {
return nil, err
}
// If there's an anchor sweeping request from the pending
// remote commitment, we will compare its budget against the
// budget used here and choose the params that has a larger
// budget. The deadline when choosing the remote pending budget
// instead of the local one will always be earlier or equal to
// the local deadline because outgoing HTLCs are resolved on
// the local commitment first before they are removed from the
// remote one.
remotePendingReq.WhenSome(func(s sweepRequest) {
if s.params.Budget <= req.params.Budget {
return
}
log.Infof("ChannelArbitrator(%v): replaced local "+
"anchor(%v) sweep params with pending remote "+
"anchor sweep params, \nold:[%v], \nnew:[%v]",
c.cfg.ChanPoint, anchors.Local.CommitAnchor,
req.params, s.params)
req.params = s.params
})
// Save the request.
requests = append(requests, req)
}
// Check the remote commitment and create an anchor sweeping request if
// needed.
htlcs, ok = c.activeHTLCs[RemoteHtlcSet]
if ok && anchors.Remote != nil {
req, err := c.createSweepRequest(
anchors.Remote, htlcs, "remote", heightHint,
)
if err != nil {
return nil, err
}
requests = append(requests, req)
}
return requests, nil
}
// failIncomingDust resolves the incoming dust HTLCs because they do not have
// an output on the commitment transaction and cannot be resolved onchain. We
// mark them as failed here.
func (c *ChannelArbitrator) failIncomingDust(
incomingDustHTLCs []channeldb.HTLC) error {
for _, htlc := range incomingDustHTLCs {
if !htlc.Incoming || htlc.OutputIndex >= 0 {
return fmt.Errorf("htlc with index %v is not incoming "+
"dust", htlc.OutputIndex)
}
key := models.CircuitKey{
ChanID: c.cfg.ShortChanID,
HtlcID: htlc.HtlcIndex,
}
// Mark this dust htlc as final failed.
chainArbCfg := c.cfg.ChainArbitratorConfig
err := chainArbCfg.PutFinalHtlcOutcome(
key.ChanID, key.HtlcID, false,
)
if err != nil {
return err
}
// Send notification.
chainArbCfg.HtlcNotifier.NotifyFinalHtlcEvent(
key,
channeldb.FinalHtlcInfo{
Settled: false,
Offchain: false,
},
)
}
return nil
}
// abandonForwards cancels back the incoming HTLCs for their corresponding
// outgoing HTLCs. We use a set here to avoid sending duplicate failure messages
// for the same HTLC. This also needs to be done for locally initiated outgoing
// HTLCs they are special cased in the switch.
func (c *ChannelArbitrator) abandonForwards(htlcs fn.Set[uint64]) error {
log.Debugf("ChannelArbitrator(%v): cancelling back %v incoming "+
"HTLC(s)", c.cfg.ChanPoint,
len(htlcs))
msgsToSend := make([]ResolutionMsg, 0, len(htlcs))
failureMsg := &lnwire.FailPermanentChannelFailure{}
for idx := range htlcs {
failMsg := ResolutionMsg{
SourceChan: c.cfg.ShortChanID,
HtlcIndex: idx,
Failure: failureMsg,
}
msgsToSend = append(msgsToSend, failMsg)
}
// Send the msges to the switch, if there are any.
if len(msgsToSend) == 0 {
return nil
}
log.Debugf("ChannelArbitrator(%v): sending resolution message=%v",
c.cfg.ChanPoint, lnutils.SpewLogClosure(msgsToSend))
err := c.cfg.DeliverResolutionMsg(msgsToSend...)
if err != nil {
log.Errorf("Unable to send resolution msges to switch: %v", err)
return err
}
return nil
}
// handleCoopCloseEvent takes a coop close event from ChainEvents, marks the
// channel as closed and advances the state.
func (c *ChannelArbitrator) handleCoopCloseEvent(
closeInfo *CooperativeCloseInfo) error {
log.Infof("ChannelArbitrator(%v) marking channel cooperatively closed "+
"at height %v", c.cfg.ChanPoint, closeInfo.CloseHeight)
err := c.cfg.MarkChannelClosed(
closeInfo.ChannelCloseSummary,
channeldb.ChanStatusCoopBroadcasted,
)
if err != nil {
return fmt.Errorf("unable to mark channel closed: %w", err)
}
// We'll now advance our state machine until it reaches a terminal
// state, and the channel is marked resolved.
_, _, err = c.advanceState(closeInfo.CloseHeight, coopCloseTrigger, nil)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
return nil
}
// handleLocalForceCloseEvent takes a local force close event from ChainEvents,
// saves the contract resolutions to disk, mark the channel as closed and
// advance the state.
func (c *ChannelArbitrator) handleLocalForceCloseEvent(
closeInfo *LocalUnilateralCloseInfo) error {
closeTx := closeInfo.CloseTx
resolutions, err := closeInfo.ContractResolutions.
UnwrapOrErr(
fmt.Errorf("resolutions not found"),
)
if err != nil {
return fmt.Errorf("unable to get resolutions: %w", err)
}
// We make sure that the htlc resolutions are present
// otherwise we would panic dereferencing the pointer.
//
// TODO(ziggie): Refactor ContractResolutions to use
// options.
if resolutions.HtlcResolutions == nil {
return fmt.Errorf("htlc resolutions is nil")
}
log.Infof("ChannelArbitrator(%v): local force close tx=%v confirmed",
c.cfg.ChanPoint, closeTx.TxHash())
contractRes := &ContractResolutions{
CommitHash: closeTx.TxHash(),
CommitResolution: resolutions.CommitResolution,
HtlcResolutions: *resolutions.HtlcResolutions,
AnchorResolution: resolutions.AnchorResolution,
}
// When processing a unilateral close event, we'll transition to the
// ContractClosed state. We'll log out the set of resolutions such that
// they are available to fetch in that state, we'll also write the
// commit set so we can reconstruct our chain actions on restart.
err = c.log.LogContractResolutions(contractRes)
if err != nil {
return fmt.Errorf("unable to write resolutions: %w", err)
}
err = c.log.InsertConfirmedCommitSet(&closeInfo.CommitSet)
if err != nil {
return fmt.Errorf("unable to write commit set: %w", err)
}
// After the set of resolutions are successfully logged, we can safely
// close the channel. After this succeeds we won't be getting chain
// events anymore, so we must make sure we can recover on restart after
// it is marked closed. If the next state transition fails, we'll start
// up in the prior state again, and we won't be longer getting chain
// events. In this case we must manually re-trigger the state
// transition into StateContractClosed based on the close status of the
// channel.
err = c.cfg.MarkChannelClosed(
closeInfo.ChannelCloseSummary,
channeldb.ChanStatusLocalCloseInitiator,
)
if err != nil {
return fmt.Errorf("unable to mark channel closed: %w", err)
}
// We'll now advance our state machine until it reaches a terminal
// state.
_, _, err = c.advanceState(
uint32(closeInfo.SpendingHeight),
localCloseTrigger, &closeInfo.CommitSet,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
return nil
}
// handleRemoteForceCloseEvent takes a remote force close event from
// ChainEvents, saves the contract resolutions to disk, mark the channel as
// closed and advance the state.
func (c *ChannelArbitrator) handleRemoteForceCloseEvent(
closeInfo *RemoteUnilateralCloseInfo) error {
log.Infof("ChannelArbitrator(%v): remote party has force closed "+
"channel at height %v", c.cfg.ChanPoint,
closeInfo.SpendingHeight)
// If we don't have a self output, and there are no active HTLC's, then
// we can immediately mark the contract as fully resolved and exit.
contractRes := &ContractResolutions{
CommitHash: *closeInfo.SpenderTxHash,
CommitResolution: closeInfo.CommitResolution,
HtlcResolutions: *closeInfo.HtlcResolutions,
AnchorResolution: closeInfo.AnchorResolution,
}
// When processing a unilateral close event, we'll transition to the
// ContractClosed state. We'll log out the set of resolutions such that
// they are available to fetch in that state, we'll also write the
// commit set so we can reconstruct our chain actions on restart.
err := c.log.LogContractResolutions(contractRes)
if err != nil {
return fmt.Errorf("unable to write resolutions: %w", err)
}
err = c.log.InsertConfirmedCommitSet(&closeInfo.CommitSet)
if err != nil {
return fmt.Errorf("unable to write commit set: %w", err)
}
// After the set of resolutions are successfully logged, we can safely
// close the channel. After this succeeds we won't be getting chain
// events anymore, so we must make sure we can recover on restart after
// it is marked closed. If the next state transition fails, we'll start
// up in the prior state again, and we won't be longer getting chain
// events. In this case we must manually re-trigger the state
// transition into StateContractClosed based on the close status of the
// channel.
closeSummary := &closeInfo.ChannelCloseSummary
err = c.cfg.MarkChannelClosed(
closeSummary,
channeldb.ChanStatusRemoteCloseInitiator,
)
if err != nil {
return fmt.Errorf("unable to mark channel closed: %w", err)
}
// We'll now advance our state machine until it reaches a terminal
// state.
_, _, err = c.advanceState(
uint32(closeInfo.SpendingHeight),
remoteCloseTrigger, &closeInfo.CommitSet,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
return nil
}
// handleContractBreach takes a breach close event from ChainEvents, saves the
// contract resolutions to disk, mark the channel as closed and advance the
// state.
func (c *ChannelArbitrator) handleContractBreach(
breachInfo *BreachCloseInfo) error {
closeSummary := &breachInfo.CloseSummary
log.Infof("ChannelArbitrator(%v): remote party has breached channel "+
"at height %v!", c.cfg.ChanPoint, closeSummary.CloseHeight)
// In the breach case, we'll only have anchor and breach resolutions.
contractRes := &ContractResolutions{
CommitHash: breachInfo.CommitHash,
BreachResolution: breachInfo.BreachResolution,
AnchorResolution: breachInfo.AnchorResolution,
}
// We'll transition to the ContractClosed state and log the set of
// resolutions such that they can be turned into resolvers later on.
// We'll also insert the CommitSet of the latest set of commitments.
err := c.log.LogContractResolutions(contractRes)
if err != nil {
return fmt.Errorf("unable to write resolutions: %w", err)
}
err = c.log.InsertConfirmedCommitSet(&breachInfo.CommitSet)
if err != nil {
return fmt.Errorf("unable to write commit set: %w", err)
}
// The channel is finally marked pending closed here as the
// BreachArbitrator and channel arbitrator have persisted the relevant
// states.
err = c.cfg.MarkChannelClosed(
closeSummary, channeldb.ChanStatusRemoteCloseInitiator,
)
if err != nil {
return fmt.Errorf("unable to mark channel closed: %w", err)
}
log.Infof("Breached channel=%v marked pending-closed",
breachInfo.BreachResolution.FundingOutPoint)
// We'll advance our state machine until it reaches a terminal state.
_, _, err = c.advanceState(
closeSummary.CloseHeight, breachCloseTrigger,
&breachInfo.CommitSet,
)
if err != nil {
log.Errorf("Unable to advance state: %v", err)
}
return nil
}
package contractcourt
import (
"encoding/binary"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/sweep"
)
// commitSweepResolver is a resolver that will attempt to sweep the commitment
// output paying to us (local channel balance). In the case that the local
// party (we) broadcasts their version of the commitment transaction, we have
// to wait before sweeping it, as it has a CSV delay. For anchor channel
// type, even if the remote party broadcasts the commitment transaction,
// we have to wait one block after commitment transaction is confirmed,
// because CSV 1 is put into the script of UTXO representing local balance.
// Additionally, if the channel is a channel lease, we have to wait for
// CLTV to expire.
// https://docs.lightning.engineering/lightning-network-tools/pool/overview
type commitSweepResolver struct {
// localChanCfg is used to provide the resolver with the keys required
// to identify whether the commitment transaction was broadcast by the
// local or remote party.
localChanCfg channeldb.ChannelConfig
// commitResolution contains all data required to successfully sweep
// this HTLC on-chain.
commitResolution lnwallet.CommitOutputResolution
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
broadcastHeight uint32
// chanPoint is the channel point of the original contract.
chanPoint wire.OutPoint
// channelInitiator denotes whether the party responsible for resolving
// the contract initiated the channel.
channelInitiator bool
// leaseExpiry denotes the additional waiting period the contract must
// hold until it can be resolved. This waiting period is known as the
// expiration of a script-enforced leased channel and only applies to
// the channel initiator.
//
// NOTE: This value should only be set when the contract belongs to a
// leased channel.
leaseExpiry uint32
// chanType denotes the type of channel the contract belongs to.
chanType channeldb.ChannelType
// currentReport stores the current state of the resolver for reporting
// over the rpc interface.
currentReport ContractReport
// reportLock prevents concurrent access to the resolver report.
reportLock sync.Mutex
contractResolverKit
}
// newCommitSweepResolver instantiates a new direct commit output resolver.
func newCommitSweepResolver(res lnwallet.CommitOutputResolution,
broadcastHeight uint32, chanPoint wire.OutPoint,
resCfg ResolverConfig) *commitSweepResolver {
r := &commitSweepResolver{
contractResolverKit: *newContractResolverKit(resCfg),
commitResolution: res,
broadcastHeight: broadcastHeight,
chanPoint: chanPoint,
}
r.initLogger(fmt.Sprintf("%T(%v)", r, r.commitResolution.SelfOutPoint))
r.initReport()
return r
}
// ResolverKey returns an identifier which should be globally unique for this
// particular resolver within the chain the original contract resides within.
func (c *commitSweepResolver) ResolverKey() []byte {
key := newResolverID(c.commitResolution.SelfOutPoint)
return key[:]
}
// waitForSpend waits for the given outpoint to be spent, and returns the
// details of the spending tx.
func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32,
notifier chainntnfs.ChainNotifier, quit <-chan struct{}) (
*chainntnfs.SpendDetail, error) {
spendNtfn, err := notifier.RegisterSpendNtfn(
op, pkScript, heightHint,
)
if err != nil {
return nil, err
}
select {
case spendDetail, ok := <-spendNtfn.Spend:
if !ok {
return nil, errResolverShuttingDown
}
return spendDetail, nil
case <-quit:
return nil, errResolverShuttingDown
}
}
// getCommitTxConfHeight waits for confirmation of the commitment tx and
// returns the confirmation height.
func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) {
txID := c.commitResolution.SelfOutPoint.Hash
signDesc := c.commitResolution.SelfOutputSignDesc
pkScript := signDesc.Output.PkScript
const confDepth = 1
confChan, err := c.Notifier.RegisterConfirmationsNtfn(
&txID, pkScript, confDepth, c.broadcastHeight,
)
if err != nil {
return 0, err
}
defer confChan.Cancel()
select {
case txConfirmation, ok := <-confChan.Confirmed:
if !ok {
return 0, fmt.Errorf("cannot get confirmation "+
"for commit tx %v", txID)
}
return txConfirmation.BlockHeight, nil
case <-c.quit:
return 0, errResolverShuttingDown
}
}
// Resolve instructs the contract resolver to resolve the output on-chain. Once
// the output has been *fully* resolved, the function should return immediately
// with a nil ContractResolver value for the first return value. In the case
// that the contract requires further resolution, then another resolve is
// returned.
//
// NOTE: This function MUST be run as a goroutine.
// TODO(yy): fix the funlen in the next PR.
//
//nolint:funlen
func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil, nil
}
var sweepTxID chainhash.Hash
// Sweeper is going to join this input with other inputs if possible
// and publish the sweep tx. When the sweep tx confirms, it signals us
// through the result channel with the outcome. Wait for this to
// happen.
outcome := channeldb.ResolverOutcomeClaimed
select {
case sweepResult := <-c.sweepResultChan:
switch sweepResult.Err {
// If the remote party was able to sweep this output it's
// likely what we sent was actually a revoked commitment.
// Report the error and continue to wrap up the contract.
case sweep.ErrRemoteSpend:
c.log.Warnf("local commitment output was swept by "+
"remote party via %v", sweepResult.Tx.TxHash())
outcome = channeldb.ResolverOutcomeUnclaimed
// No errors, therefore continue processing.
case nil:
c.log.Infof("local commitment output fully resolved by "+
"sweep tx: %v", sweepResult.Tx.TxHash())
// Unknown errors.
default:
c.log.Errorf("unable to sweep input: %v",
sweepResult.Err)
return nil, sweepResult.Err
}
sweepTxID = sweepResult.Tx.TxHash()
case <-c.quit:
return nil, errResolverShuttingDown
}
// Funds have been swept and balance is no longer in limbo.
c.reportLock.Lock()
if outcome == channeldb.ResolverOutcomeClaimed {
// We only record the balance as recovered if it actually came
// back to us.
c.currentReport.RecoveredBalance = c.currentReport.LimboBalance
}
c.currentReport.LimboBalance = 0
c.reportLock.Unlock()
report := c.currentReport.resolverReport(
&sweepTxID, channeldb.ResolverTypeCommit, outcome,
)
c.markResolved()
// Checkpoint the resolver with a closure that will write the outcome
// of the resolver and its sweep transaction to disk.
return nil, c.Checkpoint(c, report)
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) Stop() {
c.log.Debugf("stopping...")
defer c.log.Debugf("stopped")
close(c.quit)
}
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) SupplementState(state *channeldb.OpenChannel) {
if state.ChanType.HasLeaseExpiration() {
c.leaseExpiry = state.ThawHeight
}
c.localChanCfg = state.LocalChanCfg
c.channelInitiator = state.IsInitiator
c.chanType = state.ChanType
}
// hasCLTV denotes whether the resolver must wait for an additional CLTV to
// expire before resolving the contract.
func (c *commitSweepResolver) hasCLTV() bool {
return c.channelInitiator && c.leaseExpiry > 0
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) Encode(w io.Writer) error {
if err := encodeCommitResolution(w, &c.commitResolution); err != nil {
return err
}
if err := binary.Write(w, endian, c.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, c.broadcastHeight); err != nil {
return err
}
if _, err := w.Write(c.chanPoint.Hash[:]); err != nil {
return err
}
err := binary.Write(w, endian, c.chanPoint.Index)
if err != nil {
return err
}
// Previously a sweep tx was serialized at this point. Refactoring
// removed this, but keep in mind that this data may still be present in
// the database.
return nil
}
// newCommitSweepResolverFromReader attempts to decode an encoded
// ContractResolver from the passed Reader instance, returning an active
// ContractResolver instance.
func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*commitSweepResolver, error) {
c := &commitSweepResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
if err := decodeCommitResolution(r, &c.commitResolution); err != nil {
return nil, err
}
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
c.markResolved()
}
if err := binary.Read(r, endian, &c.broadcastHeight); err != nil {
return nil, err
}
_, err := io.ReadFull(r, c.chanPoint.Hash[:])
if err != nil {
return nil, err
}
err = binary.Read(r, endian, &c.chanPoint.Index)
if err != nil {
return nil, err
}
// Previously a sweep tx was deserialized at this point. Refactoring
// removed this, but keep in mind that this data may still be present in
// the database.
c.initLogger(fmt.Sprintf("%T(%v)", c, c.commitResolution.SelfOutPoint))
c.initReport()
return c, nil
}
// report returns a report on the resolution state of the contract.
func (c *commitSweepResolver) report() *ContractReport {
c.reportLock.Lock()
defer c.reportLock.Unlock()
cpy := c.currentReport
return &cpy
}
// initReport initializes the pending channels report for this resolver.
func (c *commitSweepResolver) initReport() {
amt := btcutil.Amount(
c.commitResolution.SelfOutputSignDesc.Output.Value,
)
// Set the initial report. All fields are filled in, except for the
// maturity height which remains 0 until Resolve() is executed.
//
// TODO(joostjager): Resolvers only activate after the commit tx
// confirms. With more refactoring in channel arbitrator, it would be
// possible to make the confirmation height part of ResolverConfig and
// populate MaturityHeight here.
c.currentReport = ContractReport{
Outpoint: c.commitResolution.SelfOutPoint,
Type: ReportOutputUnencumbered,
Amount: amt,
LimboBalance: amt,
RecoveredBalance: 0,
}
}
// A compile time assertion to ensure commitSweepResolver meets the
// ContractResolver interface.
var _ reportingContractResolver = (*commitSweepResolver)(nil)
// Launch constructs a commit input and offers it to the sweeper.
func (c *commitSweepResolver) Launch() error {
if c.isLaunched() {
c.log.Tracef("already launched")
return nil
}
c.log.Debugf("launching resolver...")
c.markLaunched()
// If we're already resolved, then we can exit early.
if c.IsResolved() {
c.log.Errorf("already resolved")
return nil
}
confHeight, err := c.getCommitTxConfHeight()
if err != nil {
return err
}
// Wait up until the CSV expires, unless we also have a CLTV that
// expires after.
unlockHeight := confHeight + c.commitResolution.MaturityDelay
if c.hasCLTV() {
unlockHeight = max(unlockHeight, c.leaseExpiry)
}
// Update report now that we learned the confirmation height.
c.reportLock.Lock()
c.currentReport.MaturityHeight = unlockHeight
c.reportLock.Unlock()
// Derive the witness type for this input.
witnessType, err := c.decideWitnessType()
if err != nil {
return err
}
// We'll craft an input with all the information required for the
// sweeper to create a fully valid sweeping transaction to recover
// these coins.
var inp *input.BaseInput
if c.hasCLTV() {
inp = input.NewCsvInputWithCltv(
&c.commitResolution.SelfOutPoint, witnessType,
&c.commitResolution.SelfOutputSignDesc,
c.broadcastHeight, c.commitResolution.MaturityDelay,
c.leaseExpiry, input.WithResolutionBlob(
c.commitResolution.ResolutionBlob,
),
)
} else {
inp = input.NewCsvInput(
&c.commitResolution.SelfOutPoint, witnessType,
&c.commitResolution.SelfOutputSignDesc,
c.broadcastHeight, c.commitResolution.MaturityDelay,
input.WithResolutionBlob(
c.commitResolution.ResolutionBlob,
),
)
}
// TODO(roasbeef): instead of adding ctrl block to the sign desc, make
// new input type, have sweeper set it?
// Calculate the budget for the sweeping this input.
budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value),
c.Budget.ToLocalRatio, c.Budget.ToLocal,
)
c.log.Infof("sweeping commit output %v using budget=%v", witnessType,
budget)
// With our input constructed, we'll now offer it to the sweeper.
resultChan, err := c.Sweeper.SweepInput(
inp, sweep.Params{
Budget: budget,
// Specify a nil deadline here as there's no time
// pressure.
DeadlineHeight: fn.None[int32](),
},
)
if err != nil {
c.log.Errorf("unable to sweep input: %v", err)
return err
}
c.sweepResultChan = resultChan
return nil
}
// decideWitnessType returns the witness type for the input.
func (c *commitSweepResolver) decideWitnessType() (input.WitnessType, error) {
var (
isLocalCommitTx bool
signDesc = c.commitResolution.SelfOutputSignDesc
)
switch {
// For taproot channels, we'll know if this is the local commit based
// on the timelock value. For remote commitment transactions, the
// witness script has a timelock of 1.
case c.chanType.IsTaproot():
delayKey := c.localChanCfg.DelayBasePoint.PubKey
nonDelayKey := c.localChanCfg.PaymentBasePoint.PubKey
signKey := c.commitResolution.SelfOutputSignDesc.KeyDesc.PubKey
// If the key in the script is neither of these, we shouldn't
// proceed. This should be impossible.
if !signKey.IsEqual(delayKey) && !signKey.IsEqual(nonDelayKey) {
return nil, fmt.Errorf("unknown sign key %v", signKey)
}
// The commitment transaction is ours iff the signing key is
// the delay key.
isLocalCommitTx = signKey.IsEqual(delayKey)
// The output is on our local commitment if the script starts with
// OP_IF for the revocation clause. On the remote commitment it will
// either be a regular P2WKH or a simple sig spend with a CSV delay.
default:
isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF
}
isDelayedOutput := c.commitResolution.MaturityDelay != 0
c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput,
isLocalCommitTx)
// There're three types of commitments, those that have tweaks for the
// remote key (us in this case), those that don't, and a third where
// there is no tweak and the output is delayed. On the local commitment
// our output will always be delayed. We'll rely on the presence of the
// commitment tweak to discern which type of commitment this is.
var witnessType input.WitnessType
switch {
// The local delayed output for a taproot channel.
case isLocalCommitTx && c.chanType.IsTaproot():
witnessType = input.TaprootLocalCommitSpend
// The CSV 1 delayed output for a taproot channel.
case !isLocalCommitTx && c.chanType.IsTaproot():
witnessType = input.TaprootRemoteCommitSpend
// Delayed output to us on our local commitment for a channel lease in
// which we are the initiator.
case isLocalCommitTx && c.hasCLTV():
witnessType = input.LeaseCommitmentTimeLock
// Delayed output to us on our local commitment.
case isLocalCommitTx:
witnessType = input.CommitmentTimeLock
// A confirmed output to us on the remote commitment for a channel lease
// in which we are the initiator.
case isDelayedOutput && c.hasCLTV():
witnessType = input.LeaseCommitmentToRemoteConfirmed
// A confirmed output to us on the remote commitment.
case isDelayedOutput:
witnessType = input.CommitmentToRemoteConfirmed
// A non-delayed output on the remote commitment where the key is
// tweakless.
case c.commitResolution.SelfOutputSignDesc.SingleTweak == nil:
witnessType = input.CommitSpendNoDelayTweakless
// A non-delayed output on the remote commitment where the key is
// tweaked.
default:
witnessType = input.CommitmentNoDelay
}
return witnessType, nil
}
package contractcourt
import (
"fmt"
"github.com/btcsuite/btcd/btcutil"
)
const (
// MinBudgetValue is the minimal budget that we allow when configuring
// the budget used in sweeping outputs. The actual budget can be lower
// if the user decides to NOT set this value.
//
// NOTE: This value is chosen so the linear fee function can increase
// at least 1 sat/kw per block.
MinBudgetValue btcutil.Amount = 1008
// MinBudgetRatio is the minimal ratio that we allow when configuring
// the budget ratio used in sweeping outputs.
MinBudgetRatio = 0.001
// DefaultBudgetRatio defines a default budget ratio to be used when
// sweeping inputs. This is a large value, which is fine as the final
// fee rate is capped at the max fee rate configured.
DefaultBudgetRatio = 0.5
)
// BudgetConfig is a struct that holds the configuration when offering outputs
// to the sweeper.
//
//nolint:ll
type BudgetConfig struct {
ToLocal btcutil.Amount `long:"tolocal" description:"The amount in satoshis to allocate as the budget to pay fees when sweeping the to_local output. If set, the budget calculated using the ratio (if set) will be capped at this value."`
ToLocalRatio float64 `long:"tolocalratio" description:"The ratio of the value in to_local output to allocate as the budget to pay fees when sweeping it."`
AnchorCPFP btcutil.Amount `long:"anchorcpfp" description:"The amount in satoshis to allocate as the budget to pay fees when CPFPing a force close tx using the anchor output. If set, the budget calculated using the ratio (if set) will be capped at this value."`
AnchorCPFPRatio float64 `long:"anchorcpfpratio" description:"The ratio of a special value to allocate as the budget to pay fees when CPFPing a force close tx using the anchor output. The special value is the sum of all time-sensitive HTLCs on this commitment subtracted by their budgets."`
DeadlineHTLC btcutil.Amount `long:"deadlinehtlc" description:"The amount in satoshis to allocate as the budget to pay fees when sweeping a time-sensitive (first-level) HTLC. If set, the budget calculated using the ratio (if set) will be capped at this value."`
DeadlineHTLCRatio float64 `long:"deadlinehtlcratio" description:"The ratio of the value in a time-sensitive (first-level) HTLC to allocate as the budget to pay fees when sweeping it."`
NoDeadlineHTLC btcutil.Amount `long:"nodeadlinehtlc" description:"The amount in satoshis to allocate as the budget to pay fees when sweeping a non-time-sensitive (second-level) HTLC. If set, the budget calculated using the ratio (if set) will be capped at this value."`
NoDeadlineHTLCRatio float64 `long:"nodeadlinehtlcratio" description:"The ratio of the value in a non-time-sensitive (second-level) HTLC to allocate as the budget to pay fees when sweeping it."`
}
// Validate checks the budget configuration for any invalid values.
func (b *BudgetConfig) Validate() error {
// Exit early if no budget config is set.
if b == nil {
return fmt.Errorf("no budget config set")
}
// Sanity check all fields.
if b.ToLocal != 0 && b.ToLocal < MinBudgetValue {
return fmt.Errorf("tolocal must be at least %v",
MinBudgetValue)
}
if b.ToLocalRatio != 0 && b.ToLocalRatio < MinBudgetRatio {
return fmt.Errorf("tolocalratio must be at least %v",
MinBudgetRatio)
}
if b.AnchorCPFP != 0 && b.AnchorCPFP < MinBudgetValue {
return fmt.Errorf("anchorcpfp must be at least %v",
MinBudgetValue)
}
if b.AnchorCPFPRatio != 0 && b.AnchorCPFPRatio < MinBudgetRatio {
return fmt.Errorf("anchorcpfpratio must be at least %v",
MinBudgetRatio)
}
if b.DeadlineHTLC != 0 && b.DeadlineHTLC < MinBudgetValue {
return fmt.Errorf("deadlinehtlc must be at least %v",
MinBudgetValue)
}
if b.DeadlineHTLCRatio != 0 && b.DeadlineHTLCRatio < MinBudgetRatio {
return fmt.Errorf("deadlinehtlcratio must be at least %v",
MinBudgetRatio)
}
if b.NoDeadlineHTLC != 0 && b.NoDeadlineHTLC < MinBudgetValue {
return fmt.Errorf("nodeadlinehtlc must be at least %v",
MinBudgetValue)
}
if b.NoDeadlineHTLCRatio != 0 &&
b.NoDeadlineHTLCRatio < MinBudgetRatio {
return fmt.Errorf("nodeadlinehtlcratio must be at least %v",
MinBudgetRatio)
}
return nil
}
// String returns a human-readable description of the budget configuration.
func (b *BudgetConfig) String() string {
return fmt.Sprintf("tolocal=%v tolocalratio=%v anchorcpfp=%v "+
"anchorcpfpratio=%v deadlinehtlc=%v deadlinehtlcratio=%v "+
"nodeadlinehtlc=%v nodeadlinehtlcratio=%v",
b.ToLocal, b.ToLocalRatio, b.AnchorCPFP, b.AnchorCPFPRatio,
b.DeadlineHTLC, b.DeadlineHTLCRatio, b.NoDeadlineHTLC,
b.NoDeadlineHTLCRatio)
}
// DefaultSweeperConfig returns the default configuration for the sweeper.
func DefaultBudgetConfig() *BudgetConfig {
return &BudgetConfig{
ToLocalRatio: DefaultBudgetRatio,
AnchorCPFPRatio: DefaultBudgetRatio,
DeadlineHTLCRatio: DefaultBudgetRatio,
NoDeadlineHTLCRatio: DefaultBudgetRatio,
}
}
// calculateBudget takes an output value, a configured ratio and budget value,
// and returns the budget to use for sweeping the output. If the budget value
// is set, it will be used as cap.
func calculateBudget(value btcutil.Amount, ratio float64,
max btcutil.Amount) btcutil.Amount {
// If ratio is not set, using the default value.
if ratio == 0 {
ratio = DefaultBudgetRatio
}
budget := value.MulF64(ratio)
log.Tracef("Calculated budget=%v using value=%v, ratio=%v, cap=%v",
budget, value, ratio, max)
if max != 0 && budget > max {
log.Debugf("Calculated budget=%v is capped at %v", budget, max)
return max
}
return budget
}
package contractcourt
import (
"encoding/binary"
"errors"
"fmt"
"io"
"sync/atomic"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/sweep"
)
var (
endian = binary.BigEndian
)
const (
// sweepConfTarget is the default number of blocks that we'll use as a
// confirmation target when sweeping.
sweepConfTarget = 6
)
// ContractResolver is an interface which packages a state machine which is
// able to carry out the necessary steps required to fully resolve a Bitcoin
// contract on-chain. Resolvers are fully encodable to ensure callers are able
// to persist them properly. A resolver may produce another resolver in the
// case that claiming an HTLC is a multi-stage process. In this case, we may
// partially resolve the contract, then persist, and set up for an additional
// resolution.
type ContractResolver interface {
// ResolverKey returns an identifier which should be globally unique
// for this particular resolver within the chain the original contract
// resides within.
ResolverKey() []byte
// Launch starts the resolver by constructing an input and offering it
// to the sweeper. Once offered, it's expected to monitor the sweeping
// result in a goroutine invoked by calling Resolve.
//
// NOTE: We can call `Resolve` inside a goroutine at the end of this
// method to avoid calling it in the ChannelArbitrator. However, there
// are some DB-related operations such as SwapContract/ResolveContract
// which need to be done inside the resolvers instead, which needs a
// deeper refactoring.
Launch() error
// Resolve instructs the contract resolver to resolve the output
// on-chain. Once the output has been *fully* resolved, the function
// should return immediately with a nil ContractResolver value for the
// first return value. In the case that the contract requires further
// resolution, then another resolve is returned.
//
// NOTE: This function MUST be run as a goroutine.
Resolve() (ContractResolver, error)
// SupplementState allows the user of a ContractResolver to supplement
// it with state required for the proper resolution of a contract.
SupplementState(*channeldb.OpenChannel)
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
IsResolved() bool
// Encode writes an encoded version of the ContractResolver into the
// passed Writer.
Encode(w io.Writer) error
// Stop signals the resolver to cancel any current resolution
// processes, and suspend.
Stop()
}
// htlcContractResolver is the required interface for htlc resolvers.
type htlcContractResolver interface {
ContractResolver
// HtlcPoint returns the htlc's outpoint on the commitment tx.
HtlcPoint() wire.OutPoint
// Supplement adds additional information to the resolver that is
// required before Resolve() is called.
Supplement(htlc channeldb.HTLC)
// SupplementDeadline gives the deadline height for the HTLC output.
// This is only useful for outgoing HTLCs.
SupplementDeadline(deadlineHeight fn.Option[int32])
}
// reportingContractResolver is a ContractResolver that also exposes a report on
// the resolution state of the contract.
type reportingContractResolver interface {
ContractResolver
report() *ContractReport
}
// ResolverConfig contains the externally supplied configuration items that are
// required by a ContractResolver implementation.
type ResolverConfig struct {
// ChannelArbitratorConfig contains all the interfaces and closures
// required for the resolver to interact with outside sub-systems.
ChannelArbitratorConfig
// Checkpoint allows a resolver to check point its state. This function
// should write the state of the resolver to persistent storage, and
// return a non-nil error upon success. It takes a resolver report,
// which contains information about the outcome and should be written
// to disk if non-nil.
Checkpoint func(ContractResolver, ...*channeldb.ResolverReport) error
}
// contractResolverKit is meant to be used as a mix-in struct to be embedded within a
// given ContractResolver implementation. It contains all the common items that
// a resolver requires to carry out its duties.
type contractResolverKit struct {
ResolverConfig
log btclog.Logger
quit chan struct{}
// sweepResultChan is the result chan returned from calling
// `SweepInput`. It should be mounted to the specific resolver once the
// input has been offered to the sweeper.
sweepResultChan chan sweep.Result
// launched specifies whether the resolver has been launched. Calling
// `Launch` will be a no-op if this is true. This value is not saved to
// db, as it's fine to relaunch a resolver after a restart. It's only
// used to avoid resending requests to the sweeper when a new blockbeat
// is received.
launched atomic.Bool
// resolved reflects if the contract has been fully resolved or not.
resolved atomic.Bool
}
// newContractResolverKit instantiates the mix-in struct.
func newContractResolverKit(cfg ResolverConfig) *contractResolverKit {
return &contractResolverKit{
ResolverConfig: cfg,
quit: make(chan struct{}),
}
}
// initLogger initializes the resolver-specific logger.
func (r *contractResolverKit) initLogger(prefix string) {
logPrefix := fmt.Sprintf("ChannelArbitrator(%v): %s:", r.ChanPoint,
prefix)
r.log = log.WithPrefix(logPrefix)
}
// IsResolved returns true if the stored state in the resolve is fully
// resolved. In this case the target output can be forgotten.
//
// NOTE: Part of the ContractResolver interface.
func (r *contractResolverKit) IsResolved() bool {
return r.resolved.Load()
}
// markResolved marks the resolver as resolved.
func (r *contractResolverKit) markResolved() {
r.resolved.Store(true)
}
// isLaunched returns true if the resolver has been launched.
func (r *contractResolverKit) isLaunched() bool {
return r.launched.Load()
}
// markLaunched marks the resolver as launched.
func (r *contractResolverKit) markLaunched() {
r.launched.Store(true)
}
var (
// errResolverShuttingDown is returned when the resolver stops
// progressing because it received the quit signal.
errResolverShuttingDown = errors.New("resolver shutting down")
)
package contractcourt
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
)
// htlcIncomingContestResolver is a ContractResolver that's able to resolve an
// incoming HTLC that is still contested. An HTLC is still contested, if at the
// time of commitment broadcast, we don't know of the preimage for it yet, and
// it hasn't expired. In this case, we can resolve the HTLC if we learn of the
// preimage, otherwise the remote party will sweep it after it expires.
//
// TODO(roasbeef): just embed the other resolver?
type htlcIncomingContestResolver struct {
// htlcExpiry is the absolute expiry of this incoming HTLC. We use this
// value to determine if we can exit early as if the HTLC times out,
// before we learn of the preimage then we can't claim it on chain
// successfully.
htlcExpiry uint32
// htlcSuccessResolver is the inner resolver that may be utilized if we
// learn of the preimage.
*htlcSuccessResolver
}
// newIncomingContestResolver instantiates a new incoming htlc contest resolver.
func newIncomingContestResolver(
res lnwallet.IncomingHtlcResolution, broadcastHeight uint32,
htlc channeldb.HTLC, resCfg ResolverConfig) *htlcIncomingContestResolver {
success := newSuccessResolver(
res, broadcastHeight, htlc, resCfg,
)
return &htlcIncomingContestResolver{
htlcExpiry: htlc.RefundTimeout,
htlcSuccessResolver: success,
}
}
func (h *htlcIncomingContestResolver) processFinalHtlcFail() error {
// Mark the htlc as final failed.
err := h.ChainArbitratorConfig.PutFinalHtlcOutcome(
h.ChannelArbitratorConfig.ShortChanID, h.htlc.HtlcIndex, false,
)
if err != nil {
return err
}
// Send notification.
h.ChainArbitratorConfig.HtlcNotifier.NotifyFinalHtlcEvent(
models.CircuitKey{
ChanID: h.ShortChanID,
HtlcID: h.htlc.HtlcIndex,
},
channeldb.FinalHtlcInfo{
Settled: false,
Offchain: false,
},
)
return nil
}
// Launch will call the inner resolver's launch method if the preimage can be
// found, otherwise it's a no-op.
func (h *htlcIncomingContestResolver) Launch() error {
// NOTE: we don't mark this resolver as launched as the inner resolver
// will set it when it's launched.
if h.isLaunched() {
h.log.Tracef("already launched")
return nil
}
h.log.Debugf("launching contest resolver...")
// Query the preimage and apply it if we already know it.
applied, err := h.findAndapplyPreimage()
if err != nil {
return err
}
// No preimage found, leave it to be handled by the resolver.
if !applied {
return nil
}
h.log.Debugf("found preimage for htlc=%x, transforming into success "+
"resolver and launching it", h.htlc.RHash)
// Once we've applied the preimage, we'll launch the inner resolver to
// attempt to claim the HTLC.
return h.htlcSuccessResolver.Launch()
}
// Resolve attempts to resolve this contract. As we don't yet know of the
// preimage for the contract, we'll wait for one of two things to happen:
//
// 1. We learn of the preimage! In this case, we can sweep the HTLC incoming
// and ensure that if this was a multi-hop HTLC we are made whole. In this
// case, an additional ContractResolver will be returned to finish the
// job.
//
// 2. The HTLC expires. If this happens, then the contract is fully resolved
// as we have no remaining actions left at our disposal.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
// First try to parse the payload. If that fails, we can stop resolution
// now.
payload, nextHopOnionBlob, err := h.decodePayload()
if err != nil {
h.log.Debugf("cannot decode payload of htlc %v", h.HtlcPoint())
// If we've locked in an htlc with an invalid payload on our
// commitment tx, we don't need to resolve it. The other party
// will time it out and get their funds back. This situation
// can present itself when we crash before processRemoteAdds in
// the link has ran.
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
}
// We write a report to disk that indicates we could not decode
// the htlc.
resReport := h.report().resolverReport(
nil, channeldb.ResolverTypeIncomingHtlc,
channeldb.ResolverOutcomeAbandoned,
)
return nil, h.PutResolverReport(nil, resReport)
}
// Register for block epochs. After registration, the current height
// will be sent on the channel immediately.
blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return nil, err
}
defer blockEpochs.Cancel()
var currentHeight int32
select {
case newBlock, ok := <-blockEpochs.Epochs:
if !ok {
return nil, errResolverShuttingDown
}
currentHeight = newBlock.Height
case <-h.quit:
return nil, errResolverShuttingDown
}
log.Debugf("%T(%v): Resolving incoming HTLC(expiry=%v, height=%v)", h,
h.htlcResolution.ClaimOutpoint, h.htlcExpiry, currentHeight)
// We'll first check if this HTLC has been timed out, if so, we can
// return now and mark ourselves as resolved. If we're past the point of
// expiry of the HTLC, then at this point the sender can sweep it, so
// we'll end our lifetime. Here we deliberately forego the chance that
// the sender doesn't sweep and we already have or will learn the
// preimage. Otherwise the resolver could potentially stay active
// indefinitely and the channel will never close properly.
if uint32(currentHeight) >= h.htlcExpiry {
// TODO(roasbeef): should also somehow check if outgoing is
// resolved or not
// * may need to hook into the circuit map
// * can't timeout before the outgoing has been
log.Infof("%T(%v): HTLC has timed out (expiry=%v, height=%v), "+
"abandoning", h, h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
}
// Finally, get our report and checkpoint our resolver with a
// timeout outcome report.
report := h.report().resolverReport(
nil, channeldb.ResolverTypeIncomingHtlc,
channeldb.ResolverOutcomeTimeout,
)
return nil, h.Checkpoint(h, report)
}
// Define a closure to process htlc resolutions either directly or
// triggered by future notifications.
processHtlcResolution := func(e invoices.HtlcResolution) (
ContractResolver, error) {
// Take action based on the type of resolution we have
// received.
switch resolution := e.(type) {
// If the htlc resolution was a settle, apply the
// preimage and return a success resolver.
case *invoices.HtlcSettleResolution:
err := h.applyPreimage(resolution.Preimage)
if err != nil {
return nil, err
}
return h.htlcSuccessResolver, nil
// If the htlc was failed, mark the htlc as
// resolved.
case *invoices.HtlcFailResolution:
log.Infof("%T(%v): Exit hop HTLC canceled "+
"(expiry=%v, height=%v), abandoning", h,
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
}
// Checkpoint our resolver with an abandoned outcome
// because we take no further action on this htlc.
report := h.report().resolverReport(
nil, channeldb.ResolverTypeIncomingHtlc,
channeldb.ResolverOutcomeAbandoned,
)
return nil, h.Checkpoint(h, report)
// Error if the resolution type is unknown, we are only
// expecting settles and fails.
default:
return nil, fmt.Errorf("unknown resolution"+
" type: %v", e)
}
}
var (
hodlChan <-chan interface{}
witnessUpdates <-chan lntypes.Preimage
)
if payload.FwdInfo.NextHop == hop.Exit {
// Create a buffered hodl chan to prevent deadlock.
hodlQueue := queue.NewConcurrentQueue(10)
hodlQueue.Start()
hodlChan = hodlQueue.ChanOut()
// Notify registry that we are potentially resolving as an exit
// hop on-chain. If this HTLC indeed pays to an existing
// invoice, the invoice registry will tell us what to do with
// the HTLC. This is identical to HTLC resolution in the link.
circuitKey := models.CircuitKey{
ChanID: h.ShortChanID,
HtlcID: h.htlc.HtlcIndex,
}
resolution, err := h.Registry.NotifyExitHopHtlc(
h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, currentHeight,
circuitKey, hodlQueue.ChanIn(), h.htlc.CustomRecords,
payload,
)
if err != nil {
return nil, err
}
h.log.Debugf("received resolution from registry: %v",
resolution)
defer func() {
h.Registry.HodlUnsubscribeAll(hodlQueue.ChanIn())
hodlQueue.Stop()
}()
// Take action based on the resolution we received. If the htlc
// was settled, or a htlc for a known invoice failed we can
// resolve it directly. If the resolution is nil, the htlc was
// neither accepted nor failed, so we cannot take action yet.
switch res := resolution.(type) {
case *invoices.HtlcFailResolution:
// In the case where the htlc failed, but the invoice
// was known to the registry, we can directly resolve
// the htlc.
if res.Outcome != invoices.ResultInvoiceNotFound {
return processHtlcResolution(resolution)
}
// If we settled the htlc, we can resolve it.
case *invoices.HtlcSettleResolution:
return processHtlcResolution(resolution)
// If the resolution is nil, the htlc was neither settled nor
// failed so we cannot take action at present.
case nil:
default:
return nil, fmt.Errorf("unknown htlc resolution type: %T",
resolution)
}
} else {
// If the HTLC hasn't expired yet, then we may still be able to
// claim it if we learn of the pre-image, so we'll subscribe to
// the preimage database to see if it turns up, or the HTLC
// times out.
//
// NOTE: This is done BEFORE opportunistically querying the db,
// to ensure the preimage can't be delivered between querying
// and registering for the preimage subscription.
preimageSubscription, err := h.PreimageDB.SubscribeUpdates(
h.htlcSuccessResolver.ShortChanID, &h.htlc,
payload, nextHopOnionBlob,
)
if err != nil {
return nil, err
}
defer preimageSubscription.CancelSubscription()
// With the epochs and preimage subscriptions initialized, we'll
// query to see if we already know the preimage.
preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash)
if ok {
// If we do, then this means we can claim the HTLC!
// However, we don't know how to ourselves, so we'll
// return our inner resolver which has the knowledge to
// do so.
h.log.Debugf("Found preimage for htlc=%x", h.htlc.RHash)
if err := h.applyPreimage(preimage); err != nil {
return nil, err
}
return h.htlcSuccessResolver, nil
}
witnessUpdates = preimageSubscription.WitnessUpdates
}
for {
select {
case preimage := <-witnessUpdates:
// We received a new preimage, but we need to ignore
// all except the preimage we are waiting for.
if !preimage.Matches(h.htlc.RHash) {
continue
}
h.log.Debugf("Received preimage for htlc=%x",
h.htlc.RHash)
if err := h.applyPreimage(preimage); err != nil {
return nil, err
}
// We've learned of the preimage and this information
// has been added to our inner resolver. We return it so
// it can continue contract resolution.
return h.htlcSuccessResolver, nil
case hodlItem := <-hodlChan:
htlcResolution := hodlItem.(invoices.HtlcResolution)
return processHtlcResolution(htlcResolution)
case newBlock, ok := <-blockEpochs.Epochs:
if !ok {
return nil, errResolverShuttingDown
}
// If this new height expires the HTLC, then this means
// we never found out the preimage, so we can mark
// resolved and exit.
newHeight := uint32(newBlock.Height)
if newHeight >= h.htlcExpiry {
log.Infof("%T(%v): HTLC has timed out "+
"(expiry=%v, height=%v), abandoning", h,
h.htlcResolution.ClaimOutpoint,
h.htlcExpiry, currentHeight)
h.markResolved()
if err := h.processFinalHtlcFail(); err != nil {
return nil, err
}
report := h.report().resolverReport(
nil,
channeldb.ResolverTypeIncomingHtlc,
channeldb.ResolverOutcomeTimeout,
)
return nil, h.Checkpoint(h, report)
}
case <-h.quit:
return nil, errResolverShuttingDown
}
}
}
// applyPreimage is a helper function that will populate our internal resolver
// with the preimage we learn of. This should be called once the preimage is
// revealed so the inner resolver can properly complete its duties. The error
// return value indicates whether the preimage was properly applied.
func (h *htlcIncomingContestResolver) applyPreimage(
preimage lntypes.Preimage) error {
// Sanity check to see if this preimage matches our htlc. At this point
// it should never happen that it does not match.
if !preimage.Matches(h.htlc.RHash) {
return errors.New("preimage does not match hash")
}
// We may already have the preimage since both the `Launch` and
// `Resolve` methods will look for it.
if h.htlcResolution.Preimage != lntypes.ZeroHash {
h.log.Debugf("already applied preimage for htlc=%x",
h.htlc.RHash)
return nil
}
// Update htlcResolution with the matching preimage.
h.htlcResolution.Preimage = preimage
log.Infof("%T(%v): applied preimage=%v", h,
h.htlcResolution.ClaimOutpoint, preimage)
isSecondLevel := h.htlcResolution.SignedSuccessTx != nil
// If we didn't have to go to the second level to claim (this
// is the remote commitment transaction), then we don't need to
// modify our canned witness.
if !isSecondLevel {
return nil
}
isTaproot := txscript.IsPayToTaproot(
h.htlcResolution.SignedSuccessTx.TxOut[0].PkScript,
)
// If this is our commitment transaction, then we'll need to
// populate the witness for the second-level HTLC transaction.
switch {
// For taproot channels, the witness for sweeping with success
// looks like:
// - <sender sig> <receiver sig> <preimage> <success_script>
// <control_block>
//
// So we'll insert it at the 3rd index of the witness.
case isTaproot:
//nolint:ll
h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[2] = preimage[:]
// Within the witness for the success transaction, the
// preimage is the 4th element as it looks like:
//
// * <0> <sender sig> <recvr sig> <preimage> <witness script>
//
// We'll populate it within the witness, as since this
// was a "contest" resolver, we didn't yet know of the
// preimage.
case !isTaproot:
//nolint:ll
h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:]
}
return nil
}
// report returns a report on the resolution state of the contract.
func (h *htlcIncomingContestResolver) report() *ContractReport {
// No locking needed as these values are read-only.
finalAmt := h.htlc.Amt.ToSatoshis()
if h.htlcResolution.SignedSuccessTx != nil {
finalAmt = btcutil.Amount(
h.htlcResolution.SignedSuccessTx.TxOut[0].Value,
)
}
return &ContractReport{
Outpoint: h.htlcResolution.ClaimOutpoint,
Type: ReportOutputIncomingHtlc,
Amount: finalAmt,
MaturityHeight: h.htlcExpiry,
LimboBalance: finalAmt,
Stage: 1,
}
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) Stop() {
h.log.Debugf("stopping...")
defer h.log.Debugf("stopped")
close(h.quit)
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) Encode(w io.Writer) error {
// We'll first write out the one field unique to this resolver.
if err := binary.Write(w, endian, h.htlcExpiry); err != nil {
return err
}
// Then we'll write out our internal resolver.
return h.htlcSuccessResolver.Encode(w)
}
// newIncomingContestResolverFromReader attempts to decode an encoded ContractResolver
// from the passed Reader instance, returning an active ContractResolver
// instance.
func newIncomingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*htlcIncomingContestResolver, error) {
h := &htlcIncomingContestResolver{}
// We'll first read the one field unique to this resolver.
if err := binary.Read(r, endian, &h.htlcExpiry); err != nil {
return nil, err
}
// Then we'll decode our internal resolver.
successResolver, err := newSuccessResolverFromReader(r, resCfg)
if err != nil {
return nil, err
}
h.htlcSuccessResolver = successResolver
return h, nil
}
// Supplement adds additional information to the resolver that is required
// before Resolve() is called.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcIncomingContestResolver) Supplement(htlc channeldb.HTLC) {
h.htlc = htlc
}
// SupplementDeadline does nothing for an incoming htlc resolver.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcIncomingContestResolver) SupplementDeadline(_ fn.Option[int32]) {
}
// decodePayload (re)decodes the hop payload of a received htlc.
func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload,
[]byte, error) {
blindingInfo := hop.ReconstructBlindingInfo{
IncomingAmt: h.htlc.Amt,
IncomingExpiry: h.htlc.RefundTimeout,
BlindingKey: h.htlc.BlindingPoint,
}
onionReader := bytes.NewReader(h.htlc.OnionBlob[:])
iterator, err := h.OnionProcessor.ReconstructHopIterator(
onionReader, h.htlc.RHash[:], blindingInfo,
)
if err != nil {
return nil, nil, err
}
payload, _, err := iterator.HopPayload()
if err != nil {
return nil, nil, err
}
// Transform onion blob for the next hop.
var onionBlob [lnwire.OnionPacketSize]byte
buf := bytes.NewBuffer(onionBlob[0:0])
err = iterator.EncodeNextHop(buf)
if err != nil {
return nil, nil, err
}
return payload, onionBlob[:], nil
}
// A compile time assertion to ensure htlcIncomingContestResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcIncomingContestResolver)(nil)
// findAndapplyPreimage performs a non-blocking read to find the preimage for
// the incoming HTLC. If found, it will be applied to the resolver. This method
// is used for the resolver to decide whether it wants to transform into a
// success resolver during launching.
//
// NOTE: Since we have two places to query the preimage, we need to check both
// the preimage db and the invoice db to look up the preimage.
func (h *htlcIncomingContestResolver) findAndapplyPreimage() (bool, error) {
// Query to see if we already know the preimage.
preimage, ok := h.PreimageDB.LookupPreimage(h.htlc.RHash)
// If the preimage is known, we'll apply it.
if ok {
if err := h.applyPreimage(preimage); err != nil {
return false, err
}
// Successfully applied the preimage, we can now return.
return true, nil
}
// First try to parse the payload.
payload, _, err := h.decodePayload()
if err != nil {
h.log.Errorf("Cannot decode payload of htlc %v", h.HtlcPoint())
// If we cannot decode the payload, we will return a nil error
// and let it to be handled in `Resolve`.
return false, nil
}
// Exit early if this is not the exit hop, which means we are not the
// payment receiver and don't have preimage.
if payload.FwdInfo.NextHop != hop.Exit {
return false, nil
}
// Notify registry that we are potentially resolving as an exit hop
// on-chain. If this HTLC indeed pays to an existing invoice, the
// invoice registry will tell us what to do with the HTLC. This is
// identical to HTLC resolution in the link.
circuitKey := models.CircuitKey{
ChanID: h.ShortChanID,
HtlcID: h.htlc.HtlcIndex,
}
// Try get the resolution - if it doesn't give us a resolution
// immediately, we'll assume we don't know it yet and let the `Resolve`
// handle the waiting.
//
// NOTE: we use a nil subscriber here and a zero current height as we
// are only interested in the settle resolution.
//
// TODO(yy): move this logic to link and let the preimage be accessed
// via the preimage beacon.
resolution, err := h.Registry.NotifyExitHopHtlc(
h.htlc.RHash, h.htlc.Amt, h.htlcExpiry, 0,
circuitKey, nil, h.htlc.CustomRecords, payload,
)
if err != nil {
return false, err
}
res, ok := resolution.(*invoices.HtlcSettleResolution)
// Exit early if it's not a settle resolution.
if !ok {
return false, nil
}
// Otherwise we have a settle resolution, apply the preimage.
err = h.applyPreimage(res.Preimage)
if err != nil {
return false, err
}
return true, nil
}
package contractcourt
import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/tlv"
)
// htlcLeaseResolver is a struct that houses the lease specific HTLC resolution
// logic. This includes deriving the _true_ waiting height, as well as the
// input to offer to the sweeper.
type htlcLeaseResolver struct {
// channelInitiator denotes whether the party responsible for resolving
// the contract initiated the channel.
channelInitiator bool
// leaseExpiry denotes the additional waiting period the contract must
// hold until it can be resolved. This waiting period is known as the
// expiration of a script-enforced leased channel and only applies to
// the channel initiator.
//
// NOTE: This value should only be set when the contract belongs to a
// leased channel.
leaseExpiry uint32
}
// hasCLTV denotes whether the resolver must wait for an additional CLTV to
// expire before resolving the contract.
func (h *htlcLeaseResolver) hasCLTV() bool {
return h.channelInitiator && h.leaseExpiry > 0
}
// deriveWaitHeight computes the height the resolver needs to wait until it can
// sweep the input.
func (h *htlcLeaseResolver) deriveWaitHeight(csvDelay uint32,
commitSpend *chainntnfs.SpendDetail) uint32 {
waitHeight := uint32(commitSpend.SpendingHeight) + csvDelay - 1
if h.hasCLTV() {
waitHeight = max(waitHeight, h.leaseExpiry)
}
return waitHeight
}
// makeSweepInput constructs the type of input (either just csv or csv+ctlv) to
// send to the sweeper so the output can ultimately be swept.
func (h *htlcLeaseResolver) makeSweepInput(op *wire.OutPoint,
wType, cltvWtype input.StandardWitnessType,
signDesc *input.SignDescriptor, csvDelay, broadcastHeight uint32,
payHash [32]byte, resBlob fn.Option[tlv.Blob]) *input.BaseInput {
log.Infof("%T(%x): offering second-layer output to sweeper: %v", h,
payHash, op)
if h.hasCLTV() {
return input.NewCsvInputWithCltv(
op, cltvWtype, signDesc, broadcastHeight, csvDelay,
h.leaseExpiry, input.WithResolutionBlob(resBlob),
)
}
log.Infof("%T(%x): CSV lock expired, offering second-layer output to "+
"sweeper: %v", h, payHash, op)
return input.NewCsvInput(
op, wType, signDesc, broadcastHeight, csvDelay,
input.WithResolutionBlob(resBlob),
)
}
// SupplementState allows the user of a ContractResolver to supplement it with
// state required for the proper resolution of a contract.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcLeaseResolver) SupplementState(state *channeldb.OpenChannel) {
if state.ChanType.HasLeaseExpiration() {
h.leaseExpiry = state.ThawHeight
}
h.channelInitiator = state.IsInitiator
}
package contractcourt
import (
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
)
// htlcOutgoingContestResolver is a ContractResolver that's able to resolve an
// outgoing HTLC that is still contested. An HTLC is still contested, if at the
// time that we broadcast the commitment transaction, it isn't able to be fully
// resolved. In this case, we'll either wait for the HTLC to timeout, or for
// us to learn of the preimage.
type htlcOutgoingContestResolver struct {
// htlcTimeoutResolver is the inner solver that this resolver may turn
// into. This only happens if the HTLC expires on-chain.
*htlcTimeoutResolver
}
// newOutgoingContestResolver instantiates a new outgoing contested htlc
// resolver.
func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution,
broadcastHeight uint32, htlc channeldb.HTLC,
resCfg ResolverConfig) *htlcOutgoingContestResolver {
timeout := newTimeoutResolver(
res, broadcastHeight, htlc, resCfg,
)
return &htlcOutgoingContestResolver{
htlcTimeoutResolver: timeout,
}
}
// Launch will call the inner resolver's launch method if the expiry height has
// been reached, otherwise it's a no-op.
func (h *htlcOutgoingContestResolver) Launch() error {
// NOTE: we don't mark this resolver as launched as the inner resolver
// will set it when it's launched.
if h.isLaunched() {
h.log.Tracef("already launched")
return nil
}
h.log.Debugf("launching contest resolver...")
_, bestHeight, err := h.ChainIO.GetBestBlock()
if err != nil {
return err
}
if uint32(bestHeight) < h.htlcResolution.Expiry {
return nil
}
// If the current height is >= expiry, then a timeout path spend will
// be valid to be included in the next block, and we can immediately
// return the resolver.
h.log.Infof("expired (height=%v, expiry=%v), transforming into "+
"timeout resolver and launching it", bestHeight,
h.htlcResolution.Expiry)
return h.htlcTimeoutResolver.Launch()
}
// Resolve commences the resolution of this contract. As this contract hasn't
// yet timed out, we'll wait for one of two things to happen
//
// 1. The HTLC expires. In this case, we'll sweep the funds and send a clean
// up cancel message to outside sub-systems.
//
// 2. The remote party sweeps this HTLC on-chain, in which case we'll add the
// pre-image to our global cache, then send a clean up settle message
// backwards.
//
// When either of these two things happens, we'll create a new resolver which
// is able to handle the final resolution of the contract. We're only the pivot
// point.
func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) {
// If we're already full resolved, then we don't have anything further
// to do.
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
// Otherwise, we'll watch for two external signals to decide if we'll
// morph into another resolver, or fully resolve the contract.
//
// The output we'll be watching for is the *direct* spend from the HTLC
// output. If this isn't our commitment transaction, it'll be right on
// the resolution. Otherwise, we fetch this pointer from the input of
// the time out transaction.
outPointToWatch, scriptToWatch, err := h.chainDetailsToWatch()
if err != nil {
return nil, err
}
// First, we'll register for a spend notification for this output. If
// the remote party sweeps with the pre-image, we'll be notified.
spendNtfn, err := h.Notifier.RegisterSpendNtfn(
outPointToWatch, scriptToWatch, h.broadcastHeight,
)
if err != nil {
return nil, err
}
// We'll quickly check to see if the output has already been spent.
select {
// If the output has already been spent, then we can stop early and
// sweep the pre-image from the output.
case commitSpend, ok := <-spendNtfn.Spend:
if !ok {
return nil, errResolverShuttingDown
}
return nil, h.claimCleanUp(commitSpend)
// If it hasn't, then we'll watch for both the expiration, and the
// sweeping out this output.
default:
}
// If we reach this point, then we can't fully act yet, so we'll await
// either of our signals triggering: the HTLC expires, or we learn of
// the preimage.
blockEpochs, err := h.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return nil, err
}
defer blockEpochs.Cancel()
for {
select {
// A new block has arrived, we'll check to see if this leads to
// HTLC expiration.
case newBlock, ok := <-blockEpochs.Epochs:
if !ok {
return nil, errResolverShuttingDown
}
// If the current height is >= expiry, then a timeout
// path spend will be valid to be included in the next
// block, and we can immediately return the resolver.
//
// NOTE: when broadcasting this transaction, btcd will
// check the timelock in `CheckTransactionStandard`,
// which requires `expiry < currentHeight+1`. If the
// check doesn't pass, error `transaction is not
// finalized` will be returned and the broadcast will
// fail.
newHeight := uint32(newBlock.Height)
expiry := h.htlcResolution.Expiry
// Check if the expiry height is about to be reached.
// We offer this HTLC one block earlier to make sure
// when the next block arrives, the sweeper will pick
// up this input and sweep it immediately. The sweeper
// will handle the waiting for the one last block till
// expiry.
if newHeight >= expiry-1 {
h.log.Infof("HTLC about to expire "+
"(height=%v, expiry=%v), transforming "+
"into timeout resolver", newHeight,
h.htlcResolution.Expiry)
return h.htlcTimeoutResolver, nil
}
// The output has been spent! This means the preimage has been
// revealed on-chain.
case commitSpend, ok := <-spendNtfn.Spend:
if !ok {
return nil, errResolverShuttingDown
}
// The only way this output can be spent by the remote
// party is by revealing the preimage. So we'll perform
// our duties to clean up the contract once it has been
// claimed.
return nil, h.claimCleanUp(commitSpend)
case <-h.quit:
return nil, errResolverShuttingDown
}
}
}
// report returns a report on the resolution state of the contract.
func (h *htlcOutgoingContestResolver) report() *ContractReport {
// No locking needed as these values are read-only.
finalAmt := h.htlc.Amt.ToSatoshis()
if h.htlcResolution.SignedTimeoutTx != nil {
finalAmt = btcutil.Amount(
h.htlcResolution.SignedTimeoutTx.TxOut[0].Value,
)
}
return &ContractReport{
Outpoint: h.htlcResolution.ClaimOutpoint,
Type: ReportOutputOutgoingHtlc,
Amount: finalAmt,
MaturityHeight: h.htlcResolution.Expiry,
LimboBalance: finalAmt,
Stage: 1,
}
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcOutgoingContestResolver) Stop() {
h.log.Debugf("stopping...")
defer h.log.Debugf("stopped")
close(h.quit)
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcOutgoingContestResolver) Encode(w io.Writer) error {
return h.htlcTimeoutResolver.Encode(w)
}
// SupplementDeadline does nothing for an incoming htlc resolver.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcOutgoingContestResolver) SupplementDeadline(_ fn.Option[int32]) {
}
// newOutgoingContestResolverFromReader attempts to decode an encoded ContractResolver
// from the passed Reader instance, returning an active ContractResolver
// instance.
func newOutgoingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*htlcOutgoingContestResolver, error) {
h := &htlcOutgoingContestResolver{}
timeoutResolver, err := newTimeoutResolverFromReader(r, resCfg)
if err != nil {
return nil, err
}
h.htlcTimeoutResolver = timeoutResolver
return h, nil
}
// A compile time assertion to ensure htlcOutgoingContestResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcOutgoingContestResolver)(nil)
package contractcourt
import (
"encoding/binary"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/sweep"
)
// htlcSuccessResolver is a resolver that's capable of sweeping an incoming
// HTLC output on-chain. If this is the remote party's commitment, we'll sweep
// it directly from the commitment output *immediately*. If this is our
// commitment, we'll first broadcast the success transaction, then send it to
// the incubator for sweeping. That's it, no need to send any clean up
// messages.
//
// TODO(roasbeef): don't need to broadcast?
type htlcSuccessResolver struct {
// htlcResolution is the incoming HTLC resolution for this HTLC. It
// contains everything we need to properly resolve this HTLC.
htlcResolution lnwallet.IncomingHtlcResolution
// outputIncubating returns true if we've sent the output to the output
// incubator (utxo nursery). In case the htlcResolution has non-nil
// SignDetails, it means we will let the Sweeper handle broadcasting
// the secondd-level transaction, and sweeping its output. In this case
// we let this field indicate whether we need to broadcast the
// second-level tx (false) or if it has confirmed and we must sweep the
// second-level output (true).
outputIncubating bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
broadcastHeight uint32
// htlc contains information on the htlc that we are resolving on-chain.
htlc channeldb.HTLC
// currentReport stores the current state of the resolver for reporting
// over the rpc interface. This should only be reported in case we have
// a non-nil SignDetails on the htlcResolution, otherwise the nursery
// will produce reports.
currentReport ContractReport
// reportLock prevents concurrent access to the resolver report.
reportLock sync.Mutex
contractResolverKit
htlcLeaseResolver
}
// newSuccessResolver instanties a new htlc success resolver.
func newSuccessResolver(res lnwallet.IncomingHtlcResolution,
broadcastHeight uint32, htlc channeldb.HTLC,
resCfg ResolverConfig) *htlcSuccessResolver {
h := &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(resCfg),
htlcResolution: res,
broadcastHeight: broadcastHeight,
htlc: htlc,
}
h.initReport()
h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h
}
// outpoint returns the outpoint of the HTLC output we're attempting to sweep.
func (h *htlcSuccessResolver) outpoint() wire.OutPoint {
// The primary key for this resolver will be the outpoint of the HTLC
// on the commitment transaction itself. If this is our commitment,
// then the output can be found within the signed success tx,
// otherwise, it's just the ClaimOutpoint.
if h.htlcResolution.SignedSuccessTx != nil {
return h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint
}
return h.htlcResolution.ClaimOutpoint
}
// ResolverKey returns an identifier which should be globally unique for this
// particular resolver within the chain the original contract resides within.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) ResolverKey() []byte {
key := newResolverID(h.outpoint())
return key[:]
}
// Resolve attempts to resolve an unresolved incoming HTLC that we know the
// preimage to. If the HTLC is on the commitment of the remote party, then we'll
// simply sweep it directly. Otherwise, we'll hand this off to the utxo nursery
// to do its duty. There is no need to make a call to the invoice registry
// anymore. Every HTLC has already passed through the incoming contest resolver
// and in there the invoice was already marked as settled.
//
// NOTE: Part of the ContractResolver interface.
//
// TODO(yy): refactor the interface method to return an error only.
func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
var err error
switch {
// If we're already resolved, then we can exit early.
case h.IsResolved():
h.log.Errorf("already resolved")
// If this is an output on the remote party's commitment transaction,
// use the direct-spend path to sweep the htlc.
case h.isRemoteCommitOutput():
err = h.resolveRemoteCommitOutput()
// If this is an output on our commitment transaction using post-anchor
// channel type, it will be handled by the sweeper.
case h.isZeroFeeOutput():
err = h.resolveSuccessTx()
// If this is an output on our own commitment using pre-anchor channel
// type, we will publish the success tx and offer the output to the
// nursery.
default:
err = h.resolveLegacySuccessTx()
}
return nil, err
}
// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote
// commitment with the preimage. In this case we can sweep the output directly,
// and don't have to broadcast a second-level transaction.
func (h *htlcSuccessResolver) resolveRemoteCommitOutput() error {
h.log.Info("waiting for direct-preimage spend of the htlc to confirm")
// Wait for the direct-preimage HTLC sweep tx to confirm.
//
// TODO(yy): use the result chan returned from `SweepInput`.
sweepTxDetails, err := waitForSpend(
&h.htlcResolution.ClaimOutpoint,
h.htlcResolution.SweepSignDesc.Output.PkScript,
h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return err
}
// TODO(yy): should also update the `RecoveredBalance` and
// `LimboBalance` like other paths?
// Checkpoint the resolver, and write the outcome to disk.
return h.checkpointClaim(sweepTxDetails.SpenderTxHash)
}
// checkpointClaim checkpoints the success resolver with the reports it needs.
// If this htlc was claimed two stages, it will write reports for both stages,
// otherwise it will just write for the single htlc claim.
func (h *htlcSuccessResolver) checkpointClaim(spendTx *chainhash.Hash) error {
// Mark the htlc as final settled.
err := h.ChainArbitratorConfig.PutFinalHtlcOutcome(
h.ChannelArbitratorConfig.ShortChanID, h.htlc.HtlcIndex, true,
)
if err != nil {
return err
}
// Send notification.
h.ChainArbitratorConfig.HtlcNotifier.NotifyFinalHtlcEvent(
models.CircuitKey{
ChanID: h.ShortChanID,
HtlcID: h.htlc.HtlcIndex,
},
channeldb.FinalHtlcInfo{
Settled: true,
Offchain: false,
},
)
// Create a resolver report for claiming of the htlc itself.
amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value)
reports := []*channeldb.ResolverReport{
{
OutPoint: h.htlcResolution.ClaimOutpoint,
Amount: amt,
ResolverType: channeldb.ResolverTypeIncomingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeClaimed,
SpendTxID: spendTx,
},
}
// If we have a success tx, we append a report to represent our first
// stage claim.
if h.htlcResolution.SignedSuccessTx != nil {
// If the SignedSuccessTx is not nil, we are claiming the htlc
// in two stages, so we need to create a report for the first
// stage transaction as well.
spendTx := h.htlcResolution.SignedSuccessTx
spendTxID := spendTx.TxHash()
report := &channeldb.ResolverReport{
OutPoint: spendTx.TxIn[0].PreviousOutPoint,
Amount: h.htlc.Amt.ToSatoshis(),
ResolverType: channeldb.ResolverTypeIncomingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
SpendTxID: &spendTxID,
}
reports = append(reports, report)
}
// Finally, we checkpoint the resolver with our report(s).
h.markResolved()
return h.Checkpoint(h, reports...)
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) Stop() {
h.log.Debugf("stopping...")
defer h.log.Debugf("stopped")
close(h.quit)
}
// report returns a report on the resolution state of the contract.
func (h *htlcSuccessResolver) report() *ContractReport {
// If the sign details are nil, the report will be created by handled
// by the nursery.
if h.htlcResolution.SignDetails == nil {
return nil
}
h.reportLock.Lock()
defer h.reportLock.Unlock()
cpy := h.currentReport
return &cpy
}
func (h *htlcSuccessResolver) initReport() {
// We create the initial report. This will only be reported for
// resolvers not handled by the nursery.
finalAmt := h.htlc.Amt.ToSatoshis()
if h.htlcResolution.SignedSuccessTx != nil {
finalAmt = btcutil.Amount(
h.htlcResolution.SignedSuccessTx.TxOut[0].Value,
)
}
h.currentReport = ContractReport{
Outpoint: h.htlcResolution.ClaimOutpoint,
Type: ReportOutputIncomingHtlc,
Amount: finalAmt,
MaturityHeight: h.htlcResolution.CsvDelay,
LimboBalance: finalAmt,
Stage: 1,
}
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) Encode(w io.Writer) error {
// First we'll encode our inner HTLC resolution.
if err := encodeIncomingResolution(w, &h.htlcResolution); err != nil {
return err
}
// Next, we'll write out the fields that are specified to the contract
// resolver.
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
return err
}
if _, err := w.Write(h.htlc.RHash[:]); err != nil {
return err
}
// We encode the sign details last for backwards compatibility.
err := encodeSignDetails(w, h.htlcResolution.SignDetails)
if err != nil {
return err
}
return nil
}
// newSuccessResolverFromReader attempts to decode an encoded ContractResolver
// from the passed Reader instance, returning an active ContractResolver
// instance.
func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*htlcSuccessResolver, error) {
h := &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
// First we'll decode our inner HTLC resolution.
if err := decodeIncomingResolution(r, &h.htlcResolution); err != nil {
return nil, err
}
// Next, we'll read all the fields that are specified to the contract
// resolver.
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
h.markResolved()
}
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
if _, err := io.ReadFull(r, h.htlc.RHash[:]); err != nil {
return nil, err
}
// Sign details is a new field that was added to the htlc resolution,
// so it is serialized last for backwards compatibility. We try to read
// it, but don't error out if there are not bytes left.
signDetails, err := decodeSignDetails(r)
if err == nil {
h.htlcResolution.SignDetails = signDetails
} else if err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
h.initReport()
h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h, nil
}
// Supplement adds additional information to the resolver that is required
// before Resolve() is called.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcSuccessResolver) Supplement(htlc channeldb.HTLC) {
h.htlc = htlc
}
// HtlcPoint returns the htlc's outpoint on the commitment tx.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcSuccessResolver) HtlcPoint() wire.OutPoint {
return h.htlcResolution.HtlcPoint()
}
// SupplementDeadline does nothing for an incoming htlc resolver.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcSuccessResolver) SupplementDeadline(_ fn.Option[int32]) {
}
// A compile time assertion to ensure htlcSuccessResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcSuccessResolver)(nil)
// isRemoteCommitOutput returns a bool to indicate whether the htlc output is
// on the remote commitment.
func (h *htlcSuccessResolver) isRemoteCommitOutput() bool {
// If we don't have a success transaction, then this means that this is
// an output on the remote party's commitment transaction.
return h.htlcResolution.SignedSuccessTx == nil
}
// isZeroFeeOutput returns a boolean indicating whether the htlc output is from
// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY.
func (h *htlcSuccessResolver) isZeroFeeOutput() bool {
// If we have non-nil SignDetails, this means it has a 2nd level HTLC
// transaction that is signed using sighash SINGLE|ANYONECANPAY (the
// case for anchor type channels). In this case we can re-sign it and
// attach fees at will.
return h.htlcResolution.SignedSuccessTx != nil &&
h.htlcResolution.SignDetails != nil
}
// isTaproot returns true if the resolver is for a taproot output.
func (h *htlcSuccessResolver) isTaproot() bool {
return txscript.IsPayToTaproot(
h.htlcResolution.SweepSignDesc.Output.PkScript,
)
}
// sweepRemoteCommitOutput creates a sweep request to sweep the HTLC output on
// the remote commitment via the direct preimage-spend.
func (h *htlcSuccessResolver) sweepRemoteCommitOutput() error {
// Before we can craft out sweeping transaction, we need to create an
// input which contains all the items required to add this input to a
// sweeping transaction, and generate a witness.
var inp input.Input
if h.isTaproot() {
inp = lnutils.Ptr(input.MakeTaprootHtlcSucceedInput(
&h.htlcResolution.ClaimOutpoint,
&h.htlcResolution.SweepSignDesc,
h.htlcResolution.Preimage[:],
h.broadcastHeight,
h.htlcResolution.CsvDelay,
input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
))
} else {
inp = lnutils.Ptr(input.MakeHtlcSucceedInput(
&h.htlcResolution.ClaimOutpoint,
&h.htlcResolution.SweepSignDesc,
h.htlcResolution.Preimage[:],
h.broadcastHeight,
h.htlcResolution.CsvDelay,
))
}
// Calculate the budget for this sweep.
budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value),
h.Budget.DeadlineHTLCRatio,
h.Budget.DeadlineHTLC,
)
deadline := fn.Some(int32(h.htlc.RefundTimeout))
log.Infof("%T(%x): offering direct-preimage HTLC output to sweeper "+
"with deadline=%v, budget=%v", h, h.htlc.RHash[:],
h.htlc.RefundTimeout, budget)
// We'll now offer the direct preimage HTLC to the sweeper.
_, err := h.Sweeper.SweepInput(
inp,
sweep.Params{
Budget: budget,
DeadlineHeight: deadline,
},
)
return err
}
// sweepSuccessTx attempts to sweep the second level success tx.
func (h *htlcSuccessResolver) sweepSuccessTx() error {
var secondLevelInput input.HtlcSecondLevelAnchorInput
if h.isTaproot() {
secondLevelInput = input.MakeHtlcSecondLevelSuccessTaprootInput(
h.htlcResolution.SignedSuccessTx,
h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
h.broadcastHeight, input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
)
} else {
secondLevelInput = input.MakeHtlcSecondLevelSuccessAnchorInput(
h.htlcResolution.SignedSuccessTx,
h.htlcResolution.SignDetails, h.htlcResolution.Preimage,
h.broadcastHeight,
)
}
// Calculate the budget for this sweep.
value := btcutil.Amount(secondLevelInput.SignDesc().Output.Value)
budget := calculateBudget(
value, h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC,
)
// The deadline would be the CLTV in this HTLC output. If we are the
// initiator of this force close, with the default
// `IncomingBroadcastDelta`, it means we have 10 blocks left when going
// onchain.
deadline := fn.Some(int32(h.htlc.RefundTimeout))
h.log.Infof("offering second-level HTLC success tx to sweeper with "+
"deadline=%v, budget=%v", h.htlc.RefundTimeout, budget)
// We'll now offer the second-level transaction to the sweeper.
_, err := h.Sweeper.SweepInput(
&secondLevelInput,
sweep.Params{
Budget: budget,
DeadlineHeight: deadline,
},
)
return err
}
// sweepSuccessTxOutput attempts to sweep the output of the second level
// success tx.
func (h *htlcSuccessResolver) sweepSuccessTxOutput() error {
h.log.Debugf("sweeping output %v from 2nd-level HTLC success tx",
h.htlcResolution.ClaimOutpoint)
// This should be non-blocking as we will only attempt to sweep the
// output when the second level tx has already been confirmed. In other
// words, waitForSpend will return immediately.
commitSpend, err := waitForSpend(
&h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint,
h.htlcResolution.SignDetails.SignDesc.Output.PkScript,
h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return err
}
// The HTLC success tx has a CSV lock that we must wait for, and if
// this is a lease enforced channel and we're the imitator, we may need
// to wait for longer.
waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend)
// Now that the sweeper has broadcasted the second-level transaction,
// it has confirmed, and we have checkpointed our state, we'll sweep
// the second level output. We report the resolver has moved the next
// stage.
h.reportLock.Lock()
h.currentReport.Stage = 2
h.currentReport.MaturityHeight = waitHeight
h.reportLock.Unlock()
if h.hasCLTV() {
log.Infof("%T(%x): waiting for CSV and CLTV lock to expire at "+
"height %v", h, h.htlc.RHash[:], waitHeight)
} else {
log.Infof("%T(%x): waiting for CSV lock to expire at height %v",
h, h.htlc.RHash[:], waitHeight)
}
// We'll use this input index to determine the second-level output
// index on the transaction, as the signatures requires the indexes to
// be the same. We don't look for the second-level output script
// directly, as there might be more than one HTLC output to the same
// pkScript.
op := &wire.OutPoint{
Hash: *commitSpend.SpenderTxHash,
Index: commitSpend.SpenderInputIndex,
}
// Let the sweeper sweep the second-level output now that the
// CSV/CLTV locks have expired.
var witType input.StandardWitnessType
if h.isTaproot() {
witType = input.TaprootHtlcAcceptedSuccessSecondLevel
} else {
witType = input.HtlcAcceptedSuccessSecondLevel
}
inp := h.makeSweepInput(
op, witType,
input.LeaseHtlcAcceptedSuccessSecondLevel,
&h.htlcResolution.SweepSignDesc,
h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
h.htlc.RHash, h.htlcResolution.ResolutionBlob,
)
// Calculate the budget for this sweep.
budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value),
h.Budget.NoDeadlineHTLCRatio,
h.Budget.NoDeadlineHTLC,
)
log.Infof("%T(%x): offering second-level success tx output to sweeper "+
"with no deadline and budget=%v at height=%v", h,
h.htlc.RHash[:], budget, waitHeight)
// TODO(yy): use the result chan returned from SweepInput.
_, err = h.Sweeper.SweepInput(
inp,
sweep.Params{
Budget: budget,
// For second level success tx, there's no rush to get
// it confirmed, so we use a nil deadline.
DeadlineHeight: fn.None[int32](),
},
)
return err
}
// resolveLegacySuccessTx handles an HTLC output from a pre-anchor type channel
// by broadcasting the second-level success transaction.
func (h *htlcSuccessResolver) resolveLegacySuccessTx() error {
// Otherwise we'll publish the second-level transaction directly and
// offer the resolution to the nursery to handle.
h.log.Infof("broadcasting legacy second-level success tx: %v",
h.htlcResolution.SignedSuccessTx.TxHash())
// We'll now broadcast the second layer transaction so we can kick off
// the claiming process.
//
// TODO(yy): offer it to the sweeper instead.
label := labels.MakeLabel(
labels.LabelTypeChannelClose, &h.ShortChanID,
)
err := h.PublishTx(h.htlcResolution.SignedSuccessTx, label)
if err != nil {
return err
}
// Fast-forward to resolve the output from the success tx if the it has
// already been sent to the UtxoNursery.
if h.outputIncubating {
return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint)
}
h.log.Infof("incubating incoming htlc output")
// Send the output to the incubator.
err = h.IncubateOutputs(
h.ChanPoint, fn.None[lnwallet.OutgoingHtlcResolution](),
fn.Some(h.htlcResolution),
h.broadcastHeight, fn.Some(int32(h.htlc.RefundTimeout)),
)
if err != nil {
return err
}
// Mark the output as incubating and checkpoint it.
h.outputIncubating = true
if err := h.Checkpoint(h); err != nil {
return err
}
// Move to resolve the output.
return h.resolveSuccessTxOutput(h.htlcResolution.ClaimOutpoint)
}
// resolveSuccessTx waits for the sweeping tx of the second-level success tx to
// confirm and offers the output from the success tx to the sweeper.
func (h *htlcSuccessResolver) resolveSuccessTx() error {
h.log.Infof("waiting for 2nd-level HTLC success transaction to confirm")
// Create aliases to make the code more readable.
outpoint := h.htlcResolution.SignedSuccessTx.TxIn[0].PreviousOutPoint
pkScript := h.htlcResolution.SignDetails.SignDesc.Output.PkScript
// Wait for the second level transaction to confirm.
commitSpend, err := waitForSpend(
&outpoint, pkScript, h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return err
}
// We'll use this input index to determine the second-level output
// index on the transaction, as the signatures requires the indexes to
// be the same. We don't look for the second-level output script
// directly, as there might be more than one HTLC output to the same
// pkScript.
op := wire.OutPoint{
Hash: *commitSpend.SpenderTxHash,
Index: commitSpend.SpenderInputIndex,
}
// If the 2nd-stage sweeping has already been started, we can
// fast-forward to start the resolving process for the stage two
// output.
if h.outputIncubating {
return h.resolveSuccessTxOutput(op)
}
// Now that the second-level transaction has confirmed, we checkpoint
// the state so we'll go to the next stage in case of restarts.
h.outputIncubating = true
if err := h.Checkpoint(h); err != nil {
log.Errorf("unable to Checkpoint: %v", err)
return err
}
h.log.Infof("2nd-level HTLC success tx=%v confirmed",
commitSpend.SpenderTxHash)
// Send the sweep request for the output from the success tx.
if err := h.sweepSuccessTxOutput(); err != nil {
return err
}
return h.resolveSuccessTxOutput(op)
}
// resolveSuccessTxOutput waits for the spend of the output from the 2nd-level
// success tx.
func (h *htlcSuccessResolver) resolveSuccessTxOutput(op wire.OutPoint) error {
// To wrap this up, we'll wait until the second-level transaction has
// been spent, then fully resolve the contract.
log.Infof("%T(%x): waiting for second-level HTLC output to be spent "+
"after csv_delay=%v", h, h.htlc.RHash[:],
h.htlcResolution.CsvDelay)
spend, err := waitForSpend(
&op, h.htlcResolution.SweepSignDesc.Output.PkScript,
h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return err
}
h.reportLock.Lock()
h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
h.currentReport.LimboBalance = 0
h.reportLock.Unlock()
return h.checkpointClaim(spend.SpenderTxHash)
}
// Launch creates an input based on the details of the incoming htlc resolution
// and offers it to the sweeper.
func (h *htlcSuccessResolver) Launch() error {
if h.isLaunched() {
h.log.Tracef("already launched")
return nil
}
h.log.Debugf("launching resolver...")
h.markLaunched()
switch {
// If we're already resolved, then we can exit early.
case h.IsResolved():
h.log.Errorf("already resolved")
return nil
// If this is an output on the remote party's commitment transaction,
// use the direct-spend path.
case h.isRemoteCommitOutput():
return h.sweepRemoteCommitOutput()
// If this is an anchor type channel, we now sweep either the
// second-level success tx or the output from the second-level success
// tx.
case h.isZeroFeeOutput():
// If the second-level success tx has already been swept, we
// can go ahead and sweep its output.
if h.outputIncubating {
return h.sweepSuccessTxOutput()
}
// Otherwise, sweep the second level tx.
return h.sweepSuccessTx()
// If this is a legacy channel type, the output is handled by the
// nursery via the Resolve so we do nothing here.
//
// TODO(yy): handle the legacy output by offering it to the sweeper.
default:
return nil
}
}
package contractcourt
import (
"encoding/binary"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/sweep"
)
// htlcTimeoutResolver is a ContractResolver that's capable of resolving an
// outgoing HTLC. The HTLC may be on our commitment transaction, or on the
// commitment transaction of the remote party. An output on our commitment
// transaction is considered fully resolved once the second-level transaction
// has been confirmed (and reached a sufficient depth). An output on the
// commitment transaction of the remote party is resolved once we detect a
// spend of the direct HTLC output using the timeout clause.
type htlcTimeoutResolver struct {
// htlcResolution contains all the information required to properly
// resolve this outgoing HTLC.
htlcResolution lnwallet.OutgoingHtlcResolution
// outputIncubating returns true if we've sent the output to the output
// incubator (utxo nursery).
outputIncubating bool
// broadcastHeight is the height that the original contract was
// broadcast to the main-chain at. We'll use this value to bound any
// historical queries to the chain for spends/confirmations.
//
// TODO(roasbeef): wrap above into definite resolution embedding?
broadcastHeight uint32
// htlc contains information on the htlc that we are resolving on-chain.
htlc channeldb.HTLC
// currentReport stores the current state of the resolver for reporting
// over the rpc interface. This should only be reported in case we have
// a non-nil SignDetails on the htlcResolution, otherwise the nursery
// will produce reports.
currentReport ContractReport
// reportLock prevents concurrent access to the resolver report.
reportLock sync.Mutex
contractResolverKit
htlcLeaseResolver
// incomingHTLCExpiryHeight is the absolute block height at which the
// incoming HTLC will expire. This is used as the deadline height as
// the outgoing HTLC must be swept before its incoming HTLC expires.
incomingHTLCExpiryHeight fn.Option[int32]
}
// newTimeoutResolver instantiates a new timeout htlc resolver.
func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution,
broadcastHeight uint32, htlc channeldb.HTLC,
resCfg ResolverConfig) *htlcTimeoutResolver {
h := &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(resCfg),
htlcResolution: res,
broadcastHeight: broadcastHeight,
htlc: htlc,
}
h.initReport()
h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h
}
// isTaproot returns true if the htlc output is a taproot output.
func (h *htlcTimeoutResolver) isTaproot() bool {
return txscript.IsPayToTaproot(
h.htlcResolution.SweepSignDesc.Output.PkScript,
)
}
// outpoint returns the outpoint of the HTLC output we're attempting to sweep.
func (h *htlcTimeoutResolver) outpoint() wire.OutPoint {
// The primary key for this resolver will be the outpoint of the HTLC
// on the commitment transaction itself. If this is our commitment,
// then the output can be found within the signed timeout tx,
// otherwise, it's just the ClaimOutpoint.
if h.htlcResolution.SignedTimeoutTx != nil {
return h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint
}
return h.htlcResolution.ClaimOutpoint
}
// ResolverKey returns an identifier which should be globally unique for this
// particular resolver within the chain the original contract resides within.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) ResolverKey() []byte {
key := newResolverID(h.outpoint())
return key[:]
}
const (
// expectedRemoteWitnessSuccessSize is the expected size of the witness
// on the remote commitment transaction for an outgoing HTLC that is
// swept on-chain by them with pre-image.
expectedRemoteWitnessSuccessSize = 5
// expectedLocalWitnessSuccessSize is the expected size of the witness
// on the local commitment transaction for an outgoing HTLC that is
// swept on-chain by them with pre-image.
expectedLocalWitnessSuccessSize = 3
// remotePreimageIndex index within the witness on the remote
// commitment transaction that will hold they pre-image if they go to
// sweep it on chain.
remotePreimageIndex = 3
// localPreimageIndex is the index within the witness on the local
// commitment transaction for an outgoing HTLC that will hold the
// pre-image if the remote party sweeps it.
localPreimageIndex = 1
// remoteTaprootWitnessSuccessSize is the expected size of the witness
// on the remote commitment for taproot channels. The spend path will
// look like
// - <sender sig> <receiver sig> <preimage> <success_script>
// <control_block>
remoteTaprootWitnessSuccessSize = 5
// localTaprootWitnessSuccessSize is the expected size of the witness
// on the local commitment for taproot channels. The spend path will
// look like
// - <receiver sig> <preimage> <success_script> <control_block>
localTaprootWitnessSuccessSize = 4
// taprootRemotePreimageIndex is the index within the witness on the
// taproot remote commitment spend that'll hold the pre-image if the
// remote party sweeps it.
taprootRemotePreimageIndex = 2
)
// claimCleanUp is a helper method that's called once the HTLC output is spent
// by the remote party. It'll extract the preimage, add it to the global cache,
// and finally send the appropriate clean up message.
func (h *htlcTimeoutResolver) claimCleanUp(
commitSpend *chainntnfs.SpendDetail) error {
// Depending on if this is our commitment or not, then we'll be looking
// for a different witness pattern.
spenderIndex := commitSpend.SpenderInputIndex
spendingInput := commitSpend.SpendingTx.TxIn[spenderIndex]
log.Infof("%T(%v): extracting preimage! remote party spent "+
"HTLC with tx=%v", h, h.htlcResolution.ClaimOutpoint,
spew.Sdump(commitSpend.SpendingTx))
// If this is the remote party's commitment, then we'll be looking for
// them to spend using the second-level success transaction.
var preimageBytes []byte
switch {
// For taproot channels, if the remote party has swept the HTLC, then
// the witness stack will look like:
//
// - <sender sig> <receiver sig> <preimage> <success_script>
// <control_block>
case h.isTaproot() && h.htlcResolution.SignedTimeoutTx == nil:
//nolint:ll
preimageBytes = spendingInput.Witness[taprootRemotePreimageIndex]
// The witness stack when the remote party sweeps the output on a
// regular channel to them looks like:
//
// - <0> <sender sig> <recvr sig> <preimage> <witness script>
case !h.isTaproot() && h.htlcResolution.SignedTimeoutTx == nil:
preimageBytes = spendingInput.Witness[remotePreimageIndex]
// If this is a taproot channel, and there's only a single witness
// element, then we're actually on the losing side of a breach
// attempt...
case h.isTaproot() && len(spendingInput.Witness) == 1:
return fmt.Errorf("breach attempt failed")
// Otherwise, they'll be spending directly from our commitment output.
// In which case the witness stack looks like:
//
// - <sig> <preimage> <witness script>
//
// For taproot channels, this looks like:
// - <receiver sig> <preimage> <success_script> <control_block>
//
// So we can target the same index.
default:
preimageBytes = spendingInput.Witness[localPreimageIndex]
}
preimage, err := lntypes.MakePreimage(preimageBytes)
if err != nil {
return fmt.Errorf("unable to create pre-image from witness: %w",
err)
}
log.Infof("%T(%v): extracting preimage=%v from on-chain "+
"spend!", h, h.htlcResolution.ClaimOutpoint, preimage)
// With the preimage obtained, we can now add it to the global cache.
if err := h.PreimageDB.AddPreimages(preimage); err != nil {
log.Errorf("%T(%v): unable to add witness to cache",
h, h.htlcResolution.ClaimOutpoint)
}
var pre [32]byte
copy(pre[:], preimage[:])
// Finally, we'll send the clean up message, mark ourselves as
// resolved, then exit.
if err := h.DeliverResolutionMsg(ResolutionMsg{
SourceChan: h.ShortChanID,
HtlcIndex: h.htlc.HtlcIndex,
PreImage: &pre,
}); err != nil {
return err
}
h.markResolved()
// Checkpoint our resolver with a report which reflects the preimage
// claim by the remote party.
amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value)
report := &channeldb.ResolverReport{
OutPoint: h.htlcResolution.ClaimOutpoint,
Amount: amt,
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeClaimed,
SpendTxID: commitSpend.SpenderTxHash,
}
return h.Checkpoint(h, report)
}
// chainDetailsToWatch returns the output and script which we use to watch for
// spends from the direct HTLC output on the commitment transaction.
func (h *htlcTimeoutResolver) chainDetailsToWatch() (*wire.OutPoint, []byte, error) {
// If there's no timeout transaction, it means we are spending from a
// remote commit, then the claim output is the output directly on the
// commitment transaction, so we'll just use that.
if h.htlcResolution.SignedTimeoutTx == nil {
outPointToWatch := h.htlcResolution.ClaimOutpoint
scriptToWatch := h.htlcResolution.SweepSignDesc.Output.PkScript
return &outPointToWatch, scriptToWatch, nil
}
// If SignedTimeoutTx is not nil, this is the local party's commitment,
// and we'll need to grab watch the output that our timeout transaction
// points to. We can directly grab the outpoint, then also extract the
// witness script (the last element of the witness stack) to
// re-construct the pkScript we need to watch.
//
//nolint:ll
outPointToWatch := h.htlcResolution.SignedTimeoutTx.TxIn[0].PreviousOutPoint
witness := h.htlcResolution.SignedTimeoutTx.TxIn[0].Witness
var (
scriptToWatch []byte
err error
)
switch {
// For taproot channels, then final witness element is the control
// block, and the one before it the witness script. We can use both of
// these together to reconstruct the taproot output key, then map that
// into a v1 witness program.
case h.isTaproot():
// First, we'll parse the control block into something we can
// use.
ctrlBlockBytes := witness[len(witness)-1]
ctrlBlock, err := txscript.ParseControlBlock(ctrlBlockBytes)
if err != nil {
return nil, nil, err
}
// With the control block, we'll grab the witness script, then
// use that to derive the tapscript root.
witnessScript := witness[len(witness)-2]
tapscriptRoot := ctrlBlock.RootHash(witnessScript)
// Once we have the root, then we can derive the output key
// from the internal key, then turn that into a witness
// program.
outputKey := txscript.ComputeTaprootOutputKey(
ctrlBlock.InternalKey, tapscriptRoot,
)
scriptToWatch, err = txscript.PayToTaprootScript(outputKey)
if err != nil {
return nil, nil, err
}
// For regular channels, the witness script is the last element on the
// stack. We can then use this to re-derive the output that we're
// watching on chain.
default:
scriptToWatch, err = input.WitnessScriptHash(
witness[len(witness)-1],
)
}
if err != nil {
return nil, nil, err
}
return &outPointToWatch, scriptToWatch, nil
}
// isPreimageSpend returns true if the passed spend on the specified commitment
// is a success spend that reveals the pre-image or not.
func isPreimageSpend(isTaproot bool, spend *chainntnfs.SpendDetail,
localCommit bool) bool {
// Based on the spending input index and transaction, obtain the
// witness that tells us what type of spend this is.
spenderIndex := spend.SpenderInputIndex
spendingInput := spend.SpendingTx.TxIn[spenderIndex]
spendingWitness := spendingInput.Witness
switch {
// If this is a taproot remote commitment, then we can detect the type
// of spend via the leaf revealed in the control block and the witness
// itself.
//
// The keyspend (revocation path) is just a single signature, while the
// timeout and success paths are most distinct.
//
// The success path will look like:
//
// - <sender sig> <receiver sig> <preimage> <success_script>
// <control_block>
case isTaproot && !localCommit:
return checkSizeAndIndex(
spendingWitness, remoteTaprootWitnessSuccessSize,
taprootRemotePreimageIndex,
)
// Otherwise, then if this is our local commitment transaction, then if
// they're sweeping the transaction, it'll be directly from the output,
// skipping the second level.
//
// In this case, then there're two main tapscript paths, with the
// success case look like:
//
// - <receiver sig> <preimage> <success_script> <control_block>
case isTaproot && localCommit:
return checkSizeAndIndex(
spendingWitness, localTaprootWitnessSuccessSize,
localPreimageIndex,
)
// If this is the non-taproot, remote commitment then the only possible
// spends for outgoing HTLCs are:
//
// RECVR: <0> <sender sig> <recvr sig> <preimage> (2nd level success spend)
// REVOK: <sig> <key>
// SENDR: <sig> 0
//
// In this case, if 5 witness elements are present (factoring the
// witness script), and the 3rd element is the size of the pre-image,
// then this is a remote spend. If not, then we swept it ourselves, or
// revoked their output.
case !isTaproot && !localCommit:
return checkSizeAndIndex(
spendingWitness, expectedRemoteWitnessSuccessSize,
remotePreimageIndex,
)
// Otherwise, for our non-taproot commitment, the only possible spends
// for an outgoing HTLC are:
//
// SENDR: <0> <sendr sig> <recvr sig> <0> (2nd level timeout)
// RECVR: <recvr sig> <preimage>
// REVOK: <revoke sig> <revoke key>
//
// So the only success case has the pre-image as the 2nd (index 1)
// element in the witness.
case !isTaproot:
fallthrough
default:
return checkSizeAndIndex(
spendingWitness, expectedLocalWitnessSuccessSize,
localPreimageIndex,
)
}
}
// checkSizeAndIndex checks that the witness is of the expected size and that
// the witness element at the specified index is of the expected size.
func checkSizeAndIndex(witness wire.TxWitness, size, index int) bool {
if len(witness) != size {
return false
}
return len(witness[index]) == lntypes.HashSize
}
// Resolve kicks off full resolution of an outgoing HTLC output. If it's our
// commitment, it isn't resolved until we see the second level HTLC txn
// confirmed. If it's the remote party's commitment, we don't resolve until we
// see a direct sweep via the timeout clause.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
// If we're already resolved, then we can exit early.
if h.IsResolved() {
h.log.Errorf("already resolved")
return nil, nil
}
// If this is an output on the remote party's commitment transaction,
// use the direct-spend path to sweep the htlc.
if h.isRemoteCommitOutput() {
return nil, h.resolveRemoteCommitOutput()
}
// If this is a zero-fee HTLC, we now handle the spend from our
// commitment transaction.
if h.isZeroFeeOutput() {
return nil, h.resolveTimeoutTx()
}
// If this is an output on our own commitment using pre-anchor channel
// type, we will let the utxo nursery handle it.
return nil, h.resolveSecondLevelTxLegacy()
}
// sweepTimeoutTx sends a second level timeout transaction to the sweeper.
// This transaction uses the SINGLE|ANYONECANPAY flag.
func (h *htlcTimeoutResolver) sweepTimeoutTx() error {
var inp input.Input
if h.isTaproot() {
inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutTaprootInput(
h.htlcResolution.SignedTimeoutTx,
h.htlcResolution.SignDetails,
h.broadcastHeight,
input.WithResolutionBlob(
h.htlcResolution.ResolutionBlob,
),
))
} else {
inp = lnutils.Ptr(input.MakeHtlcSecondLevelTimeoutAnchorInput(
h.htlcResolution.SignedTimeoutTx,
h.htlcResolution.SignDetails,
h.broadcastHeight,
))
}
// Calculate the budget.
budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value),
h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC,
)
h.log.Infof("offering 2nd-level HTLC timeout tx to sweeper "+
"with deadline=%v, budget=%v", h.incomingHTLCExpiryHeight,
budget)
// For an outgoing HTLC, it must be swept before the RefundTimeout of
// its incoming HTLC is reached.
_, err := h.Sweeper.SweepInput(
inp,
sweep.Params{
Budget: budget,
DeadlineHeight: h.incomingHTLCExpiryHeight,
},
)
if err != nil {
return err
}
return nil
}
// resolveSecondLevelTxLegacy sends a second level timeout transaction to the
// utxo nursery. This transaction uses the legacy SIGHASH_ALL flag.
func (h *htlcTimeoutResolver) resolveSecondLevelTxLegacy() error {
h.log.Debug("incubating htlc output")
// The utxo nursery will take care of broadcasting the second-level
// timeout tx and sweeping its output once it confirms.
err := h.IncubateOutputs(
h.ChanPoint, fn.Some(h.htlcResolution),
fn.None[lnwallet.IncomingHtlcResolution](),
h.broadcastHeight, h.incomingHTLCExpiryHeight,
)
if err != nil {
return err
}
return h.resolveTimeoutTx()
}
// sweepDirectHtlcOutput sends the direct spend of the HTLC output to the
// sweeper. This is used when the remote party goes on chain, and we're able to
// sweep an HTLC we offered after a timeout. Only the CLTV encumbered outputs
// are resolved via this path.
func (h *htlcTimeoutResolver) sweepDirectHtlcOutput() error {
var htlcWitnessType input.StandardWitnessType
if h.isTaproot() {
htlcWitnessType = input.TaprootHtlcOfferedRemoteTimeout
} else {
htlcWitnessType = input.HtlcOfferedRemoteTimeout
}
sweepInput := input.NewCsvInputWithCltv(
&h.htlcResolution.ClaimOutpoint, htlcWitnessType,
&h.htlcResolution.SweepSignDesc, h.broadcastHeight,
h.htlcResolution.CsvDelay, h.htlcResolution.Expiry,
input.WithResolutionBlob(h.htlcResolution.ResolutionBlob),
)
// Calculate the budget.
budget := calculateBudget(
btcutil.Amount(sweepInput.SignDesc().Output.Value),
h.Budget.DeadlineHTLCRatio, h.Budget.DeadlineHTLC,
)
log.Infof("%T(%x): offering offered remote timeout HTLC output to "+
"sweeper with deadline %v and budget=%v at height=%v",
h, h.htlc.RHash[:], h.incomingHTLCExpiryHeight, budget,
h.broadcastHeight)
_, err := h.Sweeper.SweepInput(
sweepInput,
sweep.Params{
Budget: budget,
// This is an outgoing HTLC, so we want to make sure
// that we sweep it before the incoming HTLC expires.
DeadlineHeight: h.incomingHTLCExpiryHeight,
},
)
if err != nil {
return err
}
return nil
}
// watchHtlcSpend watches for a spend of the HTLC output. For neutrino backend,
// it will check blocks for the confirmed spend. For btcd and bitcoind, it will
// check both the mempool and the blocks.
func (h *htlcTimeoutResolver) watchHtlcSpend() (*chainntnfs.SpendDetail,
error) {
// TODO(yy): outpointToWatch is always h.HtlcOutpoint(), can refactor
// to remove the redundancy.
outpointToWatch, scriptToWatch, err := h.chainDetailsToWatch()
if err != nil {
return nil, err
}
// If there's no mempool configured, which is the case for SPV node
// such as neutrino, then we will watch for confirmed spend only.
if h.Mempool == nil {
return h.waitForConfirmedSpend(outpointToWatch, scriptToWatch)
}
// Watch for a spend of the HTLC output in both the mempool and blocks.
return h.waitForMempoolOrBlockSpend(*outpointToWatch, scriptToWatch)
}
// waitForConfirmedSpend waits for the HTLC output to be spent and confirmed in
// a block, returns the spend details.
func (h *htlcTimeoutResolver) waitForConfirmedSpend(op *wire.OutPoint,
pkScript []byte) (*chainntnfs.SpendDetail, error) {
// We'll block here until either we exit, or the HTLC output on the
// commitment transaction has been spent.
spend, err := waitForSpend(
op, pkScript, h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return nil, err
}
return spend, nil
}
// Stop signals the resolver to cancel any current resolution processes, and
// suspend.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Stop() {
h.log.Debugf("stopping...")
defer h.log.Debugf("stopped")
close(h.quit)
}
// report returns a report on the resolution state of the contract.
func (h *htlcTimeoutResolver) report() *ContractReport {
// If we have a SignedTimeoutTx but no SignDetails, this is a local
// commitment for a non-anchor channel, which was handled by the utxo
// nursery.
if h.htlcResolution.SignDetails == nil && h.
htlcResolution.SignedTimeoutTx != nil {
return nil
}
h.reportLock.Lock()
defer h.reportLock.Unlock()
cpy := h.currentReport
return &cpy
}
func (h *htlcTimeoutResolver) initReport() {
// We create the initial report. This will only be reported for
// resolvers not handled by the nursery.
finalAmt := h.htlc.Amt.ToSatoshis()
if h.htlcResolution.SignedTimeoutTx != nil {
finalAmt = btcutil.Amount(
h.htlcResolution.SignedTimeoutTx.TxOut[0].Value,
)
}
// If there's no timeout transaction, then we're already effectively in
// level two.
stage := uint32(1)
if h.htlcResolution.SignedTimeoutTx == nil {
stage = 2
}
h.currentReport = ContractReport{
Outpoint: h.htlcResolution.ClaimOutpoint,
Type: ReportOutputOutgoingHtlc,
Amount: finalAmt,
MaturityHeight: h.htlcResolution.Expiry,
LimboBalance: finalAmt,
Stage: stage,
}
}
// Encode writes an encoded version of the ContractResolver into the passed
// Writer.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Encode(w io.Writer) error {
// First, we'll write out the relevant fields of the
// OutgoingHtlcResolution to the writer.
if err := encodeOutgoingResolution(w, &h.htlcResolution); err != nil {
return err
}
// With that portion written, we can now write out the fields specific
// to the resolver itself.
if err := binary.Write(w, endian, h.outputIncubating); err != nil {
return err
}
if err := binary.Write(w, endian, h.IsResolved()); err != nil {
return err
}
if err := binary.Write(w, endian, h.broadcastHeight); err != nil {
return err
}
if err := binary.Write(w, endian, h.htlc.HtlcIndex); err != nil {
return err
}
// We encode the sign details last for backwards compatibility.
err := encodeSignDetails(w, h.htlcResolution.SignDetails)
if err != nil {
return err
}
return nil
}
// newTimeoutResolverFromReader attempts to decode an encoded ContractResolver
// from the passed Reader instance, returning an active ContractResolver
// instance.
func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) (
*htlcTimeoutResolver, error) {
h := &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
// First, we'll read out all the mandatory fields of the
// OutgoingHtlcResolution that we store.
if err := decodeOutgoingResolution(r, &h.htlcResolution); err != nil {
return nil, err
}
// With those fields read, we can now read back the fields that are
// specific to the resolver itself.
if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return nil, err
}
var resolved bool
if err := binary.Read(r, endian, &resolved); err != nil {
return nil, err
}
if resolved {
h.markResolved()
}
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return nil, err
}
if err := binary.Read(r, endian, &h.htlc.HtlcIndex); err != nil {
return nil, err
}
// Sign details is a new field that was added to the htlc resolution,
// so it is serialized last for backwards compatibility. We try to read
// it, but don't error out if there are not bytes left.
signDetails, err := decodeSignDetails(r)
if err == nil {
h.htlcResolution.SignDetails = signDetails
} else if err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
h.initReport()
h.initLogger(fmt.Sprintf("%T(%v)", h, h.outpoint()))
return h, nil
}
// Supplement adds additional information to the resolver that is required
// before Resolve() is called.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcTimeoutResolver) Supplement(htlc channeldb.HTLC) {
h.htlc = htlc
}
// HtlcPoint returns the htlc's outpoint on the commitment tx.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcTimeoutResolver) HtlcPoint() wire.OutPoint {
return h.htlcResolution.HtlcPoint()
}
// SupplementDeadline sets the incomingHTLCExpiryHeight for this outgoing htlc
// resolver.
//
// NOTE: Part of the htlcContractResolver interface.
func (h *htlcTimeoutResolver) SupplementDeadline(d fn.Option[int32]) {
h.incomingHTLCExpiryHeight = d
}
// A compile time assertion to ensure htlcTimeoutResolver meets the
// ContractResolver interface.
var _ htlcContractResolver = (*htlcTimeoutResolver)(nil)
// spendResult is used to hold the result of a spend event from either a
// mempool spend or a block spend.
type spendResult struct {
// spend contains the details of the spend.
spend *chainntnfs.SpendDetail
// err is the error that occurred during the spend notification.
err error
}
// waitForMempoolOrBlockSpend waits for the htlc output to be spent by a
// transaction that's either be found in the mempool or in a block.
func (h *htlcTimeoutResolver) waitForMempoolOrBlockSpend(op wire.OutPoint,
pkScript []byte) (*chainntnfs.SpendDetail, error) {
log.Infof("%T(%v): waiting for spent of HTLC output %v to be found "+
"in mempool or block", h, h.htlcResolution.ClaimOutpoint, op)
// Subscribe for block spent(confirmed).
blockSpent, err := h.Notifier.RegisterSpendNtfn(
&op, pkScript, h.broadcastHeight,
)
if err != nil {
return nil, fmt.Errorf("register spend: %w", err)
}
// Subscribe for mempool spent(unconfirmed).
mempoolSpent, err := h.Mempool.SubscribeMempoolSpent(op)
if err != nil {
return nil, fmt.Errorf("register mempool spend: %w", err)
}
// Create a result chan that will be used to receive the spending
// events.
result := make(chan *spendResult, 2)
// Create a goroutine that will wait for either a mempool spend or a
// block spend.
//
// NOTE: no need to use waitgroup here as when the resolver exits, the
// goroutine will return on the quit channel.
go h.consumeSpendEvents(result, blockSpent.Spend, mempoolSpent.Spend)
// Wait for the spend event to be received.
select {
case event := <-result:
// Cancel the mempool subscription as we don't need it anymore.
h.Mempool.CancelMempoolSpendEvent(mempoolSpent)
return event.spend, event.err
case <-h.quit:
return nil, errResolverShuttingDown
}
}
// consumeSpendEvents consumes the spend events from the block and mempool
// subscriptions. It exits when a spend event is received from the block, or
// the resolver itself quits. When a spend event is received from the mempool,
// however, it won't exit but continuing to wait for a spend event from the
// block subscription.
//
// NOTE: there could be a case where we found the preimage in the mempool,
// which will be added to our preimage beacon and settle the incoming link,
// meanwhile the timeout sweep tx confirms. This outgoing HTLC is "free" money
// and is not swept here.
//
// TODO(yy): sweep the outgoing htlc if it's confirmed.
func (h *htlcTimeoutResolver) consumeSpendEvents(resultChan chan *spendResult,
blockSpent, mempoolSpent <-chan *chainntnfs.SpendDetail) {
op := h.HtlcPoint()
// Create a result chan to hold the results.
result := &spendResult{}
// Wait for a spend event to arrive.
for {
select {
// If a spend event is received from the block, this outgoing
// htlc is spent either by the remote via the preimage or by us
// via the timeout. We can exit the loop and `claimCleanUp`
// will feed the preimage to the beacon if found. This treats
// the block as the final judge and the preimage spent won't
// appear in the mempool afterwards.
//
// NOTE: if a reorg happens, the preimage spend can appear in
// the mempool again. Though a rare case, we should handle it
// in a dedicated reorg system.
case spendDetail, ok := <-blockSpent:
if !ok {
result.err = fmt.Errorf("block spent err: %w",
errResolverShuttingDown)
} else {
log.Debugf("Found confirmed spend of HTLC "+
"output %s in tx=%s", op,
spendDetail.SpenderTxHash)
result.spend = spendDetail
// Once confirmed, persist the state on disk if
// we haven't seen the output's spending tx in
// mempool before.
}
// Send the result and exit the loop.
resultChan <- result
return
// If a spend event is received from the mempool, this can be
// either the 2nd stage timeout tx or a preimage spend from the
// remote. We will further check whether the spend reveals the
// preimage and add it to the preimage beacon to settle the
// incoming link.
//
// NOTE: we won't exit the loop here so we can continue to
// watch for the block spend to check point the resolution.
case spendDetail, ok := <-mempoolSpent:
if !ok {
result.err = fmt.Errorf("mempool spent err: %w",
errResolverShuttingDown)
// This is an internal error so we exit.
resultChan <- result
return
}
log.Debugf("Found mempool spend of HTLC output %s "+
"in tx=%s", op, spendDetail.SpenderTxHash)
// Check whether the spend reveals the preimage, if not
// continue the loop.
hasPreimage := isPreimageSpend(
h.isTaproot(), spendDetail,
!h.isRemoteCommitOutput(),
)
if !hasPreimage {
log.Debugf("HTLC output %s spent doesn't "+
"reveal preimage", op)
continue
}
// Found the preimage spend, send the result and
// continue the loop.
result.spend = spendDetail
resultChan <- result
continue
// If the resolver exits, we exit the goroutine.
case <-h.quit:
result.err = errResolverShuttingDown
resultChan <- result
return
}
}
}
// isRemoteCommitOutput returns a bool to indicate whether the htlc output is
// on the remote commitment.
func (h *htlcTimeoutResolver) isRemoteCommitOutput() bool {
// If we don't have a timeout transaction, then this means that this is
// an output on the remote party's commitment transaction.
return h.htlcResolution.SignedTimeoutTx == nil
}
// isZeroFeeOutput returns a boolean indicating whether the htlc output is from
// a anchor-enabled channel, which uses the sighash SINGLE|ANYONECANPAY.
func (h *htlcTimeoutResolver) isZeroFeeOutput() bool {
// If we have non-nil SignDetails, this means it has a 2nd level HTLC
// transaction that is signed using sighash SINGLE|ANYONECANPAY (the
// case for anchor type channels). In this case we can re-sign it and
// attach fees at will.
return h.htlcResolution.SignedTimeoutTx != nil &&
h.htlcResolution.SignDetails != nil
}
// waitHtlcSpendAndCheckPreimage waits for the htlc output to be spent and
// checks whether the spending reveals the preimage. If the preimage is found,
// it will be added to the preimage beacon to settle the incoming link, and a
// nil spend details will be returned. Otherwise, the spend details will be
// returned, indicating this is a non-preimage spend.
func (h *htlcTimeoutResolver) waitHtlcSpendAndCheckPreimage() (
*chainntnfs.SpendDetail, error) {
// Wait for the htlc output to be spent, which can happen in one of the
// paths,
// 1. The remote party spends the htlc output using the preimage.
// 2. The local party spends the htlc timeout tx from the local
// commitment.
// 3. The local party spends the htlc output directlt from the remote
// commitment.
spend, err := h.watchHtlcSpend()
if err != nil {
return nil, err
}
// If the spend reveals the pre-image, then we'll enter the clean up
// workflow to pass the preimage back to the incoming link, add it to
// the witness cache, and exit.
if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
return nil, h.claimCleanUp(spend)
}
return spend, nil
}
// sweepTimeoutTxOutput attempts to sweep the output of the second level
// timeout tx.
func (h *htlcTimeoutResolver) sweepTimeoutTxOutput() error {
h.log.Debugf("sweeping output %v from 2nd-level HTLC timeout tx",
h.htlcResolution.ClaimOutpoint)
// This should be non-blocking as we will only attempt to sweep the
// output when the second level tx has already been confirmed. In other
// words, waitHtlcSpendAndCheckPreimage will return immediately.
commitSpend, err := h.waitHtlcSpendAndCheckPreimage()
if err != nil {
return err
}
// Exit early if the spend is nil, as this means it's a remote spend
// using the preimage path, which is handled in claimCleanUp.
if commitSpend == nil {
h.log.Infof("preimage spend detected, skipping 2nd-level " +
"HTLC output sweep")
return nil
}
waitHeight := h.deriveWaitHeight(h.htlcResolution.CsvDelay, commitSpend)
// Now that the sweeper has broadcasted the second-level transaction,
// it has confirmed, and we have checkpointed our state, we'll sweep
// the second level output. We report the resolver has moved the next
// stage.
h.reportLock.Lock()
h.currentReport.Stage = 2
h.currentReport.MaturityHeight = waitHeight
h.reportLock.Unlock()
if h.hasCLTV() {
h.log.Infof("waiting for CSV and CLTV lock to expire at "+
"height %v", waitHeight)
} else {
h.log.Infof("waiting for CSV lock to expire at height %v",
waitHeight)
}
// We'll use this input index to determine the second-level output
// index on the transaction, as the signatures requires the indexes to
// be the same. We don't look for the second-level output script
// directly, as there might be more than one HTLC output to the same
// pkScript.
op := &wire.OutPoint{
Hash: *commitSpend.SpenderTxHash,
Index: commitSpend.SpenderInputIndex,
}
var witType input.StandardWitnessType
if h.isTaproot() {
witType = input.TaprootHtlcOfferedTimeoutSecondLevel
} else {
witType = input.HtlcOfferedTimeoutSecondLevel
}
// Let the sweeper sweep the second-level output now that the CSV/CLTV
// locks have expired.
inp := h.makeSweepInput(
op, witType,
input.LeaseHtlcOfferedTimeoutSecondLevel,
&h.htlcResolution.SweepSignDesc,
h.htlcResolution.CsvDelay, uint32(commitSpend.SpendingHeight),
h.htlc.RHash, h.htlcResolution.ResolutionBlob,
)
// Calculate the budget for this sweep.
budget := calculateBudget(
btcutil.Amount(inp.SignDesc().Output.Value),
h.Budget.NoDeadlineHTLCRatio,
h.Budget.NoDeadlineHTLC,
)
h.log.Infof("offering output from 2nd-level timeout tx to sweeper "+
"with no deadline and budget=%v", budget)
// TODO(yy): use the result chan returned from SweepInput to get the
// confirmation status of this sweeping tx so we don't need to make
// another subscription via `RegisterSpendNtfn` for this outpoint here
// in the resolver.
_, err = h.Sweeper.SweepInput(
inp,
sweep.Params{
Budget: budget,
// For second level success tx, there's no rush
// to get it confirmed, so we use a nil
// deadline.
DeadlineHeight: fn.None[int32](),
},
)
return err
}
// checkpointStageOne creates a checkpoint for the first stage of the htlc
// timeout transaction. This is used to ensure that the resolver can resume
// watching for the second stage spend in case of a restart.
func (h *htlcTimeoutResolver) checkpointStageOne(
spendTxid chainhash.Hash) error {
h.log.Debugf("checkpoint stage one spend of HTLC output %v, spent "+
"in tx %v", h.outpoint(), spendTxid)
// Now that the second-level transaction has confirmed, we checkpoint
// the state so we'll go to the next stage in case of restarts.
h.outputIncubating = true
// Create stage-one report.
report := &channeldb.ResolverReport{
OutPoint: h.outpoint(),
Amount: h.htlc.Amt.ToSatoshis(),
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeFirstStage,
SpendTxID: &spendTxid,
}
// At this point, the second-level transaction is sufficiently
// confirmed. We can now send back our clean up message, failing the
// HTLC on the incoming link.
failureMsg := &lnwire.FailPermanentChannelFailure{}
err := h.DeliverResolutionMsg(ResolutionMsg{
SourceChan: h.ShortChanID,
HtlcIndex: h.htlc.HtlcIndex,
Failure: failureMsg,
})
if err != nil {
return err
}
return h.Checkpoint(h, report)
}
// checkpointClaim checkpoints the timeout resolver with the reports it needs.
func (h *htlcTimeoutResolver) checkpointClaim(
spendDetail *chainntnfs.SpendDetail) error {
h.log.Infof("resolving htlc with incoming fail msg, output=%v "+
"confirmed in tx=%v", spendDetail.SpentOutPoint,
spendDetail.SpenderTxHash)
// Create a resolver report for the claiming of the HTLC.
amt := btcutil.Amount(h.htlcResolution.SweepSignDesc.Output.Value)
report := &channeldb.ResolverReport{
OutPoint: *spendDetail.SpentOutPoint,
Amount: amt,
ResolverType: channeldb.ResolverTypeOutgoingHtlc,
ResolverOutcome: channeldb.ResolverOutcomeTimeout,
SpendTxID: spendDetail.SpenderTxHash,
}
// Finally, we checkpoint the resolver with our report(s).
h.markResolved()
return h.Checkpoint(h, report)
}
// resolveRemoteCommitOutput handles sweeping an HTLC output on the remote
// commitment with via the timeout path. In this case we can sweep the output
// directly, and don't have to broadcast a second-level transaction.
func (h *htlcTimeoutResolver) resolveRemoteCommitOutput() error {
h.log.Debug("waiting for direct-timeout spend of the htlc to confirm")
// Wait for the direct-timeout HTLC sweep tx to confirm.
spend, err := h.watchHtlcSpend()
if err != nil {
return err
}
// If the spend reveals the preimage, then we'll enter the clean up
// workflow to pass the preimage back to the incoming link, add it to
// the witness cache, and exit.
if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
return h.claimCleanUp(spend)
}
// Send the clean up msg to fail the incoming HTLC.
failureMsg := &lnwire.FailPermanentChannelFailure{}
err = h.DeliverResolutionMsg(ResolutionMsg{
SourceChan: h.ShortChanID,
HtlcIndex: h.htlc.HtlcIndex,
Failure: failureMsg,
})
if err != nil {
return err
}
// TODO(yy): should also update the `RecoveredBalance` and
// `LimboBalance` like other paths?
// Checkpoint the resolver, and write the outcome to disk.
return h.checkpointClaim(spend)
}
// resolveTimeoutTx waits for the sweeping tx of the second-level
// timeout tx to confirm and offers the output from the timeout tx to the
// sweeper.
func (h *htlcTimeoutResolver) resolveTimeoutTx() error {
h.log.Debug("waiting for first-stage 2nd-level HTLC timeout tx to " +
"confirm")
// Wait for the second level transaction to confirm.
spend, err := h.watchHtlcSpend()
if err != nil {
return err
}
// If the spend reveals the preimage, then we'll enter the clean up
// workflow to pass the preimage back to the incoming link, add it to
// the witness cache, and exit.
if isPreimageSpend(h.isTaproot(), spend, !h.isRemoteCommitOutput()) {
return h.claimCleanUp(spend)
}
op := h.htlcResolution.ClaimOutpoint
spenderTxid := *spend.SpenderTxHash
// If the timeout tx is a re-signed tx, we will need to find the actual
// spent outpoint from the spending tx.
if h.isZeroFeeOutput() {
op = wire.OutPoint{
Hash: spenderTxid,
Index: spend.SpenderInputIndex,
}
}
// If the 2nd-stage sweeping has already been started, we can
// fast-forward to start the resolving process for the stage two
// output.
if h.outputIncubating {
return h.resolveTimeoutTxOutput(op)
}
h.log.Infof("2nd-level HTLC timeout tx=%v confirmed", spenderTxid)
// Start the process to sweep the output from the timeout tx.
if h.isZeroFeeOutput() {
err = h.sweepTimeoutTxOutput()
if err != nil {
return err
}
}
// Create a checkpoint since the timeout tx is confirmed and the sweep
// request has been made.
if err := h.checkpointStageOne(spenderTxid); err != nil {
return err
}
// Start the resolving process for the stage two output.
return h.resolveTimeoutTxOutput(op)
}
// resolveTimeoutTxOutput waits for the spend of the output from the 2nd-level
// timeout tx.
func (h *htlcTimeoutResolver) resolveTimeoutTxOutput(op wire.OutPoint) error {
h.log.Debugf("waiting for second-stage 2nd-level timeout tx output %v "+
"to be spent after csv_delay=%v", op, h.htlcResolution.CsvDelay)
spend, err := waitForSpend(
&op, h.htlcResolution.SweepSignDesc.Output.PkScript,
h.broadcastHeight, h.Notifier, h.quit,
)
if err != nil {
return err
}
h.reportLock.Lock()
h.currentReport.RecoveredBalance = h.currentReport.LimboBalance
h.currentReport.LimboBalance = 0
h.reportLock.Unlock()
return h.checkpointClaim(spend)
}
// Launch creates an input based on the details of the outgoing htlc resolution
// and offers it to the sweeper.
func (h *htlcTimeoutResolver) Launch() error {
if h.isLaunched() {
h.log.Tracef("already launched")
return nil
}
h.log.Debugf("launching resolver...")
h.markLaunched()
switch {
// If we're already resolved, then we can exit early.
case h.IsResolved():
h.log.Errorf("already resolved")
return nil
// If this is an output on the remote party's commitment transaction,
// use the direct timeout spend path.
//
// NOTE: When the outputIncubating is false, it means that the output
// has been offered to the utxo nursery as starting in 0.18.4, we
// stopped marking this flag for direct timeout spends (#9062). In that
// case, we will do nothing and let the utxo nursery handle it.
case h.isRemoteCommitOutput() && !h.outputIncubating:
return h.sweepDirectHtlcOutput()
// If this is an anchor type channel, we now sweep either the
// second-level timeout tx or the output from the second-level timeout
// tx.
case h.isZeroFeeOutput():
// If the second-level timeout tx has already been swept, we
// can go ahead and sweep its output.
if h.outputIncubating {
return h.sweepTimeoutTxOutput()
}
// Otherwise, sweep the second level tx.
return h.sweepTimeoutTx()
// If this is an output on our own commitment using pre-anchor channel
// type, we will let the utxo nursery handle it via Resolve.
//
// TODO(yy): handle the legacy output by offering it to the sweeper.
default:
return nil
}
}
package contractcourt
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
var (
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
log btclog.Logger
// brarLog is the logger used by the breach arb.
brarLog btclog.Logger
// utxnLog is the logger used by the utxo nursary.
utxnLog btclog.Logger
)
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("CNCT", nil))
UseBreachLogger(build.NewSubLogger("BRAR", nil))
UseNurseryLogger(build.NewSubLogger("UTXN", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
// UseBreachLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseBreachLogger(logger btclog.Logger) {
brarLog = logger
}
// UseNurseryLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseNurseryLogger(logger btclog.Logger) {
utxnLog = logger
}
package contractcourt
import (
"bytes"
"errors"
"fmt"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/kvdb"
)
// Overview of Nursery Store Storage Hierarchy
//
// CHAIN SEGMENTATION
//
// The root directory of a nursery store is bucketed by the chain hash and
// the 'utxn' prefix. This allows multiple utxo nurseries for distinct chains
// to simultaneously use the same channel.DB instance. This is critical for
// providing replay protection and more to isolate chain-specific data in the
// multichain setting.
//
// utxn<chain-hash>/
// |
// | CHANNEL INDEX
// |
// | The channel index contains a directory for each channel that has a
// | non-zero number of outputs being tracked by the nursery store.
// | Inside each channel directory are files containing serialized spendable
// | outputs that are awaiting some state transition. The name of each file
// | contains the outpoint of the spendable output in the file, and is
// | prefixed with 4-byte state prefix, indicating whether the spendable
// | output is a crib, preschool, or kindergarten, or graduated output. The
// | nursery store supports the ability to enumerate all outputs for a
// | particular channel, which is useful in constructing nursery reports.
// |
// ├── channel-index-key/
// │ ├── <chan-point-1>/ <- CHANNEL BUCKET
// | | ├── <state-prefix><outpoint-1>: <spendable-output-1>
// | | └── <state-prefix><outpoint-2>: <spendable-output-2>
// │ ├── <chan-point-2>/
// | | └── <state-prefix><outpoint-3>: <spendable-output-3>
// │ └── <chan-point-3>/
// | ├── <state-prefix><outpoint-4>: <spendable-output-4>
// | └── <state-prefix><outpoint-5>: <spendable-output-5>
// |
// | HEIGHT INDEX
// |
// | The height index contains a directory for each height at which the
// | nursery still has scheduled actions. If an output is a crib or
// | kindergarten output, it will have an associated entry in the height
// | index. Inside a particular height directory, the structure is similar
// | to that of the channel index, containing multiple channel directories,
// | each of which contains subdirectories named with a prefixed outpoint
// | belonging to the channel. Enumerating these combinations yields a
// | relative file path:
// | e.g. <chan-point-3>/<prefix><outpoint-2>/
// | that can be queried in the channel index to retrieve the serialized
// | output.
// |
// └── height-index-key/
// ├── <height-1>/ <- HEIGHT BUCKET
// | ├── <chan-point-3>/ <- HEIGHT-CHANNEL BUCKET
// | | ├── <state-prefix><outpoint-4>: "" <- PREFIXED OUTPOINT
// | | └── <state-prefix><outpoint-5>: ""
// | ├── <chan-point-2>/
// | | └── <state-prefix><outpoint-3>: ""
// └── <height-2>/
// └── <chan-point-1>/
// └── <state-prefix><outpoint-1>: ""
// └── <state-prefix><outpoint-2>: ""
// TODO(joostjager): Add database migration to clean up now unused last
// graduated height and finalized txes. This also prevents people downgrading
// and surprising the downgraded nursery with missing data.
// NurseryStorer abstracts the persistent storage layer for the utxo nursery.
// Concretely, it stores commitment and htlc outputs until any time-bounded
// constraints have fully matured. The store exposes methods for enumerating its
// contents, and persisting state transitions detected by the utxo nursery.
type NurseryStorer interface {
// Incubate registers a set of CSV delayed outputs (incoming HTLC's on
// our commitment transaction, or a commitment output), and a slice of
// outgoing htlc outputs to be swept back into the user's wallet. The
// event is persisted to disk, such that the nursery can resume the
// incubation process after a potential crash.
Incubate([]kidOutput, []babyOutput) error
// CribToKinder atomically moves a babyOutput in the crib bucket to the
// kindergarten bucket. Baby outputs are outgoing HTLC's which require
// us to go to the second-layer to claim. The now mature kidOutput
// contained in the babyOutput will be stored as it waits out the
// kidOutput's CSV delay.
CribToKinder(*babyOutput) error
// PreschoolToKinder atomically moves a kidOutput from the preschool
// bucket to the kindergarten bucket. This transition should be executed
// after receiving confirmation of the preschool output. Incoming HTLC's
// we need to go to the second-layer to claim, and also our commitment
// outputs fall into this class.
//
// An additional parameter specifies the last graduated height. This is
// used in case of late registration. It schedules the output for sweep
// at the next epoch even though it has already expired earlier.
PreschoolToKinder(kid *kidOutput, lastGradHeight uint32) error
// GraduateKinder atomically moves an output at the provided height into
// the graduated status. This involves removing the kindergarten entries
// from both the height and channel indexes. The height bucket will be
// opportunistically pruned from the height index as outputs are
// removed.
GraduateKinder(height uint32, output *kidOutput) error
// FetchPreschools returns a list of all outputs currently stored in
// the preschool bucket.
FetchPreschools() ([]kidOutput, error)
// FetchClass returns a list of kindergarten and crib outputs whose
// timelocks expire at the given height.
FetchClass(height uint32) ([]kidOutput, []babyOutput, error)
// HeightsBelowOrEqual returns the lowest non-empty heights in the
// height index, that exist at or below the provided upper bound.
HeightsBelowOrEqual(height uint32) ([]uint32, error)
// ForChanOutputs iterates over all outputs being incubated for a
// particular channel point. This method accepts a callback that allows
// the caller to process each key-value pair. The key will be a prefixed
// outpoint, and the value will be the serialized bytes for an output,
// whose type should be inferred from the key's prefix.
ForChanOutputs(*wire.OutPoint, func([]byte, []byte) error, func()) error
// ListChannels returns all channels the nursery is currently tracking.
ListChannels() ([]wire.OutPoint, error)
// IsMatureChannel determines the whether or not all of the outputs in a
// particular channel bucket have been marked as graduated.
IsMatureChannel(*wire.OutPoint) (bool, error)
// RemoveChannel channel erases all entries from the channel bucket for
// the provided channel point, this method should only be called if
// IsMatureChannel indicates the channel is ready for removal.
RemoveChannel(*wire.OutPoint) error
}
var (
// utxnChainPrefix is used to prefix a particular chain hash and create
// the root-level, chain-segmented bucket for each nursery store.
utxnChainPrefix = []byte("utxn")
// channelIndexKey is a static key used to lookup the bucket containing
// all of the nursery's active channels.
channelIndexKey = []byte("channel-index")
// channelIndexKey is a static key used to retrieve a directory
// containing all heights for which the nursery will need to take
// action.
heightIndexKey = []byte("height-index")
)
// Defines the state prefixes that will be used to persistently track an
// output's progress through the nursery.
// NOTE: Each state prefix MUST be exactly 4 bytes in length, the nursery logic
// depends on the ability to create keys for a different state by overwriting
// an existing state prefix.
var (
// cribPrefix is the state prefix given to htlc outputs waiting for
// their first-stage, absolute locktime to elapse.
cribPrefix = []byte("crib")
// psclPrefix is the state prefix given to commitment outputs awaiting
// the confirmation of the commitment transaction, as this solidifies
// the absolute height at which they can be spent.
psclPrefix = []byte("pscl")
// kndrPrefix is the state prefix given to all CSV delayed outputs,
// either from the commitment transaction, or a stage-one htlc
// transaction, whose maturity height has solidified. Outputs marked in
// this state are in their final stage of incubation within the nursery,
// and will be swept into the wallet after waiting out the relative
// timelock.
kndrPrefix = []byte("kndr")
// gradPrefix is the state prefix given to all outputs that have been
// completely incubated. Once all outputs have been marked as graduated,
// this serves as a persistent marker that the nursery should mark the
// channel fully closed in the channeldb.
gradPrefix = []byte("grad")
)
// prefixChainKey creates the root level keys for the nursery store. The keys
// are comprised of a nursery-specific prefix and the intended chain hash that
// this nursery store will be used for. This allows multiple nursery stores to
// isolate their state when operating on multiple chains or forks.
func prefixChainKey(sysPrefix []byte, hash *chainhash.Hash) ([]byte, error) {
// Create a buffer to which we will write the system prefix, e.g.
// "utxn", followed by the provided chain hash.
var pfxChainBuffer bytes.Buffer
if _, err := pfxChainBuffer.Write(sysPrefix); err != nil {
return nil, err
}
if _, err := pfxChainBuffer.Write(hash[:]); err != nil {
return nil, err
}
return pfxChainBuffer.Bytes(), nil
}
// prefixOutputKey creates a serialized key that prefixes the serialized
// outpoint with the provided state prefix. The returned bytes will be of the
// form <prefix><outpoint>.
func prefixOutputKey(statePrefix []byte,
outpoint wire.OutPoint) ([]byte, error) {
// Create a buffer to which we will first write the state prefix,
// followed by the outpoint.
var pfxOutputBuffer bytes.Buffer
if _, err := pfxOutputBuffer.Write(statePrefix); err != nil {
return nil, err
}
err := graphdb.WriteOutpoint(&pfxOutputBuffer, &outpoint)
if err != nil {
return nil, err
}
return pfxOutputBuffer.Bytes(), nil
}
// NurseryStore is a concrete instantiation of a NurseryStore that is backed by
// a channeldb.DB instance.
type NurseryStore struct {
chainHash chainhash.Hash
db *channeldb.DB
pfxChainKey []byte
}
// NewNurseryStore accepts a chain hash and a channeldb.DB instance, returning
// an instance of NurseryStore who's database is properly segmented for the
// given chain.
func NewNurseryStore(chainHash *chainhash.Hash,
db *channeldb.DB) (*NurseryStore, error) {
// Prefix the provided chain hash with "utxn" to create the key for the
// nursery store's root bucket, ensuring each one has proper chain
// segmentation.
pfxChainKey, err := prefixChainKey(utxnChainPrefix, chainHash)
if err != nil {
return nil, err
}
return &NurseryStore{
chainHash: *chainHash,
db: db,
pfxChainKey: pfxChainKey,
}, nil
}
// Incubate persists the beginning of the incubation process for the
// CSV-delayed outputs (commitment and incoming HTLC's), commitment output and
// a list of outgoing two-stage htlc outputs.
func (ns *NurseryStore) Incubate(kids []kidOutput, babies []babyOutput) error {
return kvdb.Update(ns.db, func(tx kvdb.RwTx) error {
// If we have any kid outputs to incubate, then we'll attempt
// to add each of them to the nursery store. Any duplicate
// outputs will be ignored.
for _, kid := range kids {
if err := ns.enterPreschool(tx, &kid); err != nil {
return err
}
}
// Next, we'll Add all htlc outputs to the crib bucket.
// Similarly, we'll ignore any outputs that have already been
// inserted.
for _, baby := range babies {
if err := ns.enterCrib(tx, &baby); err != nil {
return err
}
}
return nil
}, func() {})
}
// CribToKinder atomically moves a babyOutput in the crib bucket to the
// kindergarten bucket. The now mature kidOutput contained in the babyOutput
// will be stored as it waits out the kidOutput's CSV delay.
func (ns *NurseryStore) CribToKinder(bby *babyOutput) error {
return kvdb.Update(ns.db, func(tx kvdb.RwTx) error {
// First, retrieve or create the channel bucket corresponding to
// the baby output's origin channel point.
chanPoint := bby.OriginChanPoint()
chanBucket, err := ns.createChannelBucket(tx, chanPoint)
if err != nil {
return err
}
// The babyOutput should currently be stored in the crib bucket.
// So, we create a key that prefixes the babyOutput's outpoint
// with the crib prefix, allowing us to reference it in the
// store.
pfxOutputKey, err := prefixOutputKey(cribPrefix, bby.OutPoint())
if err != nil {
return err
}
// Since the babyOutput is being moved to the kindergarten
// bucket, we remove the entry from the channel bucket under the
// crib-prefixed outpoint key.
if err := chanBucket.Delete(pfxOutputKey); err != nil {
return err
}
// Remove the crib output's entry in the height index.
err = ns.removeOutputFromHeight(tx, bby.expiry, chanPoint,
pfxOutputKey)
if err != nil {
return err
}
// Since we are moving this output from the crib bucket to the
// kindergarten bucket, we overwrite the existing prefix of this
// key with the kindergarten prefix.
copy(pfxOutputKey, kndrPrefix)
// Now, serialize babyOutput's encapsulated kidOutput such that
// it can be written to the channel bucket under the new
// kindergarten-prefixed key.
var kidBuffer bytes.Buffer
if err := bby.kidOutput.Encode(&kidBuffer); err != nil {
return err
}
kidBytes := kidBuffer.Bytes()
// Persist the serialized kidOutput under the
// kindergarten-prefixed outpoint key.
if err := chanBucket.Put(pfxOutputKey, kidBytes); err != nil {
return err
}
// Now, compute the height at which this kidOutput's CSV delay
// will expire. This is done by adding the required delay to
// the block height at which the output was confirmed.
maturityHeight := bby.ConfHeight() + bby.BlocksToMaturity()
// Retrieve or create a height-channel bucket corresponding to
// the kidOutput's maturity height.
hghtChanBucketCsv, err := ns.createHeightChanBucket(tx,
maturityHeight, chanPoint)
if err != nil {
return err
}
utxnLog.Tracef("Transitioning (crib -> baby) output for "+
"chan_point=%v at height_index=%v", chanPoint,
maturityHeight)
// Register the kindergarten output's prefixed output key in the
// height-channel bucket corresponding to its maturity height.
// This informs the utxo nursery that it should attempt to spend
// this output when the blockchain reaches the maturity height.
return hghtChanBucketCsv.Put(pfxOutputKey, []byte{})
}, func() {})
}
// PreschoolToKinder atomically moves a kidOutput from the preschool bucket to
// the kindergarten bucket. This transition should be executed after receiving
// confirmation of the preschool output's commitment transaction.
func (ns *NurseryStore) PreschoolToKinder(kid *kidOutput,
lastGradHeight uint32) error {
return kvdb.Update(ns.db, func(tx kvdb.RwTx) error {
// Create or retrieve the channel bucket corresponding to the
// kid output's origin channel point.
chanPoint := kid.OriginChanPoint()
chanBucket, err := ns.createChannelBucket(tx, chanPoint)
if err != nil {
return err
}
// First, we will attempt to remove the existing serialized
// output from the channel bucket, where the kid's outpoint will
// be prefixed by a preschool prefix.
// Generate the key of existing serialized kid output by
// prefixing its outpoint with the preschool prefix...
pfxOutputKey, err := prefixOutputKey(psclPrefix, kid.OutPoint())
if err != nil {
return err
}
// And remove the old serialized output from the database.
if err := chanBucket.Delete(pfxOutputKey); err != nil {
return err
}
// Next, we will write the provided kid outpoint to the channel
// bucket, using a key prefixed by the kindergarten prefix.
// Convert the preschool prefix key into a kindergarten key for
// the same outpoint.
copy(pfxOutputKey, kndrPrefix)
// Reserialize the kid here to capture any differences in the
// new and old kid output, such as the confirmation height.
var kidBuffer bytes.Buffer
if err := kid.Encode(&kidBuffer); err != nil {
return err
}
kidBytes := kidBuffer.Bytes()
// And store the kid output in its channel bucket using the
// kindergarten prefixed key.
if err := chanBucket.Put(pfxOutputKey, kidBytes); err != nil {
return err
}
// If this output has an absolute time lock, then we'll set the
// maturity height directly.
var maturityHeight uint32
if kid.BlocksToMaturity() == 0 {
maturityHeight = kid.absoluteMaturity
} else {
// Otherwise, since the CSV delay on the kid output has
// now begun ticking, we must insert a record of in the
// height index to remind us to revisit this output
// once it has fully matured.
//
// Compute the maturity height, by adding the output's
// CSV delay to its confirmation height.
maturityHeight = kid.ConfHeight() + kid.BlocksToMaturity()
}
if maturityHeight <= lastGradHeight {
utxnLog.Debugf("Late Registration for kid output=%v "+
"detected: class_height=%v, "+
"last_graduated_height=%v", kid.OutPoint(),
maturityHeight, lastGradHeight)
maturityHeight = lastGradHeight + 1
}
utxnLog.Infof("Transitioning (crib -> kid) output for "+
"chan_point=%v at height_index=%v", chanPoint,
maturityHeight)
// Create or retrieve the height-channel bucket for this
// channel. This method will first create a height bucket for
// the given maturity height if none exists.
hghtChanBucket, err := ns.createHeightChanBucket(tx,
maturityHeight, chanPoint)
if err != nil {
return err
}
// Finally, we touch a key in the height-channel created above.
// The key is named using a kindergarten prefixed key, signaling
// that this CSV delayed output will be ready to broadcast at
// the maturity height, after a brief period of incubation.
return hghtChanBucket.Put(pfxOutputKey, []byte{})
}, func() {})
}
// GraduateKinder atomically moves an output at the provided height into the
// graduated status. This involves removing the kindergarten entries from both
// the height and channel indexes. The height bucket will be opportunistically
// pruned from the height index as outputs are removed.
func (ns *NurseryStore) GraduateKinder(height uint32, kid *kidOutput) error {
return kvdb.Update(ns.db, func(tx kvdb.RwTx) error {
hghtBucket := ns.getHeightBucket(tx, height)
if hghtBucket == nil {
// Nothing to delete, bucket has already been removed.
return nil
}
// For the kindergarten output, delete its entry from the
// height and channel index, and create a new grad output in the
// channel index.
outpoint := kid.OutPoint()
chanPoint := kid.OriginChanPoint()
// Construct the key under which the output is
// currently stored height and channel indexes.
pfxOutputKey, err := prefixOutputKey(kndrPrefix,
outpoint)
if err != nil {
return err
}
// Remove the grad output's entry in the height
// index.
err = ns.removeOutputFromHeight(tx, height,
chanPoint, pfxOutputKey)
if err != nil {
return err
}
chanBucket := ns.getChannelBucketWrite(tx, chanPoint)
if chanBucket == nil {
return ErrContractNotFound
}
// Remove previous output with kindergarten
// prefix.
err = chanBucket.Delete(pfxOutputKey)
if err != nil {
return err
}
// Convert kindergarten key to graduate key.
copy(pfxOutputKey, gradPrefix)
var gradBuffer bytes.Buffer
if err := kid.Encode(&gradBuffer); err != nil {
return err
}
// Insert serialized output into channel bucket
// using graduate-prefixed key.
return chanBucket.Put(pfxOutputKey,
gradBuffer.Bytes())
}, func() {})
}
// FetchClass returns a list of babyOutputs in the crib bucket whose CLTV
// delay expires at the provided block height.
// FetchClass returns a list of the kindergarten and crib outputs whose timeouts
// are expiring
func (ns *NurseryStore) FetchClass(
height uint32) ([]kidOutput, []babyOutput, error) { // nolint:revive
// Construct list of all crib and kindergarten outputs that need to be
// processed at the provided block height.
var kids []kidOutput
var babies []babyOutput
if err := kvdb.View(ns.db, func(tx kvdb.RTx) error {
// Append each crib output to our list of babyOutputs.
if err := ns.forEachHeightPrefix(tx, cribPrefix, height,
func(buf []byte) error {
// We will attempt to deserialize all outputs
// stored with the crib prefix into babyOutputs,
// since this is the expected type that would
// have been serialized previously.
var baby babyOutput
babyReader := bytes.NewReader(buf)
if err := baby.Decode(babyReader); err != nil {
return err
}
babies = append(babies, baby)
return nil
},
); err != nil {
return err
}
// Append each kindergarten output to our list of kidOutputs.
return ns.forEachHeightPrefix(tx, kndrPrefix, height,
func(buf []byte) error {
// We will attempt to deserialize all outputs
// stored with the kindergarten prefix into
// kidOutputs, since this is the expected type
// that would have been serialized previously.
var kid kidOutput
kidReader := bytes.NewReader(buf)
if err := kid.Decode(kidReader); err != nil {
return err
}
kids = append(kids, kid)
return nil
})
}, func() {
kids = nil
babies = nil
}); err != nil {
return nil, nil, err
}
return kids, babies, nil
}
// FetchPreschools returns a list of all outputs currently stored in the
// preschool bucket.
func (ns *NurseryStore) FetchPreschools() ([]kidOutput, error) { // nolint:revive
var kids []kidOutput
if err := kvdb.View(ns.db, func(tx kvdb.RTx) error {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Load the existing channel index from the chain bucket.
chanIndex := chainBucket.NestedReadBucket(channelIndexKey)
if chanIndex == nil {
return nil
}
// Construct a list of all channels in the channel index that
// are currently being tracked by the nursery store.
var activeChannels [][]byte
if err := chanIndex.ForEach(func(chanBytes, _ []byte) error {
activeChannels = append(activeChannels, chanBytes)
return nil
}); err != nil {
return err
}
// Iterate over all of the accumulated channels, and do a prefix
// scan inside of each channel bucket. Each output found that
// has a preschool prefix will be deserialized into a kidOutput,
// and added to our list of preschool outputs to return to the
// caller.
for _, chanBytes := range activeChannels {
// Retrieve the channel bucket associated with this
// channel.
chanBucket := chanIndex.NestedReadBucket(chanBytes)
if chanBucket == nil {
continue
}
// All of the outputs of interest will start with the
// "pscl" prefix. So, we will perform a prefix scan of
// the channel bucket to efficiently enumerate all the
// desired outputs.
c := chanBucket.ReadCursor()
for k, v := c.Seek(psclPrefix); bytes.HasPrefix(
k, psclPrefix); k, v = c.Next() {
// Deserialize each output as a kidOutput, since
// this should have been the type that was
// serialized when it was written to disk.
var psclOutput kidOutput
psclReader := bytes.NewReader(v)
err := psclOutput.Decode(psclReader)
if err != nil {
return err
}
// Add the deserialized output to our list of
// preschool outputs.
kids = append(kids, psclOutput)
}
}
return nil
}, func() {
kids = nil
}); err != nil {
return nil, err
}
return kids, nil
}
// HeightsBelowOrEqual returns a slice of all non-empty heights in the height
// index at or below the provided upper bound.
func (ns *NurseryStore) HeightsBelowOrEqual(height uint32) ([]uint32, error) {
var activeHeights []uint32
err := kvdb.View(ns.db, func(tx kvdb.RTx) error {
// Ensure that the chain bucket for this nursery store exists.
chainBucket := tx.ReadBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Ensure that the height index has been properly initialized for this
// chain.
hghtIndex := chainBucket.NestedReadBucket(heightIndexKey)
if hghtIndex == nil {
return nil
}
// Serialize the provided height, as this will form the name of the
// bucket.
var lower, upper [4]byte
byteOrder.PutUint32(upper[:], height)
c := hghtIndex.ReadCursor()
for k, _ := c.Seek(lower[:]); bytes.Compare(k, upper[:]) <= 0 &&
len(k) == 4; k, _ = c.Next() {
activeHeights = append(activeHeights, byteOrder.Uint32(k))
}
return nil
}, func() {
activeHeights = nil
})
if err != nil {
return nil, err
}
return activeHeights, nil
}
// ForChanOutputs iterates over all outputs being incubated for a particular
// channel point. This method accepts a callback that allows the caller to
// process each key-value pair. The key will be a prefixed outpoint, and the
// value will be the serialized bytes for an output, whose type should be
// inferred from the key's prefix.
// NOTE: The callback should not modify the provided byte slices and is
// preferably non-blocking.
func (ns *NurseryStore) ForChanOutputs(chanPoint *wire.OutPoint,
callback func([]byte, []byte) error, reset func()) error {
return kvdb.View(ns.db, func(tx kvdb.RTx) error {
return ns.forChanOutputs(tx, chanPoint, callback)
}, reset)
}
// ListChannels returns all channels the nursery is currently tracking.
func (ns *NurseryStore) ListChannels() ([]wire.OutPoint, error) {
var activeChannels []wire.OutPoint
if err := kvdb.View(ns.db, func(tx kvdb.RTx) error {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Retrieve the existing channel index.
chanIndex := chainBucket.NestedReadBucket(channelIndexKey)
if chanIndex == nil {
return nil
}
return chanIndex.ForEach(func(chanBytes, _ []byte) error {
var chanPoint wire.OutPoint
err := graphdb.ReadOutpoint(
bytes.NewReader(chanBytes), &chanPoint,
)
if err != nil {
return err
}
activeChannels = append(activeChannels, chanPoint)
return nil
})
}, func() {
activeChannels = nil
}); err != nil {
return nil, err
}
return activeChannels, nil
}
// IsMatureChannel determines the whether or not all of the outputs in a
// particular channel bucket have been marked as graduated.
func (ns *NurseryStore) IsMatureChannel(chanPoint *wire.OutPoint) (bool, error) {
err := kvdb.View(ns.db, func(tx kvdb.RTx) error {
// Iterate over the contents of the channel bucket, computing
// both total number of outputs, and those that have the grad
// prefix.
return ns.forChanOutputs(tx, chanPoint,
func(pfxKey, _ []byte) error {
if !bytes.HasPrefix(pfxKey, gradPrefix) {
return ErrImmatureChannel
}
return nil
})
}, func() {})
if err != nil && err != ErrImmatureChannel {
return false, err
}
return err == nil, nil
}
// ErrImmatureChannel signals a channel cannot be removed because not all of its
// outputs have graduated.
var ErrImmatureChannel = errors.New("cannot remove immature channel, " +
"still has ungraduated outputs")
// RemoveChannel channel erases all entries from the channel bucket for the
// provided channel point.
// NOTE: The channel's entries in the height index are assumed to be removed.
func (ns *NurseryStore) RemoveChannel(chanPoint *wire.OutPoint) error {
return kvdb.Update(ns.db, func(tx kvdb.RwTx) error {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadWriteBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Retrieve the channel index stored in the chain bucket.
chanIndex := chainBucket.NestedReadWriteBucket(channelIndexKey)
if chanIndex == nil {
return nil
}
// Serialize the provided channel point, such that we can delete
// the mature channel bucket.
var chanBuffer bytes.Buffer
err := graphdb.WriteOutpoint(&chanBuffer, chanPoint)
if err != nil {
return err
}
chanBytes := chanBuffer.Bytes()
err = ns.forChanOutputs(tx, chanPoint, func(k, v []byte) error {
if !bytes.HasPrefix(k, gradPrefix) {
return ErrImmatureChannel
}
// Construct a kindergarten prefixed key, since this
// would have been the preceding state for a grad
// output.
kndrKey := make([]byte, len(k))
copy(kndrKey, k)
copy(kndrKey[:4], kndrPrefix)
// Decode each to retrieve the output's maturity height.
var kid kidOutput
if err := kid.Decode(bytes.NewReader(v)); err != nil {
return err
}
maturityHeight := kid.ConfHeight() + kid.BlocksToMaturity()
hghtBucket := ns.getHeightBucketWrite(tx, maturityHeight)
if hghtBucket == nil {
return nil
}
return removeBucketIfExists(hghtBucket, chanBytes)
})
if err != nil {
return err
}
return removeBucketIfExists(chanIndex, chanBytes)
}, func() {})
}
// Helper Methods
// enterCrib accepts a new htlc output that the nursery will incubate through
// its two-stage process of sweeping funds back to the user's wallet. These
// outputs are persisted in the nursery store in the crib state, and will be
// revisited after the first-stage output's CLTV has expired.
func (ns *NurseryStore) enterCrib(tx kvdb.RwTx, baby *babyOutput) error {
// First, retrieve or create the channel bucket corresponding to the
// baby output's origin channel point.
chanPoint := baby.OriginChanPoint()
chanBucket, err := ns.createChannelBucket(tx, chanPoint)
if err != nil {
return err
}
// Since we are inserting this output into the crib bucket, we create a
// key that prefixes the baby output's outpoint with the crib prefix.
pfxOutputKey, err := prefixOutputKey(cribPrefix, baby.OutPoint())
if err != nil {
return err
}
// We'll first check that we don't already have an entry for this
// output. If we do, then we can exit early.
if rawBytes := chanBucket.Get(pfxOutputKey); rawBytes != nil {
return nil
}
// Next, retrieve or create the height-channel bucket located in the
// height bucket corresponding to the baby output's CLTV expiry height.
// TODO: Handle late registration.
hghtChanBucket, err := ns.createHeightChanBucket(tx,
baby.expiry, chanPoint)
if err != nil {
return err
}
// Serialize the baby output so that it can be written to the
// underlying key-value store.
var babyBuffer bytes.Buffer
if err := baby.Encode(&babyBuffer); err != nil {
return err
}
babyBytes := babyBuffer.Bytes()
// Now, insert the serialized output into its channel bucket under the
// prefixed key created above.
if err := chanBucket.Put(pfxOutputKey, babyBytes); err != nil {
return err
}
// Finally, create a corresponding bucket in the height-channel bucket
// for this crib output. The existence of this bucket indicates that
// the serialized output can be retrieved from the channel bucket using
// the same prefix key.
return hghtChanBucket.Put(pfxOutputKey, []byte{})
}
// enterPreschool accepts a new commitment output that the nursery will incubate
// through a single stage before sweeping. Outputs are stored in the preschool
// bucket until the commitment transaction has been confirmed, at which point
// they will be moved to the kindergarten bucket.
func (ns *NurseryStore) enterPreschool(tx kvdb.RwTx, kid *kidOutput) error {
// First, retrieve or create the channel bucket corresponding to the
// baby output's origin channel point.
chanPoint := kid.OriginChanPoint()
chanBucket, err := ns.createChannelBucket(tx, chanPoint)
if err != nil {
return err
}
// Since the kidOutput is being inserted into the preschool bucket, we
// create a key that prefixes its outpoint with the preschool prefix.
pfxOutputKey, err := prefixOutputKey(psclPrefix, kid.OutPoint())
if err != nil {
return err
}
// We'll first check if an entry for this key is already stored. If so,
// then we'll ignore this request, and return a nil error.
if rawBytes := chanBucket.Get(pfxOutputKey); rawBytes != nil {
return nil
}
// Serialize the kidOutput and insert it into the channel bucket.
var kidBuffer bytes.Buffer
if err := kid.Encode(&kidBuffer); err != nil {
return err
}
return chanBucket.Put(pfxOutputKey, kidBuffer.Bytes())
}
// createChannelBucket creates or retrieves a channel bucket for the provided
// channel point.
func (ns *NurseryStore) createChannelBucket(tx kvdb.RwTx,
chanPoint *wire.OutPoint) (kvdb.RwBucket, error) {
// Ensure that the chain bucket for this nursery store exists.
chainBucket, err := tx.CreateTopLevelBucket(ns.pfxChainKey)
if err != nil {
return nil, err
}
// Ensure that the channel index has been properly initialized for this
// chain.
chanIndex, err := chainBucket.CreateBucketIfNotExists(channelIndexKey)
if err != nil {
return nil, err
}
// Serialize the provided channel point, as this provides the name of
// the channel bucket of interest.
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return nil, err
}
// Finally, create or retrieve the channel bucket using the serialized
// key.
return chanIndex.CreateBucketIfNotExists(chanBuffer.Bytes())
}
// getChannelBucket retrieves an existing channel bucket from the nursery store,
// using the given channel point. If the bucket does not exist, or any bucket
// along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getChannelBucket(tx kvdb.RTx,
chanPoint *wire.OutPoint) kvdb.RBucket {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Retrieve the existing channel index.
chanIndex := chainBucket.NestedReadBucket(channelIndexKey)
if chanIndex == nil {
return nil
}
// Serialize the provided channel point and return the bucket matching
// the serialized key.
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return nil
}
return chanIndex.NestedReadBucket(chanBuffer.Bytes())
}
// getChannelBucketWrite retrieves an existing channel bucket from the nursery store,
// using the given channel point. If the bucket does not exist, or any bucket
// along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getChannelBucketWrite(tx kvdb.RwTx,
chanPoint *wire.OutPoint) kvdb.RwBucket {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadWriteBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil
}
// Retrieve the existing channel index.
chanIndex := chainBucket.NestedReadWriteBucket(channelIndexKey)
if chanIndex == nil {
return nil
}
// Serialize the provided channel point and return the bucket matching
// the serialized key.
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return nil
}
return chanIndex.NestedReadWriteBucket(chanBuffer.Bytes())
}
// createHeightBucket creates or retrieves an existing bucket from the height
// index, corresponding to the provided height.
func (ns *NurseryStore) createHeightBucket(tx kvdb.RwTx,
height uint32) (kvdb.RwBucket, error) {
// Ensure that the chain bucket for this nursery store exists.
chainBucket, err := tx.CreateTopLevelBucket(ns.pfxChainKey)
if err != nil {
return nil, err
}
// Ensure that the height index has been properly initialized for this
// chain.
hghtIndex, err := chainBucket.CreateBucketIfNotExists(heightIndexKey)
if err != nil {
return nil, err
}
// Serialize the provided height, as this will form the name of the
// bucket.
var heightBytes [4]byte
byteOrder.PutUint32(heightBytes[:], height)
// Finally, create or retrieve the bucket in question.
return hghtIndex.CreateBucketIfNotExists(heightBytes[:])
}
// getHeightBucketPath retrieves an existing height bucket from the nursery
// store, using the provided block height. If the bucket does not exist, or any
// bucket along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getHeightBucketPath(tx kvdb.RTx,
height uint32) (kvdb.RBucket, kvdb.RBucket) {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil, nil
}
// Retrieve the existing channel index.
hghtIndex := chainBucket.NestedReadBucket(heightIndexKey)
if hghtIndex == nil {
return nil, nil
}
// Serialize the provided block height and return the bucket matching
// the serialized key.
var heightBytes [4]byte
byteOrder.PutUint32(heightBytes[:], height)
return chainBucket, hghtIndex.NestedReadBucket(heightBytes[:])
}
// getHeightBucketPathWrite retrieves an existing height bucket from the nursery
// store, using the provided block height. If the bucket does not exist, or any
// bucket along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getHeightBucketPathWrite(tx kvdb.RwTx,
height uint32) (kvdb.RwBucket, kvdb.RwBucket) {
// Retrieve the existing chain bucket for this nursery store.
chainBucket := tx.ReadWriteBucket(ns.pfxChainKey)
if chainBucket == nil {
return nil, nil
}
// Retrieve the existing channel index.
hghtIndex := chainBucket.NestedReadWriteBucket(heightIndexKey)
if hghtIndex == nil {
return nil, nil
}
// Serialize the provided block height and return the bucket matching
// the serialized key.
var heightBytes [4]byte
byteOrder.PutUint32(heightBytes[:], height)
return hghtIndex, hghtIndex.NestedReadWriteBucket(
heightBytes[:],
)
}
// getHeightBucket retrieves an existing height bucket from the nursery store,
// using the provided block height. If the bucket does not exist, or any bucket
// along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getHeightBucket(tx kvdb.RTx,
height uint32) kvdb.RBucket {
_, hghtBucket := ns.getHeightBucketPath(tx, height)
return hghtBucket
}
// getHeightBucketWrite retrieves an existing height bucket from the nursery store,
// using the provided block height. If the bucket does not exist, or any bucket
// along its path does not exist, a nil value is returned.
func (ns *NurseryStore) getHeightBucketWrite(tx kvdb.RwTx,
height uint32) kvdb.RwBucket {
_, hghtBucket := ns.getHeightBucketPathWrite(tx, height)
return hghtBucket
}
// createHeightChanBucket creates or retrieves an existing height-channel bucket
// for the provided block height and channel point. This method will attempt to
// instantiate all buckets along the path if required.
func (ns *NurseryStore) createHeightChanBucket(tx kvdb.RwTx,
height uint32, chanPoint *wire.OutPoint) (kvdb.RwBucket, error) {
// Ensure that the height bucket for this nursery store exists.
hghtBucket, err := ns.createHeightBucket(tx, height)
if err != nil {
return nil, err
}
// Serialize the provided channel point, as this generates the name of
// the subdirectory corresponding to the channel of interest.
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return nil, err
}
chanBytes := chanBuffer.Bytes()
// Finally, create or retrieve an existing height-channel bucket for
// this channel point.
return hghtBucket.CreateBucketIfNotExists(chanBytes)
}
// getHeightChanBucketWrite retrieves an existing height-channel bucket from the
// nursery store, using the provided block height and channel point. if the
// bucket does not exist, or any bucket along its path does not exist, a nil
// value is returned.
func (ns *NurseryStore) getHeightChanBucketWrite(tx kvdb.RwTx,
height uint32, chanPoint *wire.OutPoint) kvdb.RwBucket {
// Retrieve the existing height bucket from this nursery store.
hghtBucket := ns.getHeightBucketWrite(tx, height)
if hghtBucket == nil {
return nil
}
// Serialize the provided channel point, which generates the key for
// looking up the proper height-channel bucket inside the height bucket.
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return nil
}
chanBytes := chanBuffer.Bytes()
// Finally, return the height bucket specified by the serialized channel
// point.
return hghtBucket.NestedReadWriteBucket(chanBytes)
}
// forEachHeightPrefix enumerates all outputs at the given height whose state
// prefix matches that which is provided. This is used as a subroutine to help
// enumerate crib and kindergarten outputs at a particular height. The callback
// is invoked with serialized bytes retrieved for each output of interest,
// allowing the caller to deserialize them into the appropriate type.
func (ns *NurseryStore) forEachHeightPrefix(tx kvdb.RTx, prefix []byte,
height uint32, callback func([]byte) error) error {
// Start by retrieving the height bucket corresponding to the provided
// block height.
chainBucket, hghtBucket := ns.getHeightBucketPath(tx, height)
if hghtBucket == nil {
return nil
}
// Using the height bucket as a starting point, we will traverse its
// entire two-tier directory structure, and filter for outputs that have
// the provided prefix. The first layer of the height bucket contains
// buckets identified by a channel point, thus we first create list of
// channels contained in this height bucket.
var channelsAtHeight [][]byte
if err := hghtBucket.ForEach(func(chanBytes, v []byte) error {
if v == nil {
channelsAtHeight = append(channelsAtHeight, chanBytes)
}
return nil
}); err != nil {
return err
}
// Additionally, grab the chain index, which we will facilitate queries
// for each of the channel buckets of each of the channels in the list
// we assembled above.
chanIndex := chainBucket.NestedReadBucket(channelIndexKey)
if chanIndex == nil {
return errors.New("unable to retrieve channel index")
}
// Now, we are ready to enumerate all outputs with the desired prefix at
// this block height. We do so by iterating over our list of channels at
// this height, filtering for outputs in each height-channel bucket that
// begin with the given prefix, and then retrieving the serialized
// outputs from the appropriate channel bucket.
for _, chanBytes := range channelsAtHeight {
// Retrieve the height-channel bucket for this channel, which
// holds a sub-bucket for all outputs maturing at this height.
hghtChanBucket := hghtBucket.NestedReadBucket(chanBytes)
if hghtChanBucket == nil {
return fmt.Errorf("unable to retrieve height-channel "+
"bucket at height %d for %x", height, chanBytes)
}
// Load the appropriate channel bucket from the channel index,
// this will allow us to retrieve the individual serialized
// outputs.
chanBucket := chanIndex.NestedReadBucket(chanBytes)
if chanBucket == nil {
return fmt.Errorf("unable to retrieve channel "+
"bucket: '%x'", chanBytes)
}
// Since all of the outputs of interest will start with the same
// prefix, we will perform a prefix scan of the buckets
// contained in the height-channel bucket, efficiently
// enumerating the desired outputs.
c := hghtChanBucket.ReadCursor()
for k, _ := c.Seek(prefix); bytes.HasPrefix(
k, prefix); k, _ = c.Next() {
// Use the prefix output key emitted from our scan to
// load the serialized babyOutput from the appropriate
// channel bucket.
outputBytes := chanBucket.Get(k)
if outputBytes == nil {
return errors.New("unable to retrieve output")
}
// Present the serialized bytes to our call back
// function, which is responsible for deserializing the
// bytes into the appropriate type.
if err := callback(outputBytes); err != nil {
return err
}
}
}
return nil
}
// forChanOutputs enumerates the outputs contained in a channel bucket to the
// provided callback. The callback accepts a key-value pair of byte slices
// corresponding to the prefixed-output key and the serialized output,
// respectively.
func (ns *NurseryStore) forChanOutputs(tx kvdb.RTx, chanPoint *wire.OutPoint,
callback func([]byte, []byte) error) error {
chanBucket := ns.getChannelBucket(tx, chanPoint)
if chanBucket == nil {
return ErrContractNotFound
}
return chanBucket.ForEach(callback)
}
// errBucketNotEmpty signals that an attempt to prune a particular
// bucket failed because it still has active outputs.
var errBucketNotEmpty = errors.New("bucket is not empty, cannot be pruned")
// removeOutputFromHeight will delete the given output from the specified
// height-channel bucket, and attempt to prune the upstream directories if they
// are empty.
func (ns *NurseryStore) removeOutputFromHeight(tx kvdb.RwTx, height uint32,
chanPoint *wire.OutPoint, pfxKey []byte) error {
// Retrieve the height-channel bucket and delete the prefixed output.
hghtChanBucket := ns.getHeightChanBucketWrite(tx, height, chanPoint)
if hghtChanBucket == nil {
// Height-channel bucket already removed.
return nil
}
// Try to delete the prefixed output from the target height-channel
// bucket.
if err := hghtChanBucket.Delete(pfxKey); err != nil {
return err
}
// Retrieve the height bucket that contains the height-channel bucket.
hghtBucket := ns.getHeightBucketWrite(tx, height)
if hghtBucket == nil {
return errors.New("height bucket not found")
}
var chanBuffer bytes.Buffer
if err := graphdb.WriteOutpoint(&chanBuffer, chanPoint); err != nil {
return err
}
// Try to remove the channel-height bucket if it this was the last
// output in the bucket.
err := removeBucketIfEmpty(hghtBucket, chanBuffer.Bytes())
if err != nil && err != errBucketNotEmpty {
return err
} else if err == errBucketNotEmpty {
return nil
}
// Attempt to prune the height bucket matching the kid output's
// confirmation height in case that was the last height-chan bucket.
pruned, err := ns.pruneHeight(tx, height)
if err != nil && err != errBucketNotEmpty {
return err
} else if err == nil && pruned {
utxnLog.Infof("Height bucket %d pruned", height)
}
return nil
}
// pruneHeight removes the height bucket at the provided height if and only if
// all active outputs at this height have been removed from their respective
// height-channel buckets. The returned boolean value indicated whether or not
// this invocation successfully pruned the height bucket.
func (ns *NurseryStore) pruneHeight(tx kvdb.RwTx, height uint32) (bool, error) {
// Fetch the existing height index and height bucket.
hghtIndex, hghtBucket := ns.getHeightBucketPathWrite(tx, height)
if hghtBucket == nil {
return false, nil
}
// Iterate over all channels stored at this block height. We will
// attempt to remove each one if they are empty, keeping track of the
// number of height-channel buckets that still have active outputs.
if err := hghtBucket.ForEach(func(chanBytes, v []byte) error {
// Skip the finalized txn key if it still exists from a previous
// db version.
if v != nil {
return nil
}
// Attempt to each height-channel bucket from the height bucket
// located above.
hghtChanBucket := hghtBucket.NestedReadWriteBucket(chanBytes)
if hghtChanBucket == nil {
return errors.New("unable to find height-channel bucket")
}
return isBucketEmpty(hghtChanBucket)
}); err != nil {
return false, err
}
// Serialize the provided block height, such that it can be used as the
// key to delete desired height bucket.
var heightBytes [4]byte
byteOrder.PutUint32(heightBytes[:], height)
// All of the height-channel buckets are empty or have been previously
// removed, proceed by removing the height bucket
// altogether.
if err := removeBucketIfExists(hghtIndex, heightBytes[:]); err != nil {
return false, err
}
return true, nil
}
// removeBucketIfEmpty attempts to delete a bucket specified by name from the
// provided parent bucket.
func removeBucketIfEmpty(parent kvdb.RwBucket, bktName []byte) error {
// Attempt to fetch the named bucket from its parent.
bkt := parent.NestedReadWriteBucket(bktName)
if bkt == nil {
// No bucket was found, already removed?
return nil
}
// The bucket exists, fail if it still has children.
if err := isBucketEmpty(bkt); err != nil {
return err
}
return parent.DeleteNestedBucket(bktName)
}
// removeBucketIfExists safely deletes the named bucket by first checking
// that it exists in the parent bucket.
func removeBucketIfExists(parent kvdb.RwBucket, bktName []byte) error {
// Attempt to fetch the named bucket from its parent.
bkt := parent.NestedReadWriteBucket(bktName)
if bkt == nil {
// No bucket was found, already removed?
return nil
}
return parent.DeleteNestedBucket(bktName)
}
// isBucketEmpty returns errBucketNotEmpty if the bucket has a non-zero number
// of children.
func isBucketEmpty(parent kvdb.RBucket) error {
return parent.ForEach(func(_, _ []byte) error {
return errBucketNotEmpty
})
}
// Compile-time constraint to ensure NurseryStore implements NurseryStorer.
var _ NurseryStorer = (*NurseryStore)(nil)
package contractcourt
import (
"bytes"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
commitCtrlBlockType tlv.Type = 0
revokeCtrlBlockType tlv.Type = 1
outgoingHtlcCtrlBlockType tlv.Type = 2
incomingHtlcCtrlBlockType tlv.Type = 3
secondLevelCtrlBlockType tlv.Type = 4
anchorTapTweakType tlv.Type = 0
htlcTweakCtrlBlockType tlv.Type = 1
secondLevelHtlcTweakCtrlBlockType tlv.Type = 2
)
// taprootBriefcase is a supplemental storage struct that contains all the
// information we need to sweep taproot outputs.
type taprootBriefcase struct {
// CtrlBlock is the set of control block for the taproot outputs.
CtrlBlocks tlv.RecordT[tlv.TlvType0, ctrlBlocks]
// TapTweaks is the set of taproot tweaks for the taproot outputs that
// are to be spent via a keyspend path. This includes anchors, and any
// revocation paths.
TapTweaks tlv.RecordT[tlv.TlvType1, tapTweaks]
// SettledCommitBlob is an optional record that contains an opaque blob
// that may be used to properly sweep commitment outputs on a force
// close transaction.
SettledCommitBlob tlv.OptionalRecordT[tlv.TlvType2, tlv.Blob]
// BreachCommitBlob is an optional record that contains an opaque blob
// used to sweep a remote party's breached output.
BreachedCommitBlob tlv.OptionalRecordT[tlv.TlvType3, tlv.Blob]
// HtlcBlobs is an optikonal record that contains the opaque blobs for
// the set of active HTLCs on the commitment transaction.
HtlcBlobs tlv.OptionalRecordT[tlv.TlvType4, htlcAuxBlobs]
}
// TODO(roasbeef): morph into new tlv record
// newTaprootBriefcase returns a new instance of the taproot specific briefcase
// variant.
func newTaprootBriefcase() *taprootBriefcase {
return &taprootBriefcase{
CtrlBlocks: tlv.NewRecordT[tlv.TlvType0](newCtrlBlocks()),
TapTweaks: tlv.NewRecordT[tlv.TlvType1](newTapTweaks()),
}
}
// EncodeRecords returns a slice of TLV records that should be encoded.
func (t *taprootBriefcase) EncodeRecords() []tlv.Record {
records := []tlv.Record{
t.CtrlBlocks.Record(),
t.TapTweaks.Record(),
}
t.SettledCommitBlob.WhenSome(
func(r tlv.RecordT[tlv.TlvType2, tlv.Blob]) {
records = append(records, r.Record())
},
)
t.BreachedCommitBlob.WhenSome(
func(r tlv.RecordT[tlv.TlvType3, tlv.Blob]) {
records = append(records, r.Record())
},
)
t.HtlcBlobs.WhenSome(func(r tlv.RecordT[tlv.TlvType4, htlcAuxBlobs]) {
records = append(records, r.Record())
})
return records
}
// DecodeRecords returns a slice of TLV records that should be decoded.
func (t *taprootBriefcase) DecodeRecords() []tlv.Record {
return []tlv.Record{
t.CtrlBlocks.Record(),
t.TapTweaks.Record(),
}
}
// Encode records returns a slice of TLV records that should be encoded.
func (t *taprootBriefcase) Encode(w io.Writer) error {
stream, err := tlv.NewStream(t.EncodeRecords()...)
if err != nil {
return err
}
return stream.Encode(w)
}
// Decode decodes the given reader into the target struct.
func (t *taprootBriefcase) Decode(r io.Reader) error {
settledCommitBlob := t.SettledCommitBlob.Zero()
breachedCommitBlob := t.BreachedCommitBlob.Zero()
htlcBlobs := t.HtlcBlobs.Zero()
records := append(
t.DecodeRecords(), settledCommitBlob.Record(),
breachedCommitBlob.Record(), htlcBlobs.Record(),
)
stream, err := tlv.NewStream(records...)
if err != nil {
return err
}
typeMap, err := stream.DecodeWithParsedTypes(r)
if err != nil {
return err
}
if val, ok := typeMap[t.SettledCommitBlob.TlvType()]; ok && val == nil {
t.SettledCommitBlob = tlv.SomeRecordT(settledCommitBlob)
}
if v, ok := typeMap[t.BreachedCommitBlob.TlvType()]; ok && v == nil {
t.BreachedCommitBlob = tlv.SomeRecordT(breachedCommitBlob)
}
if v, ok := typeMap[t.HtlcBlobs.TlvType()]; ok && v == nil {
t.HtlcBlobs = tlv.SomeRecordT(htlcBlobs)
}
return nil
}
// resolverCtrlBlocks is a map of resolver IDs to their corresponding control
// block.
type resolverCtrlBlocks map[resolverID][]byte
// newResolverCtrlBlockss returns a new instance of the resolverCtrlBlocks.
func newResolverCtrlBlocks() resolverCtrlBlocks {
return make(resolverCtrlBlocks)
}
// recordSize returns the size of the record in bytes.
func (r *resolverCtrlBlocks) recordSize() uint64 {
// Each record will be serialized as: <num_total_records> || <record>,
// where <record> is serialized as: <resolver_key> || <length> ||
// <ctrl_block>.
numBlocks := uint64(len(*r))
baseSize := tlv.VarIntSize(numBlocks)
recordSize := baseSize
for _, ctrlBlock := range *r {
recordSize += resolverIDLen
recordSize += tlv.VarIntSize(uint64(len(ctrlBlock)))
recordSize += uint64(len(ctrlBlock))
}
return recordSize
}
// Encode encodes the control blocks into the target writer.
func (r *resolverCtrlBlocks) Encode(w io.Writer) error {
numBlocks := uint64(len(*r))
var buf [8]byte
if err := tlv.WriteVarInt(w, numBlocks, &buf); err != nil {
return err
}
for id, ctrlBlock := range *r {
ctrlBlock := ctrlBlock
if _, err := w.Write(id[:]); err != nil {
return err
}
if err := varBytesEncoder(w, &ctrlBlock, &buf); err != nil {
return err
}
}
return nil
}
// Decode decodes the given reader into the target struct.
func (r *resolverCtrlBlocks) Decode(reader io.Reader) error {
var buf [8]byte
numBlocks, err := tlv.ReadVarInt(reader, &buf)
if err != nil {
return err
}
for i := uint64(0); i < numBlocks; i++ {
var id resolverID
if _, err := io.ReadFull(reader, id[:]); err != nil {
return err
}
var ctrlBlock []byte
err := varBytesDecoder(reader, &ctrlBlock, &buf, 0)
if err != nil {
return err
}
(*r)[id] = ctrlBlock
}
return nil
}
// resolverCtrlBlocksEncoder is a custom TLV encoder for the
// resolverCtrlBlocks.
func resolverCtrlBlocksEncoder(w io.Writer, val any, _ *[8]byte) error {
if typ, ok := val.(*resolverCtrlBlocks); ok {
return (*typ).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "resolverCtrlBlocks")
}
// rsolverCtrlBlocksDecoder is a custom TLV decoder for the resolverCtrlBlocks.
func resolverCtrlBlocksDecoder(r io.Reader, val any, _ *[8]byte,
l uint64) error {
if typ, ok := val.(*resolverCtrlBlocks); ok {
blockReader := io.LimitReader(r, int64(l))
resolverBlocks := newResolverCtrlBlocks()
err := resolverBlocks.Decode(blockReader)
if err != nil {
return err
}
*typ = resolverBlocks
return nil
}
return tlv.NewTypeForDecodingErr(val, "resolverCtrlBlocks", l, l)
}
// ctrlBlocks is the set of control blocks we need to sweep all the output for
// a taproot/musig2 channel.
type ctrlBlocks struct {
// CommitSweepCtrlBlock is the serialized control block needed to sweep
// our commitment output.
CommitSweepCtrlBlock []byte
// RevokeSweepCtrlBlock is the serialized control block that's used to
// sweep the reovked output of a breaching party.
RevokeSweepCtrlBlock []byte
// OutgoingHtlcCtrlBlocks is the set of serialized control blocks for
// all outgoing HTLCs. This is the set of HTLCs that we offered to the
// remote party. Depending on which commitment transaction was
// broadcast, we'll either sweep here and be done, or also need to go
// to the second level.
OutgoingHtlcCtrlBlocks resolverCtrlBlocks
// IncomingHtlcCtrlBlocks is the set of serialized control blocks for
// all incoming HTLCs
IncomingHtlcCtrlBlocks resolverCtrlBlocks
// SecondLevelCtrlBlocks is the set of serialized control blocks for
// need to sweep the second level HTLCs on our commitment transaction.
SecondLevelCtrlBlocks resolverCtrlBlocks
}
// newCtrlBlocks returns a new instance of the ctrlBlocks struct.
func newCtrlBlocks() ctrlBlocks {
return ctrlBlocks{
OutgoingHtlcCtrlBlocks: newResolverCtrlBlocks(),
IncomingHtlcCtrlBlocks: newResolverCtrlBlocks(),
SecondLevelCtrlBlocks: newResolverCtrlBlocks(),
}
}
// varBytesEncoder is a custom TLV encoder for a variable length byte slice.
func varBytesEncoder(w io.Writer, val any, buf *[8]byte) error {
if t, ok := val.(*[]byte); ok {
if err := tlv.WriteVarInt(w, uint64(len(*t)), buf); err != nil {
return err
}
return tlv.EVarBytes(w, t, buf)
}
return tlv.NewTypeForEncodingErr(val, "[]byte")
}
// varBytesDecoder is a custom TLV decoder for a variable length byte slice.
func varBytesDecoder(r io.Reader, val any, buf *[8]byte, l uint64) error {
if typ, ok := val.(*[]byte); ok {
bytesLen, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
var bytes []byte
if err := tlv.DVarBytes(r, &bytes, buf, bytesLen); err != nil {
return err
}
*typ = bytes
return nil
}
return tlv.NewTypeForDecodingErr(val, "[]byte", l, l)
}
// ctrlBlockEncoder is a custom TLV encoder for the ctrlBlocks struct.
func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(*ctrlBlocks); ok {
return (*t).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "ctrlBlocks")
}
// ctrlBlockDecoder is a custom TLV decoder for the ctrlBlocks struct.
func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
if typ, ok := val.(*ctrlBlocks); ok {
ctrlReader := io.LimitReader(r, int64(l))
var ctrlBlocks ctrlBlocks
err := ctrlBlocks.Decode(ctrlReader)
if err != nil {
return err
}
*typ = ctrlBlocks
return nil
}
return tlv.NewTypeForDecodingErr(val, "ctrlBlocks", l, l)
}
// EncodeRecords returns the set of TLV records that encode the control block
// for the commitment transaction.
func (c *ctrlBlocks) EncodeRecords() []tlv.Record {
var records []tlv.Record
if len(c.CommitSweepCtrlBlock) > 0 {
records = append(records, tlv.MakePrimitiveRecord(
commitCtrlBlockType, &c.CommitSweepCtrlBlock,
))
}
if len(c.RevokeSweepCtrlBlock) > 0 {
records = append(records, tlv.MakePrimitiveRecord(
revokeCtrlBlockType, &c.RevokeSweepCtrlBlock,
))
}
if c.OutgoingHtlcCtrlBlocks != nil {
records = append(records, tlv.MakeDynamicRecord(
outgoingHtlcCtrlBlockType, &c.OutgoingHtlcCtrlBlocks,
c.OutgoingHtlcCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
))
}
if c.IncomingHtlcCtrlBlocks != nil {
records = append(records, tlv.MakeDynamicRecord(
incomingHtlcCtrlBlockType, &c.IncomingHtlcCtrlBlocks,
c.IncomingHtlcCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
))
}
if c.SecondLevelCtrlBlocks != nil {
records = append(records, tlv.MakeDynamicRecord(
secondLevelCtrlBlockType, &c.SecondLevelCtrlBlocks,
c.SecondLevelCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
))
}
return records
}
// DecodeRecords returns the set of TLV records that decode the control block.
func (c *ctrlBlocks) DecodeRecords() []tlv.Record {
return []tlv.Record{
tlv.MakePrimitiveRecord(
commitCtrlBlockType, &c.CommitSweepCtrlBlock,
),
tlv.MakePrimitiveRecord(
revokeCtrlBlockType, &c.RevokeSweepCtrlBlock,
),
tlv.MakeDynamicRecord(
outgoingHtlcCtrlBlockType, &c.OutgoingHtlcCtrlBlocks,
c.OutgoingHtlcCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
),
tlv.MakeDynamicRecord(
incomingHtlcCtrlBlockType, &c.IncomingHtlcCtrlBlocks,
c.IncomingHtlcCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
),
tlv.MakeDynamicRecord(
secondLevelCtrlBlockType, &c.SecondLevelCtrlBlocks,
c.SecondLevelCtrlBlocks.recordSize,
resolverCtrlBlocksEncoder, resolverCtrlBlocksDecoder,
),
}
}
// Record returns a TLV record that can be used to encode/decode the control
// blocks. type from a given TLV stream.
func (c *ctrlBlocks) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := ctrlBlockEncoder(&b, c, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, c, recordSize, ctrlBlockEncoder, ctrlBlockDecoder,
)
}
// Encode encodes the set of control blocks.
func (c *ctrlBlocks) Encode(w io.Writer) error {
stream, err := tlv.NewStream(c.EncodeRecords()...)
if err != nil {
return err
}
return stream.Encode(w)
}
// Decode decodes the set of control blocks.
func (c *ctrlBlocks) Decode(r io.Reader) error {
stream, err := tlv.NewStream(c.DecodeRecords()...)
if err != nil {
return err
}
return stream.Decode(r)
}
// htlcTapTweakss maps an outpoint (the same format as the resolver ID) to the
// tap tweak needed to sweep a breached HTLC output. This is used for both the
// first and second level HTLC outputs.
type htlcTapTweaks map[resolverID][32]byte
// newHtlcTapTweaks returns a new instance of the htlcTapTweaks struct.
func newHtlcTapTweaks() htlcTapTweaks {
return make(htlcTapTweaks)
}
// recordSize returns the size of the record in bytes.
func (h *htlcTapTweaks) recordSize() uint64 {
// Each record will be serialized as: <num_tweaks> || <tweak>, where
// <tweak> is serialized as: <resolver_key> || <tweak>.
numTweaks := uint64(len(*h))
baseSize := tlv.VarIntSize(numTweaks)
recordSize := baseSize
for range *h {
// Each tweak is a fixed 32 bytes, so we just tally that an the
// size of the resolver ID.
recordSize += resolverIDLen
recordSize += 32
}
return recordSize
}
// Encode encodes the tap tweaks into the target writer.
func (h *htlcTapTweaks) Encode(w io.Writer) error {
numTweaks := uint64(len(*h))
var buf [8]byte
if err := tlv.WriteVarInt(w, numTweaks, &buf); err != nil {
return err
}
for id, tweak := range *h {
tweak := tweak
if _, err := w.Write(id[:]); err != nil {
return err
}
if _, err := w.Write(tweak[:]); err != nil {
return err
}
}
return nil
}
// htlcTapTweaksEncoder is a custom TLV encoder for the htlcTapTweaks struct.
func htlcTapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(*htlcTapTweaks); ok {
return (*t).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "htlcTapTweaks")
}
// htlcTapTweaksDecoder is a custom TLV decoder for the htlcTapTweaks struct.
func htlcTapTweaksDecoder(r io.Reader, val any, _ *[8]byte,
l uint64) error {
if typ, ok := val.(*htlcTapTweaks); ok {
tweakReader := io.LimitReader(r, int64(l))
htlcTweaks := newHtlcTapTweaks()
err := htlcTweaks.Decode(tweakReader)
if err != nil {
return err
}
*typ = htlcTweaks
return nil
}
return tlv.NewTypeForDecodingErr(val, "htlcTapTweaks", l, l)
}
// Decode decodes the tap tweaks into the target struct.
func (h *htlcTapTweaks) Decode(reader io.Reader) error {
var buf [8]byte
numTweaks, err := tlv.ReadVarInt(reader, &buf)
if err != nil {
return err
}
for i := uint64(0); i < numTweaks; i++ {
var id resolverID
if _, err := io.ReadFull(reader, id[:]); err != nil {
return err
}
var tweak [32]byte
if _, err := io.ReadFull(reader, tweak[:]); err != nil {
return err
}
(*h)[id] = tweak
}
return nil
}
// tapTweaks stores the set of taptweaks needed to perform keyspends for the
// commitment outputs.
type tapTweaks struct {
// AnchorTweak is the tweak used to derive the key used to spend the
// anchor output.
AnchorTweak []byte
// BreachedHtlcTweaks stores the set of tweaks needed to sweep the
// revoked first level output of an HTLC.
BreachedHtlcTweaks htlcTapTweaks
// BreachedSecondLevelHtlcTweaks stores the set of tweaks needed to
// sweep the revoked *second* level output of an HTLC.
BreachedSecondLevelHltcTweaks htlcTapTweaks
}
// newTapTweaks returns a new tapTweaks struct.
func newTapTweaks() tapTweaks {
return tapTweaks{
BreachedHtlcTweaks: make(htlcTapTweaks),
BreachedSecondLevelHltcTweaks: make(htlcTapTweaks),
}
}
// tapTweaksEncoder is a custom TLV encoder for the tapTweaks struct.
func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(*tapTweaks); ok {
return (*t).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "tapTweaks")
}
// tapTweaksDecoder is a custom TLV decoder for the tapTweaks struct.
func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error {
if typ, ok := val.(*tapTweaks); ok {
tweakReader := io.LimitReader(r, int64(l))
var tapTweaks tapTweaks
err := tapTweaks.Decode(tweakReader)
if err != nil {
return err
}
*typ = tapTweaks
return nil
}
return tlv.NewTypeForDecodingErr(val, "tapTweaks", l, l)
}
// EncodeRecords returns the set of TLV records that encode the tweaks.
func (t *tapTweaks) EncodeRecords() []tlv.Record {
var records []tlv.Record
if len(t.AnchorTweak) > 0 {
records = append(records, tlv.MakePrimitiveRecord(
anchorTapTweakType, &t.AnchorTweak,
))
}
if len(t.BreachedHtlcTweaks) > 0 {
records = append(records, tlv.MakeDynamicRecord(
htlcTweakCtrlBlockType, &t.BreachedHtlcTweaks,
t.BreachedHtlcTweaks.recordSize,
htlcTapTweaksEncoder, htlcTapTweaksDecoder,
))
}
if len(t.BreachedSecondLevelHltcTweaks) > 0 {
records = append(records, tlv.MakeDynamicRecord(
secondLevelHtlcTweakCtrlBlockType,
&t.BreachedSecondLevelHltcTweaks,
t.BreachedSecondLevelHltcTweaks.recordSize,
htlcTapTweaksEncoder, htlcTapTweaksDecoder,
))
}
return records
}
// DecodeRecords returns the set of TLV records that decode the tweaks.
func (t *tapTweaks) DecodeRecords() []tlv.Record {
return []tlv.Record{
tlv.MakePrimitiveRecord(anchorTapTweakType, &t.AnchorTweak),
tlv.MakeDynamicRecord(
htlcTweakCtrlBlockType, &t.BreachedHtlcTweaks,
t.BreachedHtlcTweaks.recordSize,
htlcTapTweaksEncoder, htlcTapTweaksDecoder,
),
tlv.MakeDynamicRecord(
secondLevelHtlcTweakCtrlBlockType,
&t.BreachedSecondLevelHltcTweaks,
t.BreachedSecondLevelHltcTweaks.recordSize,
htlcTapTweaksEncoder, htlcTapTweaksDecoder,
),
}
}
// Record returns a TLV record that can be used to encode/decode the tap
// tweaks.
func (t *tapTweaks) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := tapTweaksEncoder(&b, t, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, t, recordSize, tapTweaksEncoder, tapTweaksDecoder,
)
}
// Encode encodes the set of tap tweaks.
func (t *tapTweaks) Encode(w io.Writer) error {
stream, err := tlv.NewStream(t.EncodeRecords()...)
if err != nil {
return err
}
return stream.Encode(w)
}
// Decode decodes the set of tap tweaks.
func (t *tapTweaks) Decode(r io.Reader) error {
stream, err := tlv.NewStream(t.DecodeRecords()...)
if err != nil {
return err
}
return stream.Decode(r)
}
// htlcAuxBlobs is a map of resolver IDs to their corresponding HTLC blobs.
// This is used to store the resolution blobs for HTLCs that are not yet
// resolved.
type htlcAuxBlobs map[resolverID]tlv.Blob
// newAuxHtlcBlobs returns a new instance of the htlcAuxBlobs struct.
func newAuxHtlcBlobs() htlcAuxBlobs {
return make(htlcAuxBlobs)
}
// Encode encodes the set of HTLC blobs into the target writer.
func (h *htlcAuxBlobs) Encode(w io.Writer) error {
var buf [8]byte
numBlobs := uint64(len(*h))
if err := tlv.WriteVarInt(w, numBlobs, &buf); err != nil {
return err
}
for id, blob := range *h {
if _, err := w.Write(id[:]); err != nil {
return err
}
if err := varBytesEncoder(w, &blob, &buf); err != nil {
return err
}
}
return nil
}
// Decode decodes the set of HTLC blobs from the target reader.
func (h *htlcAuxBlobs) Decode(r io.Reader) error {
var buf [8]byte
numBlobs, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
for i := uint64(0); i < numBlobs; i++ {
var id resolverID
if _, err := io.ReadFull(r, id[:]); err != nil {
return err
}
var blob tlv.Blob
if err := varBytesDecoder(r, &blob, &buf, 0); err != nil {
return err
}
(*h)[id] = blob
}
return nil
}
// eHtlcAuxBlobsEncoder is a custom TLV encoder for the htlcAuxBlobs struct.
func htlcAuxBlobsEncoder(w io.Writer, val any, _ *[8]byte) error {
if t, ok := val.(*htlcAuxBlobs); ok {
return (*t).Encode(w)
}
return tlv.NewTypeForEncodingErr(val, "htlcAuxBlobs")
}
// dHtlcAuxBlobsDecoder is a custom TLV decoder for the htlcAuxBlobs struct.
func htlcAuxBlobsDecoder(r io.Reader, val any, _ *[8]byte,
l uint64) error {
if typ, ok := val.(*htlcAuxBlobs); ok {
blobReader := io.LimitReader(r, int64(l))
htlcBlobs := newAuxHtlcBlobs()
err := htlcBlobs.Decode(blobReader)
if err != nil {
return err
}
*typ = htlcBlobs
return nil
}
return tlv.NewTypeForDecodingErr(val, "htlcAuxBlobs", l, l)
}
// Record returns a tlv.Record for the htlcAuxBlobs struct.
func (h *htlcAuxBlobs) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := htlcAuxBlobsEncoder(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, htlcAuxBlobsEncoder, htlcAuxBlobsDecoder,
)
}
package contractcourt
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/tlv"
)
// SUMMARY OF OUTPUT STATES
//
// - CRIB
// - SerializedType: babyOutput
// - OriginalOutputType: HTLC
// - Awaiting: First-stage HTLC CLTV expiry
// - HeightIndexEntry: Absolute block height of CLTV expiry.
// - NextState: KNDR
// - PSCL
// - SerializedType: kidOutput
// - OriginalOutputType: Commitment
// - Awaiting: Confirmation of commitment txn
// - HeightIndexEntry: None.
// - NextState: KNDR
// - KNDR
// - SerializedType: kidOutput
// - OriginalOutputType: Commitment or HTLC
// - Awaiting: Commitment CSV expiry or second-stage HTLC CSV expiry.
// - HeightIndexEntry: Input confirmation height + relative CSV delay
// - NextState: GRAD
// - GRAD:
// - SerializedType: kidOutput
// - OriginalOutputType: Commitment or HTLC
// - Awaiting: All other outputs in channel to become GRAD.
// - NextState: Mark channel fully closed in channeldb and remove.
//
// DESCRIPTION OF OUTPUT STATES
//
// TODO(roasbeef): update comment with both new output types
//
// - CRIB (babyOutput) outputs are two-stage htlc outputs that are initially
// locked using a CLTV delay, followed by a CSV delay. The first stage of a
// crib output requires broadcasting a presigned htlc timeout txn generated
// by the wallet after an absolute expiry height. Since the timeout txns are
// predetermined, they cannot be batched after-the-fact, meaning that all
// CRIB outputs are broadcast and confirmed independently. After the first
// stage is complete, a CRIB output is moved to the KNDR state, which will
// finishing sweeping the second-layer CSV delay.
//
// - PSCL (kidOutput) outputs are commitment outputs locked under a CSV delay.
// These outputs are stored temporarily in this state until the commitment
// transaction confirms, as this solidifies an absolute height that the
// relative time lock will expire. Once this maturity height is determined,
// the PSCL output is moved into KNDR.
//
// - KNDR (kidOutput) outputs are CSV delayed outputs for which the maturity
// height has been fully determined. This results from having received
// confirmation of the UTXO we are trying to spend, contained in either the
// commitment txn or htlc timeout txn. Once the maturity height is reached,
// the utxo nursery will sweep all KNDR outputs scheduled for that height
// using a single txn.
//
// - GRAD (kidOutput) outputs are KNDR outputs that have successfully been
// swept into the user's wallet. A channel is considered mature once all of
// its outputs, including two-stage htlcs, have entered the GRAD state,
// indicating that it safe to mark the channel as fully closed.
//
//
// OUTPUT STATE TRANSITIONS IN UTXO NURSERY
//
// ┌────────────────┐ ┌──────────────┐
// │ Commit Outputs │ │ HTLC Outputs │
// └────────────────┘ └──────────────┘
// │ │
// │ │
// │ │ UTXO NURSERY
// ┌───────────┼────────────────┬───────────┼───────────────────────────────┐
// │ │ │ │
// │ │ │ │ │
// │ │ │ CLTV-Delayed │
// │ │ │ V babyOutputs │
// │ │ ┌──────┐ │
// │ │ │ │ CRIB │ │
// │ │ └──────┘ │
// │ │ │ │ │
// │ │ │ │
// │ │ │ | │
// │ │ V Wait CLTV │
// │ │ │ [ ] + │
// │ │ | Publish Txn │
// │ │ │ │ │
// │ │ │ │
// │ │ │ V ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ │
// │ │ ( ) waitForTimeoutConf │
// │ │ │ | └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ │
// │ │ │ │
// │ │ │ │ │
// │ │ │ │
// │ V │ │ │
// │ ┌──────┐ │ │
// │ │ PSCL │ └ ── ── ─┼ ── ── ── ── ── ── ── ─┤
// │ └──────┘ │ │
// │ │ │ │
// │ │ │ │
// │ V ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ │ CSV-Delayed │
// │ ( ) waitForCommitConf │ kidOutputs │
// │ | └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ │ │
// │ │ │ │
// │ │ │ │
// │ │ V │
// │ │ ┌──────┐ │
// │ └─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─▶│ KNDR │ │
// │ └──────┘ │
// │ │ │
// │ │ │
// │ | │
// │ V Wait CSV │
// │ [ ] + │
// │ | Publish Txn │
// │ │ │
// │ │ │
// │ V ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ │
// │ ( ) waitForSweepConf │
// │ | └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ │
// │ │ │
// │ │ │
// │ V │
// │ ┌──────┐ │
// │ │ GRAD │ │
// │ └──────┘ │
// │ │ │
// │ │ │
// │ │ │
// └────────────────────────────────────────┼───────────────────────────────┘
// │
// │
// │
// │
// V
// ┌────────────────┐
// │ Wallet Outputs │
// └────────────────┘
var byteOrder = binary.BigEndian
const (
// kgtnOutputConfTarget is the default confirmation target we'll use for
// sweeps of CSV delayed outputs.
kgtnOutputConfTarget = 6
)
var (
// ErrContractNotFound is returned when the nursery is unable to
// retrieve information about a queried contract.
ErrContractNotFound = fmt.Errorf("unable to locate contract")
)
// NurseryConfig abstracts the required subsystems used by the utxo nursery. An
// instance of NurseryConfig is passed to newUtxoNursery during instantiation.
type NurseryConfig struct {
// ChainIO is used by the utxo nursery to determine the current block
// height, which drives the incubation of the nursery's outputs.
ChainIO lnwallet.BlockChainIO
// ConfDepth is the number of blocks the nursery store waits before
// determining outputs in the chain as confirmed.
ConfDepth uint32
// FetchClosedChannels provides access to a user's channels, such that
// they can be marked fully closed after incubation has concluded.
FetchClosedChannels func(pendingOnly bool) (
[]*channeldb.ChannelCloseSummary, error)
// FetchClosedChannel provides access to the close summary to extract a
// height hint from.
FetchClosedChannel func(chanID *wire.OutPoint) (
*channeldb.ChannelCloseSummary, error)
// Notifier provides the utxo nursery the ability to subscribe to
// transaction confirmation events, which advance outputs through their
// persistence state transitions.
Notifier chainntnfs.ChainNotifier
// PublishTransaction facilitates the process of broadcasting a signed
// transaction to the appropriate network.
PublishTransaction func(*wire.MsgTx, string) error
// Store provides access to and modification of the persistent state
// maintained about the utxo nursery's incubating outputs.
Store NurseryStorer
// Sweep sweeps an input back to the wallet.
SweepInput func(input.Input, sweep.Params) (chan sweep.Result, error)
// Budget is the configured budget for the nursery.
Budget *BudgetConfig
}
// UtxoNursery is a system dedicated to incubating time-locked outputs created
// by the broadcast of a commitment transaction either by us, or the remote
// peer. The nursery accepts outputs and "incubates" them until they've reached
// maturity, then sweep the outputs into the source wallet. An output is
// considered mature after the relative time-lock within the pkScript has
// passed. As outputs reach their maturity age, they're swept in batches into
// the source wallet, returning the outputs so they can be used within future
// channels, or regular Bitcoin transactions.
type UtxoNursery struct {
started uint32 // To be used atomically.
stopped uint32 // To be used atomically.
cfg *NurseryConfig
mu sync.Mutex
bestHeight uint32
quit chan struct{}
wg sync.WaitGroup
}
// NewUtxoNursery creates a new instance of the UtxoNursery from a
// ChainNotifier and LightningWallet instance.
func NewUtxoNursery(cfg *NurseryConfig) *UtxoNursery {
return &UtxoNursery{
cfg: cfg,
quit: make(chan struct{}),
}
}
// Start launches all goroutines the UtxoNursery needs to properly carry out
// its duties.
func (u *UtxoNursery) Start() error {
if !atomic.CompareAndSwapUint32(&u.started, 0, 1) {
return nil
}
utxnLog.Info("UTXO nursery starting")
// Retrieve the currently best known block. This is needed to have the
// state machine catch up with the blocks we missed when we were down.
bestHash, bestHeight, err := u.cfg.ChainIO.GetBestBlock()
if err != nil {
return err
}
// Set best known height to schedule late registrations properly.
atomic.StoreUint32(&u.bestHeight, uint32(bestHeight))
// 2. Flush all fully-graduated channels from the pipeline.
// Load any pending close channels, which represents the super set of
// all channels that may still be incubating.
pendingCloseChans, err := u.cfg.FetchClosedChannels(true)
if err != nil {
return err
}
// Ensure that all mature channels have been marked as fully closed in
// the channeldb.
for _, pendingClose := range pendingCloseChans {
err := u.closeAndRemoveIfMature(&pendingClose.ChanPoint)
if err != nil {
return err
}
}
// TODO(conner): check if any fully closed channels can be removed from
// utxn.
// 2. Restart spend ntfns for any preschool outputs, which are waiting
// for the force closed commitment txn to confirm, or any second-layer
// HTLC success transactions.
//
// NOTE: The next two steps *may* spawn go routines, thus from this
// point forward, we must close the nursery's quit channel if we detect
// any failures during startup to ensure they terminate.
if err := u.reloadPreschool(); err != nil {
close(u.quit)
return err
}
// 3. Replay all crib and kindergarten outputs up to the current best
// height.
if err := u.reloadClasses(uint32(bestHeight)); err != nil {
close(u.quit)
return err
}
// Start watching for new blocks, as this will drive the nursery store's
// state machine.
newBlockChan, err := u.cfg.Notifier.RegisterBlockEpochNtfn(&chainntnfs.BlockEpoch{
Height: bestHeight,
Hash: bestHash,
})
if err != nil {
close(u.quit)
return err
}
u.wg.Add(1)
go u.incubator(newBlockChan)
return nil
}
// Stop gracefully shuts down any lingering goroutines launched during normal
// operation of the UtxoNursery.
func (u *UtxoNursery) Stop() error {
if !atomic.CompareAndSwapUint32(&u.stopped, 0, 1) {
return nil
}
utxnLog.Infof("UTXO nursery shutting down...")
defer utxnLog.Debug("UTXO nursery shutdown complete")
close(u.quit)
u.wg.Wait()
return nil
}
// IncubateOutputs sends a request to the UtxoNursery to incubate a set of
// outputs from an existing commitment transaction. Outputs need to incubate if
// they're CLTV absolute time locked, or if they're CSV relative time locked.
// Once all outputs reach maturity, they'll be swept back into the wallet.
func (u *UtxoNursery) IncubateOutputs(chanPoint wire.OutPoint,
outgoingHtlc fn.Option[lnwallet.OutgoingHtlcResolution],
incomingHtlc fn.Option[lnwallet.IncomingHtlcResolution],
broadcastHeight uint32, deadlineHeight fn.Option[int32]) error {
// Add to wait group because nursery might shut down during execution of
// this function. Otherwise it could happen that nursery thinks it is
// shut down, but in this function new goroutines were started and stay
// around.
u.wg.Add(1)
defer u.wg.Done()
// Check quit channel for the case where the waitgroup wait was finished
// right before this function's add call was made.
select {
case <-u.quit:
return fmt.Errorf("nursery shutting down")
default:
}
var (
// Kid outputs can be swept after an initial confirmation
// followed by a maturity period.Baby outputs are two stage and
// will need to wait for an absolute time out to reach a
// confirmation, then require a relative confirmation delay.
kidOutputs = make([]kidOutput, 0)
babyOutputs = make([]babyOutput, 0)
)
// 1. Build all the spendable outputs that we will try to incubate.
// TODO(roasbeef): query and see if we already have, if so don't add?
// For each incoming HTLC, we'll register a kid output marked as a
// second-layer HTLC output. We effectively skip the baby stage (as the
// timelock is zero), and enter the kid stage.
incomingHtlc.WhenSome(func(htlcRes lnwallet.IncomingHtlcResolution) {
// Based on the input pk script of the sign descriptor, we can
// determine if this is a taproot output or not. This'll
// determine the witness type we try to set below.
isTaproot := txscript.IsPayToTaproot(
htlcRes.SweepSignDesc.Output.PkScript,
)
var witType input.StandardWitnessType
if isTaproot {
witType = input.TaprootHtlcAcceptedSuccessSecondLevel
} else {
witType = input.HtlcAcceptedSuccessSecondLevel
}
htlcOutput := makeKidOutput(
&htlcRes.ClaimOutpoint, &chanPoint, htlcRes.CsvDelay,
witType, &htlcRes.SweepSignDesc, 0, deadlineHeight,
)
if htlcOutput.Amount() > 0 {
kidOutputs = append(kidOutputs, htlcOutput)
}
})
// For each outgoing HTLC, we'll create a baby output. If this is our
// commitment transaction, then we'll broadcast a second-layer
// transaction to transition to a kid output. Otherwise, we'll directly
// spend once the CLTV delay us up.
outgoingHtlc.WhenSome(func(htlcRes lnwallet.OutgoingHtlcResolution) {
// If this HTLC is on our commitment transaction, then it'll be
// a baby output as we need to go to the second level to sweep
// it.
if htlcRes.SignedTimeoutTx != nil {
htlcOutput := makeBabyOutput(
&chanPoint, &htlcRes, deadlineHeight,
)
if htlcOutput.Amount() > 0 {
babyOutputs = append(babyOutputs, htlcOutput)
}
return
}
// Based on the input pk script of the sign descriptor, we can
// determine if this is a taproot output or not. This'll
// determine the witness type we try to set below.
isTaproot := txscript.IsPayToTaproot(
htlcRes.SweepSignDesc.Output.PkScript,
)
var witType input.StandardWitnessType
if isTaproot {
witType = input.TaprootHtlcOfferedRemoteTimeout
} else {
witType = input.HtlcOfferedRemoteTimeout
}
// Otherwise, this is actually a kid output as we can sweep it
// once the commitment transaction confirms, and the absolute
// CLTV lock has expired. We set the CSV delay what the
// resolution encodes, since the sequence number must be set
// accordingly.
htlcOutput := makeKidOutput(
&htlcRes.ClaimOutpoint, &chanPoint, htlcRes.CsvDelay,
witType, &htlcRes.SweepSignDesc, htlcRes.Expiry,
deadlineHeight,
)
kidOutputs = append(kidOutputs, htlcOutput)
})
// TODO(roasbeef): if want to handle outgoing on remote commit
// * need ability to cancel in the case that we learn of pre-image or
// remote party pulls
numHtlcs := len(babyOutputs) + len(kidOutputs)
utxnLog.Infof("Incubating Channel(%s) num-htlcs=%d",
chanPoint, numHtlcs)
u.mu.Lock()
defer u.mu.Unlock()
// 2. Persist the outputs we intended to sweep in the nursery store
if err := u.cfg.Store.Incubate(kidOutputs, babyOutputs); err != nil {
utxnLog.Errorf("unable to begin incubation of Channel(%s): %v",
chanPoint, err)
return err
}
// As an intermediate step, we'll now check to see if any of the baby
// outputs has actually _already_ expired. This may be the case if
// blocks were mined while we processed this message.
_, bestHeight, err := u.cfg.ChainIO.GetBestBlock()
if err != nil {
return err
}
// We'll examine all the baby outputs just inserted into the database,
// if the output has already expired, then we'll *immediately* sweep
// it. This may happen if the caller raced a block to call this method.
for i, babyOutput := range babyOutputs {
if uint32(bestHeight) >= babyOutput.expiry {
err = u.sweepCribOutput(
babyOutput.expiry, &babyOutputs[i],
)
if err != nil {
return err
}
}
}
// 3. If we are incubating any preschool outputs, register for a
// confirmation notification that will transition it to the
// kindergarten bucket.
if len(kidOutputs) != 0 {
for i := range kidOutputs {
err := u.registerPreschoolConf(
&kidOutputs[i], broadcastHeight,
)
if err != nil {
return err
}
}
}
return nil
}
// NurseryReport attempts to return a nursery report stored for the target
// outpoint. A nursery report details the maturity/sweeping progress for a
// contract that was previously force closed. If a report entry for the target
// chanPoint is unable to be constructed, then an error will be returned.
func (u *UtxoNursery) NurseryReport(
chanPoint *wire.OutPoint) (*ContractMaturityReport, error) {
u.mu.Lock()
defer u.mu.Unlock()
utxnLog.Debugf("NurseryReport: building nursery report for channel %v",
chanPoint)
var report *ContractMaturityReport
if err := u.cfg.Store.ForChanOutputs(chanPoint, func(k, v []byte) error {
switch {
case bytes.HasPrefix(k, cribPrefix):
// Cribs outputs are the only kind currently stored as
// baby outputs.
var baby babyOutput
err := baby.Decode(bytes.NewReader(v))
if err != nil {
return err
}
// Each crib output represents a stage one htlc, and
// will contribute towards the limbo balance.
report.AddLimboStage1TimeoutHtlc(&baby)
case bytes.HasPrefix(k, psclPrefix),
bytes.HasPrefix(k, kndrPrefix),
bytes.HasPrefix(k, gradPrefix):
// All others states can be deserialized as kid outputs.
var kid kidOutput
err := kid.Decode(bytes.NewReader(v))
if err != nil {
return err
}
// Now, use the state prefixes to determine how the
// this output should be represented in the nursery
// report. An output's funds are always in limbo until
// reaching the graduate state.
switch {
case bytes.HasPrefix(k, psclPrefix):
// Preschool outputs are awaiting the
// confirmation of the commitment transaction.
switch kid.WitnessType() {
//nolint:ll
case input.TaprootHtlcAcceptedSuccessSecondLevel:
fallthrough
case input.HtlcAcceptedSuccessSecondLevel:
// An HTLC output on our commitment
// transaction where the second-layer
// transaction hasn't
// yet confirmed.
report.AddLimboStage1SuccessHtlc(&kid)
case input.HtlcOfferedRemoteTimeout,
input.TaprootHtlcOfferedRemoteTimeout:
// This is an HTLC output on the
// commitment transaction of the remote
// party. We are waiting for the CLTV
// timelock expire.
report.AddLimboDirectHtlc(&kid)
}
case bytes.HasPrefix(k, kndrPrefix):
// Kindergarten outputs may originate from
// either the commitment transaction or an htlc.
// We can distinguish them via their witness
// types.
switch kid.WitnessType() {
case input.HtlcOfferedRemoteTimeout,
input.TaprootHtlcOfferedRemoteTimeout:
// This is an HTLC output on the
// commitment transaction of the remote
// party. The CLTV timelock has
// expired, and we only need to sweep
// it.
report.AddLimboDirectHtlc(&kid)
//nolint:ll
case input.TaprootHtlcAcceptedSuccessSecondLevel:
fallthrough
case input.TaprootHtlcOfferedTimeoutSecondLevel:
fallthrough
case input.HtlcAcceptedSuccessSecondLevel:
fallthrough
case input.HtlcOfferedTimeoutSecondLevel:
// The htlc timeout or success
// transaction has confirmed, and the
// CSV delay has begun ticking.
report.AddLimboStage2Htlc(&kid)
}
case bytes.HasPrefix(k, gradPrefix):
// Graduate outputs are those whose funds have
// been swept back into the wallet. Each output
// will contribute towards the recovered
// balance.
switch kid.WitnessType() {
//nolint:ll
case input.TaprootHtlcAcceptedSuccessSecondLevel:
fallthrough
case input.TaprootHtlcOfferedTimeoutSecondLevel:
fallthrough
case input.HtlcAcceptedSuccessSecondLevel:
fallthrough
case input.HtlcOfferedTimeoutSecondLevel:
fallthrough
case input.TaprootHtlcOfferedRemoteTimeout:
fallthrough
case input.HtlcOfferedRemoteTimeout:
// This htlc output successfully
// resides in a p2wkh output belonging
// to the user.
report.AddRecoveredHtlc(&kid)
}
}
default:
}
return nil
}, func() {
report = &ContractMaturityReport{}
}); err != nil {
return nil, err
}
return report, nil
}
// reloadPreschool re-initializes the chain notifier with all of the outputs
// that had been saved to the "preschool" database bucket prior to shutdown.
func (u *UtxoNursery) reloadPreschool() error {
psclOutputs, err := u.cfg.Store.FetchPreschools()
if err != nil {
return err
}
// For each of the preschool outputs stored in the nursery store, load
// its close summary from disk so that we can get an accurate height
// hint from which to start our range for spend notifications.
for i := range psclOutputs {
kid := &psclOutputs[i]
chanPoint := kid.OriginChanPoint()
// Load the close summary for this output's channel point.
closeSummary, err := u.cfg.FetchClosedChannel(chanPoint)
if err == channeldb.ErrClosedChannelNotFound {
// This should never happen since the close summary
// should only be removed after the channel has been
// swept completely.
utxnLog.Warnf("Close summary not found for "+
"chan_point=%v, can't determine height hint"+
"to sweep commit txn", chanPoint)
continue
} else if err != nil {
return err
}
// Use the close height from the channel summary as our height
// hint to drive our spend notifications, with our confirmation
// depth as a buffer for reorgs.
heightHint := closeSummary.CloseHeight - u.cfg.ConfDepth
err = u.registerPreschoolConf(kid, heightHint)
if err != nil {
return err
}
}
return nil
}
// reloadClasses reinitializes any height-dependent state transitions for which
// the utxonursery has not received confirmation, and replays the graduation of
// all kindergarten and crib outputs for all heights up to the current block.
// This allows the nursery to reinitialize all state to continue sweeping
// outputs, even in the event that we missed blocks while offline. reloadClasses
// is called during the startup of the UTXO Nursery.
func (u *UtxoNursery) reloadClasses(bestHeight uint32) error {
// Loading all active heights up to and including the current block.
activeHeights, err := u.cfg.Store.HeightsBelowOrEqual(
uint32(bestHeight))
if err != nil {
return err
}
// Return early if nothing to sweep.
if len(activeHeights) == 0 {
return nil
}
utxnLog.Infof("(Re)-sweeping %d heights below height=%d",
len(activeHeights), bestHeight)
// Attempt to re-register notifications for any outputs still at these
// heights.
for _, classHeight := range activeHeights {
utxnLog.Debugf("Attempting to sweep outputs at height=%v",
classHeight)
if err = u.graduateClass(classHeight); err != nil {
utxnLog.Errorf("Failed to sweep outputs at "+
"height=%v: %v", classHeight, err)
return err
}
}
utxnLog.Infof("UTXO Nursery is now fully synced")
return nil
}
// incubator is tasked with driving all state transitions that are dependent on
// the current height of the blockchain. As new blocks arrive, the incubator
// will attempt spend outputs at the latest height. The asynchronous
// confirmation of these spends will either 1) move a crib output into the
// kindergarten bucket or 2) move a kindergarten output into the graduated
// bucket.
func (u *UtxoNursery) incubator(newBlockChan *chainntnfs.BlockEpochEvent) {
defer u.wg.Done()
defer newBlockChan.Cancel()
for {
select {
case epoch, ok := <-newBlockChan.Epochs:
// If the epoch channel has been closed, then the
// ChainNotifier is exiting which means the daemon is
// as well. Therefore, we exit early also in order to
// ensure the daemon shuts down gracefully, yet
// swiftly.
if !ok {
return
}
// TODO(roasbeef): if the BlockChainIO is rescanning
// will give stale data
// A new block has just been connected to the main
// chain, which means we might be able to graduate crib
// or kindergarten outputs at this height. This involves
// broadcasting any presigned htlc timeout txns, as well
// as signing and broadcasting a sweep txn that spends
// from all kindergarten outputs at this height.
height := uint32(epoch.Height)
// Update best known block height for late registrations
// to be scheduled properly.
atomic.StoreUint32(&u.bestHeight, height)
if err := u.graduateClass(height); err != nil {
utxnLog.Errorf("error while graduating "+
"class at height=%d: %v", height, err)
// TODO(conner): signal fatal error to daemon
}
case <-u.quit:
return
}
}
}
// graduateClass handles the steps involved in spending outputs whose CSV or
// CLTV delay expires at the nursery's current height. This method is called
// each time a new block arrives, or during startup to catch up on heights we
// may have missed while the nursery was offline.
func (u *UtxoNursery) graduateClass(classHeight uint32) error {
// Record this height as the nursery's current best height.
u.mu.Lock()
defer u.mu.Unlock()
// Fetch all information about the crib and kindergarten outputs at
// this height.
kgtnOutputs, cribOutputs, err := u.cfg.Store.FetchClass(
classHeight,
)
if err != nil {
return err
}
utxnLog.Debugf("Attempting to graduate height=%v: num_kids=%v, "+
"num_babies=%v", classHeight, len(kgtnOutputs), len(cribOutputs))
// Offer the outputs to the sweeper and set up notifications that will
// transition the swept kindergarten outputs and cltvCrib into graduated
// outputs.
if len(kgtnOutputs) > 0 {
if err := u.sweepMatureOutputs(classHeight, kgtnOutputs); err != nil {
utxnLog.Errorf("Failed to sweep %d kindergarten "+
"outputs at height=%d: %v",
len(kgtnOutputs), classHeight, err)
return err
}
}
// Now, we broadcast all pre-signed htlc txns from the csv crib outputs
// at this height.
for i := range cribOutputs {
err := u.sweepCribOutput(classHeight, &cribOutputs[i])
if err != nil {
utxnLog.Errorf("Failed to sweep first-stage HTLC "+
"(CLTV-delayed) output %v",
cribOutputs[i].OutPoint())
return err
}
}
return nil
}
// decideDeadlineAndBudget returns the deadline and budget for a given output.
func (u *UtxoNursery) decideDeadlineAndBudget(k kidOutput) (fn.Option[int32],
btcutil.Amount) {
// Assume this is a to_local output and use a None deadline.
deadline := fn.None[int32]()
// Exit early if this is not HTLC.
if !k.isHtlc {
budget := calculateBudget(
k.amt, u.cfg.Budget.ToLocalRatio, u.cfg.Budget.ToLocal,
)
return deadline, budget
}
// Otherwise it's the first-level HTLC output, we'll use the
// time-sensitive settings for it.
budget := calculateBudget(
k.amt, u.cfg.Budget.DeadlineHTLCRatio,
u.cfg.Budget.DeadlineHTLC,
)
return k.deadlineHeight, budget
}
// sweepMatureOutputs generates and broadcasts the transaction that transfers
// control of funds from a prior channel commitment transaction to the user's
// wallet. The outputs swept were previously time locked (either absolute or
// relative), but are not mature enough to sweep into the wallet.
func (u *UtxoNursery) sweepMatureOutputs(classHeight uint32,
kgtnOutputs []kidOutput) error {
utxnLog.Infof("Sweeping %v CSV-delayed outputs with sweep tx for "+
"height %v", len(kgtnOutputs), classHeight)
for _, output := range kgtnOutputs {
// Create local copy to prevent pointer to loop variable to be
// passed in with disastrous consequences.
local := output
// Calculate the deadline height and budget for this output.
deadline, budget := u.decideDeadlineAndBudget(local)
resultChan, err := u.cfg.SweepInput(&local, sweep.Params{
DeadlineHeight: deadline,
Budget: budget,
})
if err != nil {
return err
}
u.wg.Add(1)
go u.waitForSweepConf(classHeight, &local, resultChan)
}
return nil
}
// waitForSweepConf watches for the confirmation of a sweep transaction
// containing a batch of kindergarten outputs. Once confirmation has been
// received, the nursery will mark those outputs as fully graduated, and proceed
// to mark any mature channels as fully closed in channeldb.
// NOTE(conner): this method MUST be called as a go routine.
func (u *UtxoNursery) waitForSweepConf(classHeight uint32,
output *kidOutput, resultChan chan sweep.Result) {
defer u.wg.Done()
select {
case result, ok := <-resultChan:
if !ok {
utxnLog.Errorf("Notification chan closed, can't" +
" advance graduating output")
return
}
// In case of a remote spend, still graduate the output. There
// is no way to sweep it anymore.
if result.Err == sweep.ErrRemoteSpend {
utxnLog.Infof("Output %v was spend by remote party",
output.OutPoint())
break
}
if result.Err != nil {
utxnLog.Errorf("Failed to sweep %v at "+
"height=%d", output.OutPoint(),
classHeight)
return
}
case <-u.quit:
return
}
u.mu.Lock()
defer u.mu.Unlock()
// TODO(conner): add retry utxnLogic?
// Mark the confirmed kindergarten output as graduated.
if err := u.cfg.Store.GraduateKinder(classHeight, output); err != nil {
utxnLog.Errorf("Unable to graduate kindergarten output %v: %v",
output.OutPoint(), err)
return
}
utxnLog.Infof("Graduated kindergarten output from height=%d",
classHeight)
// Attempt to close the channel, only doing so if all of the channel's
// outputs have been graduated.
chanPoint := output.OriginChanPoint()
if err := u.closeAndRemoveIfMature(chanPoint); err != nil {
utxnLog.Errorf("Failed to close and remove channel %v",
*chanPoint)
return
}
}
// sweepCribOutput broadcasts the crib output's htlc timeout txn, and sets up a
// notification that will advance it to the kindergarten bucket upon
// confirmation.
func (u *UtxoNursery) sweepCribOutput(classHeight uint32, baby *babyOutput) error {
utxnLog.Infof("Publishing CLTV-delayed HTLC output using timeout tx "+
"(txid=%v): %v", baby.timeoutTx.TxHash(),
lnutils.SpewLogClosure(baby.timeoutTx))
// We'll now broadcast the HTLC transaction, then wait for it to be
// confirmed before transitioning it to kindergarten.
label := labels.MakeLabel(labels.LabelTypeSweepTransaction, nil)
err := u.cfg.PublishTransaction(baby.timeoutTx, label)
// In case the tx does not meet mempool fee requirements we continue
// because the tx is rebroadcasted in the background and there is
// nothing we can do to bump this transaction anyways.
if err != nil && !errors.Is(err, lnwallet.ErrDoubleSpend) &&
!errors.Is(err, lnwallet.ErrMempoolFee) {
utxnLog.Errorf("Unable to broadcast baby tx: "+
"%v, %v", err, spew.Sdump(baby.timeoutTx))
return err
}
return u.registerTimeoutConf(baby, classHeight)
}
// registerTimeoutConf is responsible for subscribing to confirmation
// notification for an htlc timeout transaction. If successful, a goroutine
// will be spawned that will transition the provided baby output into the
// kindergarten state within the nursery store.
func (u *UtxoNursery) registerTimeoutConf(baby *babyOutput,
heightHint uint32) error {
birthTxID := baby.timeoutTx.TxHash()
// Register for the confirmation of presigned htlc txn.
confChan, err := u.cfg.Notifier.RegisterConfirmationsNtfn(
&birthTxID, baby.timeoutTx.TxOut[0].PkScript, u.cfg.ConfDepth,
heightHint,
)
if err != nil {
return err
}
utxnLog.Infof("Htlc output %v registered for promotion "+
"notification.", baby.OutPoint())
u.wg.Add(1)
go u.waitForTimeoutConf(baby, confChan)
return nil
}
// waitForTimeoutConf watches for the confirmation of an htlc timeout
// transaction, and attempts to move the htlc output from the crib bucket to the
// kindergarten bucket upon success.
func (u *UtxoNursery) waitForTimeoutConf(baby *babyOutput,
confChan *chainntnfs.ConfirmationEvent) {
defer u.wg.Done()
select {
case txConfirmation, ok := <-confChan.Confirmed:
if !ok {
utxnLog.Debugf("Notification chan "+
"closed, can't advance baby output %v",
baby.OutPoint())
return
}
baby.SetConfHeight(txConfirmation.BlockHeight)
case <-u.quit:
return
}
u.mu.Lock()
defer u.mu.Unlock()
// TODO(conner): add retry utxnLogic?
err := u.cfg.Store.CribToKinder(baby)
if err != nil {
utxnLog.Errorf("Unable to move htlc output from "+
"crib to kindergarten bucket: %v", err)
return
}
utxnLog.Infof("Htlc output %v promoted to "+
"kindergarten", baby.OutPoint())
}
// registerPreschoolConf is responsible for subscribing to the confirmation of
// a commitment transaction, or an htlc success transaction for an incoming
// HTLC on our commitment transaction.. If successful, the provided preschool
// output will be moved persistently into the kindergarten state within the
// nursery store.
func (u *UtxoNursery) registerPreschoolConf(kid *kidOutput, heightHint uint32) error {
txID := kid.OutPoint().Hash
// TODO(roasbeef): ensure we don't already have one waiting, need to
// de-duplicate
// * need to do above?
pkScript := kid.signDesc.Output.PkScript
confChan, err := u.cfg.Notifier.RegisterConfirmationsNtfn(
&txID, pkScript, u.cfg.ConfDepth, heightHint,
)
if err != nil {
return err
}
var outputType string
if kid.isHtlc {
outputType = "HTLC"
} else {
outputType = "Commitment"
}
utxnLog.Infof("%v outpoint %v registered for "+
"confirmation notification.", outputType, kid.OutPoint())
u.wg.Add(1)
go u.waitForPreschoolConf(kid, confChan)
return nil
}
// waitForPreschoolConf is intended to be run as a goroutine that will wait until
// a channel force close commitment transaction, or a second layer HTLC success
// transaction has been included in a confirmed block. Once the transaction has
// been confirmed (as reported by the Chain Notifier), waitForPreschoolConf
// will delete the output from the "preschool" database bucket and atomically
// add it to the "kindergarten" database bucket. This is the second step in
// the output incubation process.
func (u *UtxoNursery) waitForPreschoolConf(kid *kidOutput,
confChan *chainntnfs.ConfirmationEvent) {
defer u.wg.Done()
select {
case txConfirmation, ok := <-confChan.Confirmed:
if !ok {
utxnLog.Errorf("Notification chan "+
"closed, can't advance output %v",
kid.OutPoint())
return
}
kid.SetConfHeight(txConfirmation.BlockHeight)
case <-u.quit:
return
}
u.mu.Lock()
defer u.mu.Unlock()
// TODO(conner): add retry utxnLogic?
var outputType string
if kid.isHtlc {
outputType = "HTLC"
} else {
outputType = "Commitment"
}
bestHeight := atomic.LoadUint32(&u.bestHeight)
err := u.cfg.Store.PreschoolToKinder(kid, bestHeight)
if err != nil {
utxnLog.Errorf("Unable to move %v output "+
"from preschool to kindergarten bucket: %v",
outputType, err)
return
}
}
// RemoveChannel channel erases all entries from the channel bucket for the
// provided channel point.
func (u *UtxoNursery) RemoveChannel(op *wire.OutPoint) error {
return u.cfg.Store.RemoveChannel(op)
}
// ContractMaturityReport is a report that details the maturity progress of a
// particular force closed contract.
type ContractMaturityReport struct {
// limboBalance is the total number of frozen coins within this
// contract.
LimboBalance btcutil.Amount
// recoveredBalance is the total value that has been successfully swept
// back to the user's wallet.
RecoveredBalance btcutil.Amount
// htlcs records a maturity report for each htlc output in this channel.
Htlcs []HtlcMaturityReport
}
// HtlcMaturityReport provides a summary of a single htlc output, and is
// embedded as party of the overarching ContractMaturityReport.
type HtlcMaturityReport struct {
// Outpoint is the final output that will be swept back to the wallet.
Outpoint wire.OutPoint
// Amount is the final value that will be swept in back to the wallet.
Amount btcutil.Amount
// MaturityHeight is the absolute block height that this output will
// mature at.
MaturityHeight uint32
// Stage indicates whether the htlc is in the CLTV-timeout stage (1) or
// the CSV-delay stage (2). A stage 1 htlc's maturity height will be set
// to its expiry height, while a stage 2 htlc's maturity height will be
// set to its confirmation height plus the maturity requirement.
Stage uint32
}
// AddLimboStage1TimeoutHtlc adds an htlc crib output to the maturity report's
// htlcs, and contributes its amount to the limbo balance.
func (c *ContractMaturityReport) AddLimboStage1TimeoutHtlc(baby *babyOutput) {
c.LimboBalance += baby.Amount()
// TODO(roasbeef): bool to indicate stage 1 vs stage 2?
c.Htlcs = append(c.Htlcs, HtlcMaturityReport{
Outpoint: baby.OutPoint(),
Amount: baby.Amount(),
MaturityHeight: baby.expiry,
Stage: 1,
})
}
// AddLimboDirectHtlc adds a direct HTLC on the commitment transaction of the
// remote party to the maturity report. This a CLTV time-locked output that
// has or hasn't expired yet.
func (c *ContractMaturityReport) AddLimboDirectHtlc(kid *kidOutput) {
c.LimboBalance += kid.Amount()
htlcReport := HtlcMaturityReport{
Outpoint: kid.OutPoint(),
Amount: kid.Amount(),
MaturityHeight: kid.absoluteMaturity,
Stage: 2,
}
c.Htlcs = append(c.Htlcs, htlcReport)
}
// AddLimboStage1SuccessHtlcHtlc adds an htlc crib output to the maturity
// report's set of HTLC's. We'll use this to report any incoming HTLC sweeps
// where the second level transaction hasn't yet confirmed.
func (c *ContractMaturityReport) AddLimboStage1SuccessHtlc(kid *kidOutput) {
c.LimboBalance += kid.Amount()
c.Htlcs = append(c.Htlcs, HtlcMaturityReport{
Outpoint: kid.OutPoint(),
Amount: kid.Amount(),
Stage: 1,
})
}
// AddLimboStage2Htlc adds an htlc kindergarten output to the maturity report's
// htlcs, and contributes its amount to the limbo balance.
func (c *ContractMaturityReport) AddLimboStage2Htlc(kid *kidOutput) {
c.LimboBalance += kid.Amount()
htlcReport := HtlcMaturityReport{
Outpoint: kid.OutPoint(),
Amount: kid.Amount(),
Stage: 2,
}
// If the confirmation height is set, then this means the first stage
// has been confirmed, and we know the final maturity height of the CSV
// delay.
if kid.ConfHeight() != 0 {
htlcReport.MaturityHeight = kid.ConfHeight() + kid.BlocksToMaturity()
}
c.Htlcs = append(c.Htlcs, htlcReport)
}
// AddRecoveredHtlc adds a graduate output to the maturity report's htlcs, and
// contributes its amount to the recovered balance.
func (c *ContractMaturityReport) AddRecoveredHtlc(kid *kidOutput) {
c.RecoveredBalance += kid.Amount()
c.Htlcs = append(c.Htlcs, HtlcMaturityReport{
Outpoint: kid.OutPoint(),
Amount: kid.Amount(),
MaturityHeight: kid.ConfHeight() + kid.BlocksToMaturity(),
})
}
// closeAndRemoveIfMature removes a particular channel from the channel index
// if and only if all of its outputs have been marked graduated. If the channel
// still has ungraduated outputs, the method will succeed without altering the
// database state.
func (u *UtxoNursery) closeAndRemoveIfMature(chanPoint *wire.OutPoint) error {
isMature, err := u.cfg.Store.IsMatureChannel(chanPoint)
if err == ErrContractNotFound {
return nil
} else if err != nil {
utxnLog.Errorf("Unable to determine maturity of "+
"channel=%s", chanPoint)
return err
}
// Nothing to do if we are still incubating.
if !isMature {
return nil
}
// Now that the channel is fully closed, we remove the channel from the
// nursery store here. This preserves the invariant that we never remove
// a channel unless it is mature, as this is the only place the utxo
// nursery removes a channel.
if err := u.cfg.Store.RemoveChannel(chanPoint); err != nil {
utxnLog.Errorf("Unable to remove channel=%s from "+
"nursery store: %v", chanPoint, err)
return err
}
utxnLog.Infof("Removed channel %v from nursery store", chanPoint)
return nil
}
// babyOutput represents a two-stage CSV locked output, and is used to track
// htlc outputs through incubation. The first stage requires broadcasting a
// presigned timeout txn that spends from the CLTV locked output on the
// commitment txn. A babyOutput is treated as a subset of CsvSpendableOutputs,
// with the additional constraint that a transaction must be broadcast before
// it can be spent. Each baby transaction embeds the kidOutput that can later
// be used to spend the CSV output contained in the timeout txn.
//
// TODO(roasbeef): re-rename to timeout tx
// - create CltvCsvSpendableOutput
type babyOutput struct {
// expiry is the absolute block height at which the secondLevelTx
// should be broadcast to the network.
//
// NOTE: This value will be zero if this is a baby output for a prior
// incoming HTLC.
expiry uint32
// timeoutTx is a fully-signed transaction that, upon confirmation,
// transitions the htlc into the delay+claim stage.
timeoutTx *wire.MsgTx
// kidOutput represents the CSV output to be swept from the
// secondLevelTx after it has been broadcast and confirmed.
kidOutput
}
// makeBabyOutput constructs a baby output that wraps a future kidOutput. The
// provided sign descriptors and witness types will be used once the output
// reaches the delay and claim stage.
func makeBabyOutput(chanPoint *wire.OutPoint,
htlcResolution *lnwallet.OutgoingHtlcResolution,
deadlineHeight fn.Option[int32]) babyOutput {
htlcOutpoint := htlcResolution.ClaimOutpoint
blocksToMaturity := htlcResolution.CsvDelay
isTaproot := txscript.IsPayToTaproot(
htlcResolution.SweepSignDesc.Output.PkScript,
)
var witnessType input.StandardWitnessType
if isTaproot {
witnessType = input.TaprootHtlcOfferedTimeoutSecondLevel
} else {
witnessType = input.HtlcOfferedTimeoutSecondLevel
}
kid := makeKidOutput(
&htlcOutpoint, chanPoint, blocksToMaturity, witnessType,
&htlcResolution.SweepSignDesc, 0, deadlineHeight,
)
return babyOutput{
kidOutput: kid,
expiry: htlcResolution.Expiry,
timeoutTx: htlcResolution.SignedTimeoutTx,
}
}
// Encode writes the baby output to the given io.Writer.
func (bo *babyOutput) Encode(w io.Writer) error {
var scratch [4]byte
byteOrder.PutUint32(scratch[:], bo.expiry)
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := bo.timeoutTx.Serialize(w); err != nil {
return err
}
return bo.kidOutput.Encode(w)
}
// Decode reconstructs a baby output using the provided io.Reader.
func (bo *babyOutput) Decode(r io.Reader) error {
var scratch [4]byte
if _, err := r.Read(scratch[:]); err != nil {
return err
}
bo.expiry = byteOrder.Uint32(scratch[:])
bo.timeoutTx = new(wire.MsgTx)
if err := bo.timeoutTx.Deserialize(r); err != nil {
return err
}
return bo.kidOutput.Decode(r)
}
// kidOutput represents an output that's waiting for a required blockheight
// before its funds will be available to be moved into the user's wallet. The
// struct includes a WitnessGenerator closure which will be used to generate
// the witness required to sweep the output once it's mature.
//
// TODO(roasbeef): rename to immatureOutput?
type kidOutput struct {
breachedOutput
originChanPoint wire.OutPoint
// isHtlc denotes if this kid output is an HTLC output or not. This
// value will be used to determine how to report this output within the
// nursery report.
isHtlc bool
// blocksToMaturity is the relative CSV delay required after initial
// confirmation of the commitment transaction before we can sweep this
// output.
//
// NOTE: This will be set for: commitment outputs, and incoming HTLC's.
// Otherwise, this will be zero. It will also be non-zero for
// commitment types which requires confirmed spends.
blocksToMaturity uint32
// absoluteMaturity is the absolute height that this output will be
// mature at. In order to sweep the output after this height, the
// locktime of sweep transaction will need to be set to this value.
//
// NOTE: This will only be set for: outgoing HTLC's on the commitment
// transaction of the remote party.
absoluteMaturity uint32
// deadlineHeight is the absolute height that this output should be
// confirmed at. For an incoming HTLC, this is the CLTV expiry height.
// For outgoing HTLC, this is its corresponding incoming HTLC's CLTV
// expiry height.
deadlineHeight fn.Option[int32]
}
func makeKidOutput(outpoint, originChanPoint *wire.OutPoint,
blocksToMaturity uint32, witnessType input.StandardWitnessType,
signDescriptor *input.SignDescriptor, absoluteMaturity uint32,
deadlineHeight fn.Option[int32]) kidOutput {
// This is an HTLC either if it's an incoming HTLC on our commitment
// transaction, or is an outgoing HTLC on the commitment transaction of
// the remote peer.
isHtlc := (witnessType == input.HtlcAcceptedSuccessSecondLevel ||
witnessType == input.TaprootHtlcAcceptedSuccessSecondLevel ||
witnessType == input.TaprootHtlcOfferedRemoteTimeout ||
witnessType == input.HtlcOfferedRemoteTimeout)
// heightHint can be safely set to zero here, because after this
// function returns, nursery will set a proper confirmation height in
// waitForTimeoutConf or waitForPreschoolConf.
heightHint := uint32(0)
return kidOutput{
breachedOutput: makeBreachedOutput(
outpoint, witnessType, nil, signDescriptor, heightHint,
fn.None[tlv.Blob](),
),
isHtlc: isHtlc,
originChanPoint: *originChanPoint,
blocksToMaturity: blocksToMaturity,
absoluteMaturity: absoluteMaturity,
deadlineHeight: deadlineHeight,
}
}
func (k *kidOutput) OriginChanPoint() *wire.OutPoint {
return &k.originChanPoint
}
func (k *kidOutput) BlocksToMaturity() uint32 {
return k.blocksToMaturity
}
func (k *kidOutput) SetConfHeight(height uint32) {
k.confHeight = height
}
func (k *kidOutput) ConfHeight() uint32 {
return k.confHeight
}
func (k *kidOutput) RequiredLockTime() (uint32, bool) {
return k.absoluteMaturity, k.absoluteMaturity > 0
}
// Encode converts a KidOutput struct into a form suitable for on-disk database
// storage. Note that the signDescriptor struct field is included so that the
// output's witness can be generated by createSweepTx() when the output becomes
// spendable.
func (k *kidOutput) Encode(w io.Writer) error {
var scratch [8]byte
byteOrder.PutUint64(scratch[:], uint64(k.Amount()))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
op := k.OutPoint()
if err := graphdb.WriteOutpoint(w, &op); err != nil {
return err
}
if err := graphdb.WriteOutpoint(w, k.OriginChanPoint()); err != nil {
return err
}
if err := binary.Write(w, byteOrder, k.isHtlc); err != nil {
return err
}
byteOrder.PutUint32(scratch[:4], k.BlocksToMaturity())
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
byteOrder.PutUint32(scratch[:4], k.absoluteMaturity)
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
byteOrder.PutUint32(scratch[:4], k.ConfHeight())
if _, err := w.Write(scratch[:4]); err != nil {
return err
}
byteOrder.PutUint16(scratch[:2], uint16(k.witnessType))
if _, err := w.Write(scratch[:2]); err != nil {
return err
}
if err := input.WriteSignDescriptor(w, k.SignDesc()); err != nil {
return err
}
if k.SignDesc().ControlBlock == nil {
return nil
}
// If this is a taproot output, then it'll also have a control block,
// so we'll go ahead and write that now.
return wire.WriteVarBytes(w, 1000, k.SignDesc().ControlBlock)
}
// Decode takes a byte array representation of a kidOutput and converts it to an
// struct. Note that the witnessFunc method isn't added during deserialization
// and must be added later based on the value of the witnessType field.
func (k *kidOutput) Decode(r io.Reader) error {
var scratch [8]byte
if _, err := r.Read(scratch[:]); err != nil {
return err
}
k.amt = btcutil.Amount(byteOrder.Uint64(scratch[:]))
err := graphdb.ReadOutpoint(io.LimitReader(r, 40), &k.outpoint)
if err != nil {
return err
}
err = graphdb.ReadOutpoint(io.LimitReader(r, 40), &k.originChanPoint)
if err != nil {
return err
}
if err := binary.Read(r, byteOrder, &k.isHtlc); err != nil {
return err
}
if _, err := r.Read(scratch[:4]); err != nil {
return err
}
k.blocksToMaturity = byteOrder.Uint32(scratch[:4])
if _, err := r.Read(scratch[:4]); err != nil {
return err
}
k.absoluteMaturity = byteOrder.Uint32(scratch[:4])
if _, err := r.Read(scratch[:4]); err != nil {
return err
}
k.confHeight = byteOrder.Uint32(scratch[:4])
if _, err := r.Read(scratch[:2]); err != nil {
return err
}
k.witnessType = input.StandardWitnessType(byteOrder.Uint16(scratch[:2]))
if err := input.ReadSignDescriptor(r, &k.signDesc); err != nil {
return err
}
// If there's anything left in the reader, then this is a taproot
// output that also wrote a control block.
ctrlBlock, err := wire.ReadVarBytes(r, 0, 1000, "control block")
switch {
// If there're no bytes remaining, then we'll return early.
case errors.Is(err, io.EOF):
fallthrough
case errors.Is(err, io.ErrUnexpectedEOF):
return nil
case err != nil:
return err
}
k.signDesc.ControlBlock = ctrlBlock
return nil
}
// Compile-time constraint to ensure kidOutput implements the
// Input interface.
var _ input.Input = (*kidOutput)(nil)
package feature
import (
"fmt"
"github.com/lightningnetwork/lnd/lnwire"
)
type (
// featureSet contains a set of feature bits.
featureSet map[lnwire.FeatureBit]struct{}
// supportedFeatures maps the feature bit from a feature vector to a
// boolean indicating if this features dependencies have already been
// verified. This allows us to short circuit verification if multiple
// features have common dependencies, or map traversal starts verifying
// from the bottom up.
supportedFeatures map[lnwire.FeatureBit]bool
// depDesc maps a features to its set of dependent features, which must
// also be present for the vector to be valid. This can be used to
// recursively check the dependency chain for features in a feature
// vector.
depDesc map[lnwire.FeatureBit]featureSet
)
// ErrMissingFeatureDep is an error signaling that a transitive dependency in a
// feature vector is not set properly.
type ErrMissingFeatureDep struct {
dep lnwire.FeatureBit
}
// NewErrMissingFeatureDep creates a new ErrMissingFeatureDep error.
func NewErrMissingFeatureDep(dep lnwire.FeatureBit) ErrMissingFeatureDep {
return ErrMissingFeatureDep{dep: dep}
}
// Error returns a human-readable description of the missing dep error.
func (e ErrMissingFeatureDep) Error() string {
return fmt.Sprintf("missing feature dependency: %v", e.dep)
}
// deps is the default set of dependencies for assigned feature bits. If a
// feature is not present in the depDesc it is assumed to have no dependencies.
//
// NOTE: For proper functioning, only the optional variant of feature bits
// should be used in the following descriptor. In the future it may be necessary
// to distinguish the dependencies for optional and required bits, but for now
// the validation code maps required bits to optional ones since it simplifies
// the number of constraints.
var deps = depDesc{
lnwire.PaymentAddrOptional: {
lnwire.TLVOnionPayloadOptional: {},
},
lnwire.MPPOptional: {
lnwire.PaymentAddrOptional: {},
},
lnwire.AnchorsOptional: {
lnwire.StaticRemoteKeyOptional: {},
},
lnwire.AnchorsZeroFeeHtlcTxOptional: {
lnwire.StaticRemoteKeyOptional: {},
},
lnwire.AMPOptional: {
lnwire.PaymentAddrOptional: {},
},
lnwire.ExplicitChannelTypeOptional: {},
lnwire.ScriptEnforcedLeaseOptional: {
lnwire.ExplicitChannelTypeOptional: {},
lnwire.AnchorsZeroFeeHtlcTxOptional: {},
},
lnwire.KeysendOptional: {
lnwire.TLVOnionPayloadOptional: {},
},
lnwire.ZeroConfOptional: {
lnwire.ScidAliasOptional: {},
},
lnwire.SimpleTaprootChannelsOptionalStaging: {
lnwire.AnchorsZeroFeeHtlcTxOptional: {},
lnwire.ExplicitChannelTypeOptional: {},
},
lnwire.SimpleTaprootOverlayChansOptional: {
lnwire.SimpleTaprootChannelsOptionalStaging: {},
lnwire.TLVOnionPayloadOptional: {},
lnwire.ScidAliasOptional: {},
},
lnwire.RouteBlindingOptional: {
lnwire.TLVOnionPayloadOptional: {},
},
lnwire.Bolt11BlindedPathsOptional: {
lnwire.RouteBlindingOptional: {},
},
}
// ValidateDeps asserts that a feature vector sets all features and their
// transitive dependencies properly. It assumes that the dependencies between
// optional and required features are identical, e.g. if a feature is required
// but its dependency is optional, that is sufficient.
func ValidateDeps(fv *lnwire.FeatureVector) error {
features := fv.Features()
supported := initSupported(features)
return validateDeps(features, supported)
}
// SetBit sets the given feature bit on the given feature bit vector along with
// any of its dependencies. If the bit is required, then all the dependencies
// are also set to required, otherwise, the optional dependency bits are set.
// Existing bits are only upgraded from optional to required but never
// downgraded from required to optional.
func SetBit(vector *lnwire.FeatureVector,
bit lnwire.FeatureBit) *lnwire.FeatureVector {
fv := vector.Clone()
// Get the optional version of the bit since that is what the deps map
// uses.
optBit := mapToOptional(bit)
// If the bit we are setting is optional, then we set it (in its
// optional form) and also set all its dependents as optional if they
// are not already set (they may already be set in a required form in
// which case they should not be overridden).
if !bit.IsRequired() {
// Set the bit itself if it does not already exist. We use
// SafeSet here so that if the bit already exists in the
// required form, then this is not overwritten.
_ = fv.SafeSet(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
fv = SetBit(fv, depBit)
}
return fv
}
// The bit is required. In this case, we do want to override any
// existing optional bit for both the bit itself and for the dependent
// bits.
fv.Unset(optBit)
fv.Set(bit)
// Do the same for all the dependent bits.
for depBit := range deps[optBit] {
// The deps map only contains the optional versions of bits, so
// there is no need to first map the bit to the optional
// version.
fv.Unset(depBit)
// Set the required version of the bit instead.
fv = SetBit(fv, mapToRequired(depBit))
}
return fv
}
// validateDeps is a subroutine that recursively checks that the passed features
// have all of their associated dependencies in the supported map.
func validateDeps(features featureSet, supported supportedFeatures) error {
for bit := range features {
// Convert any required bits to optional.
bit = mapToOptional(bit)
// If the supported features doesn't contain the dependency, this
// vector is invalid.
checked, ok := supported[bit]
if !ok {
return NewErrMissingFeatureDep(bit)
}
// Alternatively, if we know that this dependency is valid, we
// can short circuit and continue verifying other bits.
if checked {
continue
}
// Recursively validate dependencies, since this method ranges
// over the subDeps. This method will return true even if
// subDeps is nil.
subDeps := deps[bit]
if err := validateDeps(subDeps, supported); err != nil {
return err
}
// Once we've confirmed that this feature's dependencies, if
// any, are sound, we record this so other paths taken through
// `bit` return early when inspecting the supported map.
supported[bit] = true
}
return nil
}
// initSupported sets all bits from the feature vector as supported but not
// checked. This signals that the validity of their dependencies has not been
// verified. All required bits are mapped to optional to simplify the DAG.
func initSupported(features featureSet) supportedFeatures {
supported := make(supportedFeatures)
for bit := range features {
bit = mapToOptional(bit)
supported[bit] = false
}
return supported
}
// mapToOptional returns the optional variant of a given feature bit pair. Our
// dependency graph is described using only optional feature bits, which
// reduces the number of constraints we need to express in the descriptor.
func mapToOptional(bit lnwire.FeatureBit) lnwire.FeatureBit {
if bit.IsRequired() {
bit ^= 0x01
}
return bit
}
// mapToRequired returns the required variant of a given feature bit pair.
func mapToRequired(bit lnwire.FeatureBit) lnwire.FeatureBit {
if bit.IsRequired() {
return bit
}
bit ^= 0x01
return bit
}
package feature
import (
"errors"
"fmt"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrUnknownSet is returned if a proposed feature vector contains a
// set that is unknown to LND.
ErrUnknownSet = errors.New("unknown feature bit set")
// ErrFeatureConfigured is returned if an attempt is made to unset a
// feature that was configured at startup.
ErrFeatureConfigured = errors.New("can't unset configured feature")
)
// Config houses any runtime modifications to the default set descriptors. For
// our purposes, this typically means disabling certain features to test legacy
// protocol interoperability or functionality.
type Config struct {
// NoTLVOnion unsets any optional or required TLVOnionPaylod bits from
// all feature sets.
NoTLVOnion bool
// NoStaticRemoteKey unsets any optional or required StaticRemoteKey
// bits from all feature sets.
NoStaticRemoteKey bool
// NoAnchors unsets any bits signaling support for anchor outputs.
NoAnchors bool
// NoWumbo unsets any bits signalling support for wumbo channels.
NoWumbo bool
// NoTaprootChans unsets any bits signaling support for taproot
// channels.
NoTaprootChans bool
// NoScriptEnforcementLease unsets any bits signaling support for script
// enforced leases.
NoScriptEnforcementLease bool
// NoKeysend unsets any bits signaling support for accepting keysend
// payments.
NoKeysend bool
// NoOptionScidAlias unsets any bits signalling support for
// option_scid_alias. This also implicitly disables zero-conf channels.
NoOptionScidAlias bool
// NoZeroConf unsets any bits signalling support for zero-conf
// channels. This should be used instead of NoOptionScidAlias to still
// keep option-scid-alias support.
NoZeroConf bool
// NoAnySegwit unsets any bits that signal support for using other
// segwit witness versions for co-op closes.
NoAnySegwit bool
// NoRouteBlinding unsets route blinding feature bits.
NoRouteBlinding bool
// NoQuiescence unsets quiescence feature bits.
NoQuiescence bool
// NoTaprootOverlay unsets the taproot overlay channel feature bits.
NoTaprootOverlay bool
// NoExperimentalEndorsement unsets any bits that signal support for
// forwarding experimental endorsement.
NoExperimentalEndorsement bool
// NoRbfCoopClose unsets any bits that signal support for using RBF for
// coop close.
NoRbfCoopClose bool
// CustomFeatures is a set of custom features to advertise in each
// set.
CustomFeatures map[Set][]lnwire.FeatureBit
}
// Manager is responsible for generating feature vectors for different requested
// feature sets.
type Manager struct {
// fsets is a static map of feature set to raw feature vectors. Requests
// are fulfilled by cloning these internal feature vectors.
fsets map[Set]*lnwire.RawFeatureVector
// configFeatures is a set of custom features that were "hard set" in
// lnd's config that cannot be updated at runtime (as is the case with
// our "standard" features that are defined in LND).
configFeatures map[Set]*lnwire.FeatureVector
}
// NewManager creates a new feature Manager, applying any custom modifications
// to its feature sets before returning.
func NewManager(cfg Config) (*Manager, error) {
return newManager(cfg, defaultSetDesc)
}
// newManager creates a new feature Manager, applying any custom modifications
// to its feature sets before returning. This method accepts the setDesc as its
// own parameter so that it can be unit tested.
func newManager(cfg Config, desc setDesc) (*Manager, error) {
// First build the default feature vector for all known sets.
fsets := make(map[Set]*lnwire.RawFeatureVector)
for bit, sets := range desc {
for set := range sets {
// Fetch the feature vector for this set, allocating a
// new one if it doesn't exist.
fv, ok := fsets[set]
if !ok {
fv = lnwire.NewRawFeatureVector()
}
// Set the configured bit on the feature vector,
// ensuring that we don't set two feature bits for the
// same pair.
err := fv.SafeSet(bit)
if err != nil {
return nil, fmt.Errorf("unable to set "+
"%v in %v: %v", bit, set, err)
}
// Write the updated feature vector under its set.
fsets[set] = fv
}
}
// Now, remove any features as directed by the config.
configFeatures := make(map[Set]*lnwire.FeatureVector)
for set, raw := range fsets {
if cfg.NoTLVOnion {
raw.Unset(lnwire.TLVOnionPayloadOptional)
raw.Unset(lnwire.TLVOnionPayloadRequired)
raw.Unset(lnwire.PaymentAddrOptional)
raw.Unset(lnwire.PaymentAddrRequired)
raw.Unset(lnwire.MPPOptional)
raw.Unset(lnwire.MPPRequired)
raw.Unset(lnwire.RouteBlindingOptional)
raw.Unset(lnwire.RouteBlindingRequired)
raw.Unset(lnwire.Bolt11BlindedPathsOptional)
raw.Unset(lnwire.Bolt11BlindedPathsRequired)
raw.Unset(lnwire.AMPOptional)
raw.Unset(lnwire.AMPRequired)
raw.Unset(lnwire.KeysendOptional)
raw.Unset(lnwire.KeysendRequired)
}
if cfg.NoStaticRemoteKey {
raw.Unset(lnwire.StaticRemoteKeyOptional)
raw.Unset(lnwire.StaticRemoteKeyRequired)
}
if cfg.NoAnchors {
raw.Unset(lnwire.AnchorsZeroFeeHtlcTxOptional)
raw.Unset(lnwire.AnchorsZeroFeeHtlcTxRequired)
// If anchors are disabled, then we also need to
// disable all other features that depend on it as
// well, as otherwise we may create an invalid feature
// bit set.
for bit, depFeatures := range deps {
for depFeature := range depFeatures {
switch {
case depFeature == lnwire.AnchorsZeroFeeHtlcTxRequired:
fallthrough
case depFeature == lnwire.AnchorsZeroFeeHtlcTxOptional:
raw.Unset(bit)
}
}
}
}
if cfg.NoWumbo {
raw.Unset(lnwire.WumboChannelsOptional)
raw.Unset(lnwire.WumboChannelsRequired)
}
if cfg.NoScriptEnforcementLease {
raw.Unset(lnwire.ScriptEnforcedLeaseOptional)
raw.Unset(lnwire.ScriptEnforcedLeaseRequired)
}
if cfg.NoKeysend {
raw.Unset(lnwire.KeysendOptional)
raw.Unset(lnwire.KeysendRequired)
}
if cfg.NoOptionScidAlias {
raw.Unset(lnwire.ScidAliasOptional)
raw.Unset(lnwire.ScidAliasRequired)
}
if cfg.NoZeroConf {
raw.Unset(lnwire.ZeroConfOptional)
raw.Unset(lnwire.ZeroConfRequired)
}
if cfg.NoAnySegwit {
raw.Unset(lnwire.ShutdownAnySegwitOptional)
raw.Unset(lnwire.ShutdownAnySegwitRequired)
}
if cfg.NoTaprootChans {
raw.Unset(lnwire.SimpleTaprootChannelsOptionalStaging)
raw.Unset(lnwire.SimpleTaprootChannelsRequiredStaging)
}
if cfg.NoRouteBlinding {
raw.Unset(lnwire.RouteBlindingOptional)
raw.Unset(lnwire.RouteBlindingRequired)
raw.Unset(lnwire.Bolt11BlindedPathsOptional)
raw.Unset(lnwire.Bolt11BlindedPathsRequired)
}
if cfg.NoQuiescence {
raw.Unset(lnwire.QuiescenceOptional)
}
if cfg.NoTaprootOverlay {
raw.Unset(lnwire.SimpleTaprootOverlayChansOptional)
raw.Unset(lnwire.SimpleTaprootOverlayChansRequired)
}
if cfg.NoExperimentalEndorsement {
raw.Unset(lnwire.ExperimentalEndorsementOptional)
raw.Unset(lnwire.ExperimentalEndorsementRequired)
}
if cfg.NoRbfCoopClose {
raw.Unset(lnwire.RbfCoopCloseOptionalStaging)
}
for _, custom := range cfg.CustomFeatures[set] {
if custom > set.Maximum() {
return nil, fmt.Errorf("feature bit: %v "+
"exceeds set: %v maximum: %v", custom,
set, set.Maximum())
}
if raw.IsSet(custom) {
return nil, fmt.Errorf("feature bit: %v "+
"already set", custom)
}
if err := raw.SafeSet(custom); err != nil {
return nil, fmt.Errorf("%w: could not set "+
"feature: %d", err, custom)
}
}
// Track custom features separately so that we can check that
// they aren't unset in subsequent updates. If there is no
// entry for the set, the vector will just be empty.
configFeatures[set] = lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(cfg.CustomFeatures[set]...),
lnwire.Features,
)
// Ensure that all of our feature sets properly set any
// dependent features.
fv := lnwire.NewFeatureVector(raw, lnwire.Features)
err := ValidateDeps(fv)
if err != nil {
return nil, fmt.Errorf("invalid feature set %v: %w",
set, err)
}
}
return &Manager{
fsets: fsets,
configFeatures: configFeatures,
}, nil
}
// GetRaw returns a raw feature vector for the passed set. If no set is known,
// an empty raw feature vector is returned.
func (m *Manager) GetRaw(set Set) *lnwire.RawFeatureVector {
if fv, ok := m.fsets[set]; ok {
return fv.Clone()
}
return lnwire.NewRawFeatureVector()
}
// setRaw sets a new raw feature vector for the given set.
func (m *Manager) setRaw(set Set, raw *lnwire.RawFeatureVector) {
m.fsets[set] = raw
}
// Get returns a feature vector for the passed set. If no set is known, an empty
// feature vector is returned.
func (m *Manager) Get(set Set) *lnwire.FeatureVector {
raw := m.GetRaw(set)
return lnwire.NewFeatureVector(raw, lnwire.Features)
}
// ListSets returns a list of the feature sets that our node supports.
func (m *Manager) ListSets() []Set {
var sets []Set
for set := range m.fsets {
sets = append(sets, set)
}
return sets
}
// UpdateFeatureSets accepts a map of new feature vectors for each of the
// manager's known sets, validates that the update can be applied and modifies
// the feature manager's internal state. If a set is not included in the update
// map, it is left unchanged. The feature vectors provided are expected to
// include the current set of features, updated with desired bits added/removed.
func (m *Manager) UpdateFeatureSets(
updates map[Set]*lnwire.RawFeatureVector) error {
for set, newFeatures := range updates {
if !set.valid() {
return fmt.Errorf("%w: set: %d", ErrUnknownSet, set)
}
if err := newFeatures.ValidatePairs(); err != nil {
return err
}
if err := m.Get(set).ValidateUpdate(
newFeatures, set.Maximum(),
); err != nil {
return err
}
// If any features were configured for this set, ensure that
// they are still set in the new feature vector.
if cfgFeat, haveCfgFeat := m.configFeatures[set]; haveCfgFeat {
for feature := range cfgFeat.Features() {
if !newFeatures.IsSet(feature) {
return fmt.Errorf("%w: can't unset: "+
"%d", ErrFeatureConfigured,
feature)
}
}
}
fv := lnwire.NewFeatureVector(newFeatures, lnwire.Features)
if err := ValidateDeps(fv); err != nil {
return err
}
}
// Only update the current feature sets once every proposed set has
// passed validation so that we don't partially update any sets then
// fail out on a later set's validation.
for set, features := range updates {
m.setRaw(set, features.Clone())
}
return nil
}
package feature
import (
"fmt"
"github.com/lightningnetwork/lnd/lnwire"
)
// ErrUnknownRequired signals that a feature vector requires certain features
// that our node is unaware of or does not implement.
type ErrUnknownRequired struct {
unknown []lnwire.FeatureBit
}
// NewErrUnknownRequired initializes an ErrUnknownRequired with the unknown
// feature bits.
func NewErrUnknownRequired(unknown []lnwire.FeatureBit) ErrUnknownRequired {
return ErrUnknownRequired{
unknown: unknown,
}
}
// Error returns a human-readable description of the error.
func (e ErrUnknownRequired) Error() string {
return fmt.Sprintf("feature vector contains unknown required "+
"features: %v", e.unknown)
}
// ValidateRequired returns an error if the feature vector contains a non-zero
// number of unknown, required feature bits.
func ValidateRequired(fv *lnwire.FeatureVector) error {
unknown := fv.UnknownRequiredFeatures()
if len(unknown) > 0 {
return NewErrUnknownRequired(unknown)
}
return nil
}
package feature
import (
"math"
"github.com/lightningnetwork/lnd/lnwire"
)
// Set is an enum identifying various feature sets, which separates the single
// feature namespace into distinct categories depending what context a feature
// vector is being used.
type Set uint8
const (
// SetInit identifies features that should be sent in an Init message to
// a remote peer.
SetInit Set = iota
// SetLegacyGlobal identifies features that should be set in the legacy
// GlobalFeatures field of an Init message, which maintains backwards
// compatibility with nodes that haven't implemented flat features.
SetLegacyGlobal
// SetNodeAnn identifies features that should be advertised on node
// announcements.
SetNodeAnn
// SetInvoice identifies features that should be advertised on invoices
// generated by the daemon.
SetInvoice
// SetInvoiceAmp identifies the features that should be advertised on
// AMP invoices generated by the daemon.
SetInvoiceAmp
// setSentinel is used to mark the end of our known sets. This enum
// member must *always* be the last item in the iota list to ensure
// that validation works as expected.
setSentinel
)
// valid returns a boolean indicating whether a set value is one of our
// predefined feature sets.
func (s Set) valid() bool {
return s < setSentinel
}
// String returns a human-readable description of a Set.
func (s Set) String() string {
switch s {
case SetInit:
return "SetInit"
case SetLegacyGlobal:
return "SetLegacyGlobal"
case SetNodeAnn:
return "SetNodeAnn"
case SetInvoice:
return "SetInvoice"
case SetInvoiceAmp:
return "SetInvoiceAmp"
default:
return "SetUnknown"
}
}
// Maximum returns the maximum allowable value for a feature bit in the context
// of a set. The maximum feature value we can express differs by set context
// because the amount of space available varies between protocol messages. In
// practice this should never be a problem (reasonably one would never hit
// these high ranges), but we enforce these maximums for the sake of sane
// validation.
func (s Set) Maximum() lnwire.FeatureBit {
switch s {
case SetInvoice, SetInvoiceAmp:
return lnwire.MaxBolt11Feature
// The space available in other sets is > math.MaxUint16, so we just
// return the maximum value our expression of a feature bit allows so
// that any value will pass.
default:
return math.MaxUint16
}
}
package graph
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/chainntnfs"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/multimutex"
"github.com/lightningnetwork/lnd/netann"
"github.com/lightningnetwork/lnd/routing/chainview"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/ticker"
)
const (
// DefaultChannelPruneExpiry is the default duration used to determine
// if a channel should be pruned or not.
DefaultChannelPruneExpiry = time.Hour * 24 * 14
// DefaultFirstTimePruneDelay is the time we'll wait after startup
// before attempting to prune the graph for zombie channels. We don't
// do it immediately after startup to allow lnd to start up without
// getting blocked by this job.
DefaultFirstTimePruneDelay = 30 * time.Second
// defaultStatInterval governs how often the router will log non-empty
// stats related to processing new channels, updates, or node
// announcements.
defaultStatInterval = time.Minute
)
var (
// ErrGraphBuilderShuttingDown is returned if the graph builder is in
// the process of shutting down.
ErrGraphBuilderShuttingDown = fmt.Errorf("graph builder shutting down")
)
// Config holds the configuration required by the Builder.
type Config struct {
// SelfNode is the public key of the node that this channel router
// belongs to.
SelfNode route.Vertex
// Graph is the channel graph that the ChannelRouter will use to gather
// metrics from and also to carry out path finding queries.
Graph DB
// Chain is the router's source to the most up-to-date blockchain data.
// All incoming advertised channels will be checked against the chain
// to ensure that the channels advertised are still open.
Chain lnwallet.BlockChainIO
// ChainView is an instance of a FilteredChainView which is used to
// watch the sub-set of the UTXO set (the set of active channels) that
// we need in order to properly maintain the channel graph.
ChainView chainview.FilteredChainView
// Notifier is a reference to the ChainNotifier, used to grab
// the latest blocks if the router is missing any.
Notifier chainntnfs.ChainNotifier
// ChannelPruneExpiry is the duration used to determine if a channel
// should be pruned or not. If the delta between now and when the
// channel was last updated is greater than ChannelPruneExpiry, then
// the channel is marked as a zombie channel eligible for pruning.
ChannelPruneExpiry time.Duration
// GraphPruneInterval is used as an interval to determine how often we
// should examine the channel graph to garbage collect zombie channels.
GraphPruneInterval time.Duration
// FirstTimePruneDelay is the time we'll wait after startup before
// attempting to prune the graph for zombie channels. We don't do it
// immediately after startup to allow lnd to start up without getting
// blocked by this job.
FirstTimePruneDelay time.Duration
// AssumeChannelValid toggles whether the builder will prune channels
// based on their spentness vs using the fact that they are considered
// zombies.
AssumeChannelValid bool
// StrictZombiePruning determines if we attempt to prune zombie
// channels according to a stricter criteria. If true, then we'll prune
// a channel if only *one* of the edges is considered a zombie.
// Otherwise, we'll only prune the channel when both edges have a very
// dated last update.
StrictZombiePruning bool
// IsAlias returns whether a passed ShortChannelID is an alias. This is
// only used for our local channels.
IsAlias func(scid lnwire.ShortChannelID) bool
}
// Builder builds and maintains a view of the Lightning Network graph.
type Builder struct {
started atomic.Bool
stopped atomic.Bool
bestHeight atomic.Uint32
cfg *Config
// newBlocks is a channel in which new blocks connected to the end of
// the main chain are sent over, and blocks updated after a call to
// UpdateFilter.
newBlocks <-chan *chainview.FilteredBlock
// staleBlocks is a channel in which blocks disconnected from the end
// of our currently known best chain are sent over.
staleBlocks <-chan *chainview.FilteredBlock
// channelEdgeMtx is a mutex we use to make sure we process only one
// ChannelEdgePolicy at a time for a given channelID, to ensure
// consistency between the various database accesses.
channelEdgeMtx *multimutex.Mutex[uint64]
// statTicker is a resumable ticker that logs the router's progress as
// it discovers channels or receives updates.
statTicker ticker.Ticker
// stats tracks newly processed channels, updates, and node
// announcements over a window of defaultStatInterval.
stats *builderStats
quit chan struct{}
wg sync.WaitGroup
}
// A compile time check to ensure Builder implements the
// ChannelGraphSource interface.
var _ ChannelGraphSource = (*Builder)(nil)
// NewBuilder constructs a new Builder.
func NewBuilder(cfg *Config) (*Builder, error) {
return &Builder{
cfg: cfg,
channelEdgeMtx: multimutex.NewMutex[uint64](),
statTicker: ticker.New(defaultStatInterval),
stats: new(builderStats),
quit: make(chan struct{}),
}, nil
}
// Start launches all the goroutines the Builder requires to carry out its
// duties. If the builder has already been started, then this method is a noop.
func (b *Builder) Start() error {
if !b.started.CompareAndSwap(false, true) {
return nil
}
log.Info("Builder starting")
bestHash, bestHeight, err := b.cfg.Chain.GetBestBlock()
if err != nil {
return err
}
// If the graph has never been pruned, or hasn't fully been created yet,
// then we don't treat this as an explicit error.
if _, _, err := b.cfg.Graph.PruneTip(); err != nil {
switch {
case errors.Is(err, graphdb.ErrGraphNeverPruned):
fallthrough
case errors.Is(err, graphdb.ErrGraphNotFound):
// If the graph has never been pruned, then we'll set
// the prune height to the current best height of the
// chain backend.
_, err = b.cfg.Graph.PruneGraph(
nil, bestHash, uint32(bestHeight),
)
if err != nil {
return err
}
default:
return err
}
}
// If AssumeChannelValid is present, then we won't rely on pruning
// channels from the graph based on their spentness, but whether they
// are considered zombies or not. We will start zombie pruning after a
// small delay, to avoid slowing down startup of lnd.
if b.cfg.AssumeChannelValid { //nolint:nestif
time.AfterFunc(b.cfg.FirstTimePruneDelay, func() {
select {
case <-b.quit:
return
default:
}
log.Info("Initial zombie prune starting")
if err := b.pruneZombieChans(); err != nil {
log.Errorf("Unable to prune zombies: %v", err)
}
})
} else {
// Otherwise, we'll use our filtered chain view to prune
// channels as soon as they are detected as spent on-chain.
if err := b.cfg.ChainView.Start(); err != nil {
return err
}
// Once the instance is active, we'll fetch the channel we'll
// receive notifications over.
b.newBlocks = b.cfg.ChainView.FilteredBlocks()
b.staleBlocks = b.cfg.ChainView.DisconnectedBlocks()
// Before we perform our manual block pruning, we'll construct
// and apply a fresh chain filter to the active
// FilteredChainView instance. We do this before, as otherwise
// we may miss on-chain events as the filter hasn't properly
// been applied.
channelView, err := b.cfg.Graph.ChannelView()
if err != nil && !errors.Is(
err, graphdb.ErrGraphNoEdgesFound,
) {
return err
}
log.Infof("Filtering chain using %v channels active",
len(channelView))
if len(channelView) != 0 {
err = b.cfg.ChainView.UpdateFilter(
channelView, uint32(bestHeight),
)
if err != nil {
return err
}
}
// The graph pruning might have taken a while and there could be
// new blocks available.
_, bestHeight, err = b.cfg.Chain.GetBestBlock()
if err != nil {
return err
}
b.bestHeight.Store(uint32(bestHeight))
// Before we begin normal operation of the router, we first need
// to synchronize the channel graph to the latest state of the
// UTXO set.
if err := b.syncGraphWithChain(); err != nil {
return err
}
// Finally, before we proceed, we'll prune any unconnected nodes
// from the graph in order to ensure we maintain a tight graph
// of "useful" nodes.
err = b.cfg.Graph.PruneGraphNodes()
if err != nil &&
!errors.Is(err, graphdb.ErrGraphNodesNotFound) {
return err
}
}
b.wg.Add(1)
go b.networkHandler()
log.Debug("Builder started")
return nil
}
// Stop signals to the Builder that it should halt all routines. This method
// will *block* until all goroutines have excited. If the builder has already
// stopped then this method will return immediately.
func (b *Builder) Stop() error {
if !b.stopped.CompareAndSwap(false, true) {
return nil
}
log.Info("Builder shutting down...")
// Our filtered chain view could've only been started if
// AssumeChannelValid isn't present.
if !b.cfg.AssumeChannelValid {
if err := b.cfg.ChainView.Stop(); err != nil {
return err
}
}
close(b.quit)
b.wg.Wait()
log.Debug("Builder shutdown complete")
return nil
}
// syncGraphWithChain attempts to synchronize the current channel graph with
// the latest UTXO set state. This process involves pruning from the channel
// graph any channels which have been closed by spending their funding output
// since we've been down.
func (b *Builder) syncGraphWithChain() error {
// First, we'll need to check to see if we're already in sync with the
// latest state of the UTXO set.
bestHash, bestHeight, err := b.cfg.Chain.GetBestBlock()
if err != nil {
return err
}
b.bestHeight.Store(uint32(bestHeight))
pruneHash, pruneHeight, err := b.cfg.Graph.PruneTip()
if err != nil {
switch {
// If the graph has never been pruned, or hasn't fully been
// created yet, then we don't treat this as an explicit error.
case errors.Is(err, graphdb.ErrGraphNeverPruned):
case errors.Is(err, graphdb.ErrGraphNotFound):
default:
return err
}
}
log.Infof("Prune tip for Channel Graph: height=%v, hash=%v",
pruneHeight, pruneHash)
switch {
// If the graph has never been pruned, then we can exit early as this
// entails it's being created for the first time and hasn't seen any
// block or created channels.
case pruneHeight == 0 || pruneHash == nil:
return nil
// If the block hashes and heights match exactly, then we don't need to
// prune the channel graph as we're already fully in sync.
case bestHash.IsEqual(pruneHash) && uint32(bestHeight) == pruneHeight:
return nil
}
// If the main chain blockhash at prune height is different from the
// prune hash, this might indicate the database is on a stale branch.
mainBlockHash, err := b.cfg.Chain.GetBlockHash(int64(pruneHeight))
if err != nil {
return err
}
// While we are on a stale branch of the chain, walk backwards to find
// first common block.
for !pruneHash.IsEqual(mainBlockHash) {
log.Infof("channel graph is stale. Disconnecting block %v "+
"(hash=%v)", pruneHeight, pruneHash)
// Prune the graph for every channel that was opened at height
// >= pruneHeight.
_, err := b.cfg.Graph.DisconnectBlockAtHeight(pruneHeight)
if err != nil {
return err
}
pruneHash, pruneHeight, err = b.cfg.Graph.PruneTip()
switch {
// If at this point the graph has never been pruned, we can exit
// as this entails we are back to the point where it hasn't seen
// any block or created channels, alas there's nothing left to
// prune.
case errors.Is(err, graphdb.ErrGraphNeverPruned):
return nil
case errors.Is(err, graphdb.ErrGraphNotFound):
return nil
case err != nil:
return err
default:
}
mainBlockHash, err = b.cfg.Chain.GetBlockHash(
int64(pruneHeight),
)
if err != nil {
return err
}
}
log.Infof("Syncing channel graph from height=%v (hash=%v) to "+
"height=%v (hash=%v)", pruneHeight, pruneHash, bestHeight,
bestHash)
// If we're not yet caught up, then we'll walk forward in the chain
// pruning the channel graph with each new block that hasn't yet been
// consumed by the channel graph.
var spentOutputs []*wire.OutPoint
for nextHeight := pruneHeight + 1; nextHeight <= uint32(bestHeight); nextHeight++ { //nolint:ll
// Break out of the rescan early if a shutdown has been
// requested, otherwise long rescans will block the daemon from
// shutting down promptly.
select {
case <-b.quit:
return ErrGraphBuilderShuttingDown
default:
}
// Using the next height, request a manual block pruning from
// the chainview for the particular block hash.
log.Infof("Filtering block for closed channels, at height: %v",
int64(nextHeight))
nextHash, err := b.cfg.Chain.GetBlockHash(int64(nextHeight))
if err != nil {
return err
}
log.Tracef("Running block filter on block with hash: %v",
nextHash)
filterBlock, err := b.cfg.ChainView.FilterBlock(nextHash)
if err != nil {
return err
}
// We're only interested in all prior outputs that have been
// spent in the block, so collate all the referenced previous
// outpoints within each tx and input.
for _, tx := range filterBlock.Transactions {
for _, txIn := range tx.TxIn {
spentOutputs = append(spentOutputs,
&txIn.PreviousOutPoint)
}
}
}
// With the spent outputs gathered, attempt to prune the channel graph,
// also passing in the best hash+height so the prune tip can be updated.
closedChans, err := b.cfg.Graph.PruneGraph(
spentOutputs, bestHash, uint32(bestHeight),
)
if err != nil {
return err
}
log.Infof("Graph pruning complete: %v channels were closed since "+
"height %v", len(closedChans), pruneHeight)
return nil
}
// isZombieChannel takes two edge policy updates and determines if the
// corresponding channel should be considered a zombie. The first boolean is
// true if the policy update from node 1 is considered a zombie, the second
// boolean is that of node 2, and the final boolean is true if the channel
// is considered a zombie.
func (b *Builder) isZombieChannel(e1,
e2 *models.ChannelEdgePolicy) (bool, bool, bool) {
chanExpiry := b.cfg.ChannelPruneExpiry
e1Zombie := e1 == nil || time.Since(e1.LastUpdate) >= chanExpiry
e2Zombie := e2 == nil || time.Since(e2.LastUpdate) >= chanExpiry
var e1Time, e2Time time.Time
if e1 != nil {
e1Time = e1.LastUpdate
}
if e2 != nil {
e2Time = e2.LastUpdate
}
return e1Zombie, e2Zombie, b.IsZombieChannel(e1Time, e2Time)
}
// IsZombieChannel takes the timestamps of the latest channel updates for a
// channel and returns true if the channel should be considered a zombie based
// on these timestamps.
func (b *Builder) IsZombieChannel(updateTime1,
updateTime2 time.Time) bool {
chanExpiry := b.cfg.ChannelPruneExpiry
e1Zombie := updateTime1.IsZero() ||
time.Since(updateTime1) >= chanExpiry
e2Zombie := updateTime2.IsZero() ||
time.Since(updateTime2) >= chanExpiry
// If we're using strict zombie pruning, then a channel is only
// considered live if both edges have a recent update we know of.
if b.cfg.StrictZombiePruning {
return e1Zombie || e2Zombie
}
// Otherwise, if we're using the less strict variant, then a channel is
// considered live if either of the edges have a recent update.
return e1Zombie && e2Zombie
}
// pruneZombieChans is a method that will be called periodically to prune out
// any "zombie" channels. We consider channels zombies if *both* edges haven't
// been updated since our zombie horizon. If AssumeChannelValid is present,
// we'll also consider channels zombies if *both* edges are disabled. This
// usually signals that a channel has been closed on-chain. We do this
// periodically to keep a healthy, lively routing table.
func (b *Builder) pruneZombieChans() error {
chansToPrune := make(map[uint64]struct{})
chanExpiry := b.cfg.ChannelPruneExpiry
log.Infof("Examining channel graph for zombie channels")
// A helper method to detect if the channel belongs to this node
isSelfChannelEdge := func(info *models.ChannelEdgeInfo) bool {
return info.NodeKey1Bytes == b.cfg.SelfNode ||
info.NodeKey2Bytes == b.cfg.SelfNode
}
// First, we'll collect all the channels which are eligible for garbage
// collection due to being zombies.
filterPruneChans := func(info *models.ChannelEdgeInfo,
e1, e2 *models.ChannelEdgePolicy) error {
// Exit early in case this channel is already marked to be
// pruned
_, markedToPrune := chansToPrune[info.ChannelID]
if markedToPrune {
return nil
}
// We'll ensure that we don't attempt to prune our *own*
// channels from the graph, as in any case this should be
// re-advertised by the sub-system above us.
if isSelfChannelEdge(info) {
return nil
}
e1Zombie, e2Zombie, isZombieChan := b.isZombieChannel(e1, e2)
if e1Zombie {
log.Tracef("Node1 pubkey=%x of chan_id=%v is zombie",
info.NodeKey1Bytes, info.ChannelID)
}
if e2Zombie {
log.Tracef("Node2 pubkey=%x of chan_id=%v is zombie",
info.NodeKey2Bytes, info.ChannelID)
}
// If either edge hasn't been updated for a period of
// chanExpiry, then we'll mark the channel itself as eligible
// for graph pruning.
if !isZombieChan {
return nil
}
log.Debugf("ChannelID(%v) is a zombie, collecting to prune",
info.ChannelID)
// TODO(roasbeef): add ability to delete single directional edge
chansToPrune[info.ChannelID] = struct{}{}
return nil
}
// If AssumeChannelValid is present we'll look at the disabled bit for
// both edges. If they're both disabled, then we can interpret this as
// the channel being closed and can prune it from our graph.
if b.cfg.AssumeChannelValid {
disabledChanIDs, err := b.cfg.Graph.DisabledChannelIDs()
if err != nil {
return fmt.Errorf("unable to get disabled channels "+
"ids chans: %v", err)
}
disabledEdges, err := b.cfg.Graph.FetchChanInfos(
disabledChanIDs,
)
if err != nil {
return fmt.Errorf("unable to fetch disabled channels "+
"edges chans: %v", err)
}
// Ensuring we won't prune our own channel from the graph.
for _, disabledEdge := range disabledEdges {
if !isSelfChannelEdge(disabledEdge.Info) {
chansToPrune[disabledEdge.Info.ChannelID] =
struct{}{}
}
}
}
startTime := time.Unix(0, 0)
endTime := time.Now().Add(-1 * chanExpiry)
oldEdges, err := b.cfg.Graph.ChanUpdatesInHorizon(startTime, endTime)
if err != nil {
return fmt.Errorf("unable to fetch expired channel updates "+
"chans: %v", err)
}
for _, u := range oldEdges {
err = filterPruneChans(u.Info, u.Policy1, u.Policy2)
if err != nil {
return fmt.Errorf("error filtering channels to "+
"prune: %w", err)
}
}
log.Infof("Pruning %v zombie channels", len(chansToPrune))
if len(chansToPrune) == 0 {
return nil
}
// With the set of zombie-like channels obtained, we'll do another pass
// to delete them from the channel graph.
toPrune := make([]uint64, 0, len(chansToPrune))
for chanID := range chansToPrune {
toPrune = append(toPrune, chanID)
log.Tracef("Pruning zombie channel with ChannelID(%v)", chanID)
}
err = b.cfg.Graph.DeleteChannelEdges(
b.cfg.StrictZombiePruning, true, toPrune...,
)
if err != nil {
return fmt.Errorf("unable to delete zombie channels: %w", err)
}
// With the channels pruned, we'll also attempt to prune any nodes that
// were a part of them.
err = b.cfg.Graph.PruneGraphNodes()
if err != nil && !errors.Is(err, graphdb.ErrGraphNodesNotFound) {
return fmt.Errorf("unable to prune graph nodes: %w", err)
}
return nil
}
// networkHandler is the primary goroutine for the Builder. The roles of
// this goroutine include answering queries related to the state of the
// network, pruning the graph on new block notification, applying network
// updates, and registering new topology clients.
//
// NOTE: This MUST be run as a goroutine.
func (b *Builder) networkHandler() {
defer b.wg.Done()
graphPruneTicker := time.NewTicker(b.cfg.GraphPruneInterval)
defer graphPruneTicker.Stop()
defer b.statTicker.Stop()
b.stats.Reset()
for {
// If there are stats, resume the statTicker.
if !b.stats.Empty() {
b.statTicker.Resume()
}
select {
case chainUpdate, ok := <-b.staleBlocks:
// If the channel has been closed, then this indicates
// the daemon is shutting down, so we exit ourselves.
if !ok {
return
}
// Since this block is stale, we update our best height
// to the previous block.
blockHeight := chainUpdate.Height
b.bestHeight.Store(blockHeight - 1)
// Update the channel graph to reflect that this block
// was disconnected.
_, err := b.cfg.Graph.DisconnectBlockAtHeight(
blockHeight,
)
if err != nil {
log.Errorf("unable to prune graph with stale "+
"block: %v", err)
continue
}
// TODO(halseth): notify client about the reorg?
// A new block has arrived, so we can prune the channel graph
// of any channels which were closed in the block.
case chainUpdate, ok := <-b.newBlocks:
// If the channel has been closed, then this indicates
// the daemon is shutting down, so we exit ourselves.
if !ok {
return
}
// We'll ensure that any new blocks received attach
// directly to the end of our main chain. If not, then
// we've somehow missed some blocks. Here we'll catch
// up the chain with the latest blocks.
currentHeight := b.bestHeight.Load()
switch {
case chainUpdate.Height == currentHeight+1:
err := b.updateGraphWithClosedChannels(
chainUpdate,
)
if err != nil {
log.Errorf("unable to prune graph "+
"with closed channels: %v", err)
}
case chainUpdate.Height > currentHeight+1:
log.Errorf("out of order block: expecting "+
"height=%v, got height=%v",
currentHeight+1, chainUpdate.Height)
err := b.getMissingBlocks(
currentHeight, chainUpdate,
)
if err != nil {
log.Errorf("unable to retrieve missing"+
"blocks: %v", err)
}
case chainUpdate.Height < currentHeight+1:
log.Errorf("out of order block: expecting "+
"height=%v, got height=%v",
currentHeight+1, chainUpdate.Height)
log.Infof("Skipping channel pruning since "+
"received block height %v was already"+
" processed.", chainUpdate.Height)
}
// The graph prune ticker has ticked, so we'll examine the
// state of the known graph to filter out any zombie channels
// for pruning.
case <-graphPruneTicker.C:
if err := b.pruneZombieChans(); err != nil {
log.Errorf("Unable to prune zombies: %v", err)
}
// Log any stats if we've processed a non-empty number of
// channels, updates, or nodes. We'll only pause the ticker if
// the last window contained no updates to avoid resuming and
// pausing while consecutive windows contain new info.
case <-b.statTicker.Ticks():
if !b.stats.Empty() {
log.Infof(b.stats.String())
} else {
b.statTicker.Pause()
}
b.stats.Reset()
// The router has been signalled to exit, to we exit our main
// loop so the wait group can be decremented.
case <-b.quit:
return
}
}
}
// getMissingBlocks walks through all missing blocks and updates the graph
// closed channels accordingly.
func (b *Builder) getMissingBlocks(currentHeight uint32,
chainUpdate *chainview.FilteredBlock) error {
outdatedHash, err := b.cfg.Chain.GetBlockHash(int64(currentHeight))
if err != nil {
return err
}
outdatedBlock := &chainntnfs.BlockEpoch{
Height: int32(currentHeight),
Hash: outdatedHash,
}
epochClient, err := b.cfg.Notifier.RegisterBlockEpochNtfn(
outdatedBlock,
)
if err != nil {
return err
}
defer epochClient.Cancel()
blockDifference := int(chainUpdate.Height - currentHeight)
// We'll walk through all the outdated blocks and make sure we're able
// to update the graph with any closed channels from them.
for i := 0; i < blockDifference; i++ {
var (
missingBlock *chainntnfs.BlockEpoch
ok bool
)
select {
case missingBlock, ok = <-epochClient.Epochs:
if !ok {
return nil
}
case <-b.quit:
return nil
}
filteredBlock, err := b.cfg.ChainView.FilterBlock(
missingBlock.Hash,
)
if err != nil {
return err
}
err = b.updateGraphWithClosedChannels(
filteredBlock,
)
if err != nil {
return err
}
}
return nil
}
// updateGraphWithClosedChannels prunes the channel graph of closed channels
// that are no longer needed.
func (b *Builder) updateGraphWithClosedChannels(
chainUpdate *chainview.FilteredBlock) error {
// Once a new block arrives, we update our running track of the height
// of the chain tip.
blockHeight := chainUpdate.Height
b.bestHeight.Store(blockHeight)
log.Infof("Pruning channel graph using block %v (height=%v)",
chainUpdate.Hash, blockHeight)
// We're only interested in all prior outputs that have been spent in
// the block, so collate all the referenced previous outpoints within
// each tx and input.
var spentOutputs []*wire.OutPoint
for _, tx := range chainUpdate.Transactions {
for _, txIn := range tx.TxIn {
spentOutputs = append(spentOutputs,
&txIn.PreviousOutPoint)
}
}
// With the spent outputs gathered, attempt to prune the channel graph,
// also passing in the hash+height of the block being pruned so the
// prune tip can be updated.
chansClosed, err := b.cfg.Graph.PruneGraph(spentOutputs,
&chainUpdate.Hash, chainUpdate.Height)
if err != nil {
log.Errorf("unable to prune routing table: %v", err)
return err
}
log.Infof("Block %v (height=%v) closed %v channels", chainUpdate.Hash,
blockHeight, len(chansClosed))
return nil
}
// assertNodeAnnFreshness returns a non-nil error if we have an announcement in
// the database for the passed node with a timestamp newer than the passed
// timestamp. ErrIgnored will be returned if we already have the node, and
// ErrOutdated will be returned if we have a timestamp that's after the new
// timestamp.
func (b *Builder) assertNodeAnnFreshness(node route.Vertex,
msgTimestamp time.Time) error {
// If we are not already aware of this node, it means that we don't
// know about any channel using this node. To avoid a DoS attack by
// node announcements, we will ignore such nodes. If we do know about
// this node, check that this update brings info newer than what we
// already have.
lastUpdate, exists, err := b.cfg.Graph.HasLightningNode(node)
if err != nil {
return errors.Errorf("unable to query for the "+
"existence of node: %v", err)
}
if !exists {
return NewErrf(ErrIgnored, "Ignoring node announcement"+
" for node not found in channel graph (%x)",
node[:])
}
// If we've reached this point then we're aware of the vertex being
// advertised. So we now check if the new message has a new time stamp,
// if not then we won't accept the new data as it would override newer
// data.
if !lastUpdate.Before(msgTimestamp) {
return NewErrf(ErrOutdated, "Ignoring outdated "+
"announcement for %x", node[:])
}
return nil
}
// MarkZombieEdge adds a channel that failed complete validation into the zombie
// index so we can avoid having to re-validate it in the future.
func (b *Builder) MarkZombieEdge(chanID uint64) error {
// If the edge fails validation we'll mark the edge itself as a zombie
// so we don't continue to request it. We use the "zero key" for both
// node pubkeys so this edge can't be resurrected.
var zeroKey [33]byte
err := b.cfg.Graph.MarkEdgeZombie(chanID, zeroKey, zeroKey)
if err != nil {
return fmt.Errorf("unable to mark spent chan(id=%v) as a "+
"zombie: %w", chanID, err)
}
return nil
}
// ApplyChannelUpdate validates a channel update and if valid, applies it to the
// database. It returns a bool indicating whether the updates were successful.
func (b *Builder) ApplyChannelUpdate(msg *lnwire.ChannelUpdate1) bool {
ch, _, _, err := b.GetChannelByID(msg.ShortChannelID)
if err != nil {
log.Errorf("Unable to retrieve channel by id: %v", err)
return false
}
var pubKey *btcec.PublicKey
switch msg.ChannelFlags & lnwire.ChanUpdateDirection {
case 0:
pubKey, _ = ch.NodeKey1()
case 1:
pubKey, _ = ch.NodeKey2()
}
// Exit early if the pubkey cannot be decided.
if pubKey == nil {
log.Errorf("Unable to decide pubkey with ChannelFlags=%v",
msg.ChannelFlags)
return false
}
err = netann.ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg)
if err != nil {
log.Errorf("Unable to validate channel update: %v", err)
return false
}
err = b.UpdateEdge(&models.ChannelEdgePolicy{
SigBytes: msg.Signature.ToSignatureBytes(),
ChannelID: msg.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(msg.Timestamp), 0),
MessageFlags: msg.MessageFlags,
ChannelFlags: msg.ChannelFlags,
TimeLockDelta: msg.TimeLockDelta,
MinHTLC: msg.HtlcMinimumMsat,
MaxHTLC: msg.HtlcMaximumMsat,
FeeBaseMSat: lnwire.MilliSatoshi(msg.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(msg.FeeRate),
ExtraOpaqueData: msg.ExtraOpaqueData,
})
if err != nil && !IsError(err, ErrIgnored, ErrOutdated) {
log.Errorf("Unable to apply channel update: %v", err)
return false
}
return true
}
// AddNode is used to add information about a node to the router database. If
// the node with this pubkey is not present in an existing channel, it will
// be ignored.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) AddNode(node *models.LightningNode,
op ...batch.SchedulerOption) error {
err := b.addNode(node, op...)
if err != nil {
logNetworkMsgProcessError(err)
return err
}
return nil
}
// addNode does some basic checks on the given LightningNode against what we
// currently have persisted in the graph, and then adds it to the graph. If we
// already know about the node, then we only update our DB if the new update
// has a newer timestamp than the last one we received.
func (b *Builder) addNode(node *models.LightningNode,
op ...batch.SchedulerOption) error {
// Before we add the node to the database, we'll check to see if the
// announcement is "fresh" or not. If it isn't, then we'll return an
// error.
err := b.assertNodeAnnFreshness(node.PubKeyBytes, node.LastUpdate)
if err != nil {
return err
}
if err := b.cfg.Graph.AddLightningNode(node, op...); err != nil {
return errors.Errorf("unable to add node %x to the "+
"graph: %v", node.PubKeyBytes, err)
}
log.Tracef("Updated vertex data for node=%x", node.PubKeyBytes)
b.stats.incNumNodeUpdates()
return nil
}
// AddEdge is used to add edge/channel to the topology of the router, after all
// information about channel will be gathered this edge/channel might be used
// in construction of payment path.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) AddEdge(edge *models.ChannelEdgeInfo,
op ...batch.SchedulerOption) error {
err := b.addEdge(edge, op...)
if err != nil {
logNetworkMsgProcessError(err)
return err
}
return nil
}
// addEdge does some validation on the new channel edge against what we
// currently have persisted in the graph, and then adds it to the graph. The
// Chain View is updated with the new edge if it is successfully added to the
// graph. We only persist the channel if we currently dont have it at all in
// our graph.
//
// TODO(elle): this currently also does funding-transaction validation. But this
// should be moved to the gossiper instead.
func (b *Builder) addEdge(edge *models.ChannelEdgeInfo,
op ...batch.SchedulerOption) error {
log.Debugf("Received ChannelEdgeInfo for channel %v", edge.ChannelID)
// Prior to processing the announcement we first check if we
// already know of this channel, if so, then we can exit early.
_, _, exists, isZombie, err := b.cfg.Graph.HasChannelEdge(
edge.ChannelID,
)
if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) {
return errors.Errorf("unable to check for edge existence: %v",
err)
}
if isZombie {
return NewErrf(ErrIgnored, "ignoring msg for zombie chan_id=%v",
edge.ChannelID)
}
if exists {
return NewErrf(ErrIgnored, "ignoring msg for known chan_id=%v",
edge.ChannelID)
}
if err := b.cfg.Graph.AddChannelEdge(edge, op...); err != nil {
return fmt.Errorf("unable to add edge: %w", err)
}
b.stats.incNumEdgesDiscovered()
// If AssumeChannelValid is present, of if the SCID is an alias, then
// the gossiper would not have done the expensive work of fetching
// a funding transaction and validating it. So we won't have the channel
// capacity nor the funding script. So we just log and return here.
scid := lnwire.NewShortChanIDFromInt(edge.ChannelID)
if b.cfg.AssumeChannelValid || b.cfg.IsAlias(scid) {
log.Tracef("New channel discovered! Link connects %x and %x "+
"with ChannelID(%v)", edge.NodeKey1Bytes,
edge.NodeKey2Bytes, edge.ChannelID)
return nil
}
log.Debugf("New channel discovered! Link connects %x and %x with "+
"ChannelPoint(%v): chan_id=%v, capacity=%v", edge.NodeKey1Bytes,
edge.NodeKey2Bytes, edge.ChannelPoint, edge.ChannelID,
edge.Capacity)
// Otherwise, then we expect the funding script to be present on the
// edge since it would have been fetched when the gossiper validated the
// announcement.
fundingPkScript, err := edge.FundingScript.UnwrapOrErr(fmt.Errorf(
"expected the funding transaction script to be set",
))
if err != nil {
return err
}
// As a new edge has been added to the channel graph, we'll update the
// current UTXO filter within our active FilteredChainView so we are
// notified if/when this channel is closed.
filterUpdate := []graphdb.EdgePoint{
{
FundingPkScript: fundingPkScript,
OutPoint: edge.ChannelPoint,
},
}
err = b.cfg.ChainView.UpdateFilter(filterUpdate, b.bestHeight.Load())
if err != nil {
return errors.Errorf("unable to update chain "+
"view: %v", err)
}
return nil
}
// UpdateEdge is used to update edge information, without this message edge
// considered as not fully constructed.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) UpdateEdge(update *models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error {
err := b.updateEdge(update, op...)
if err != nil {
logNetworkMsgProcessError(err)
return err
}
return nil
}
// updateEdge validates the new edge policy against what we currently have
// persisted in the graph, and then applies it to the graph if the update is
// considered fresh enough and if we actually have a channel persisted for the
// given update.
func (b *Builder) updateEdge(policy *models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error {
log.Debugf("Received ChannelEdgePolicy for channel %v",
policy.ChannelID)
// We make sure to hold the mutex for this channel ID, such that no
// other goroutine is concurrently doing database accesses for the same
// channel ID.
b.channelEdgeMtx.Lock(policy.ChannelID)
defer b.channelEdgeMtx.Unlock(policy.ChannelID)
edge1Timestamp, edge2Timestamp, exists, isZombie, err :=
b.cfg.Graph.HasChannelEdge(policy.ChannelID)
if err != nil && !errors.Is(err, graphdb.ErrGraphNoEdgesFound) {
return errors.Errorf("unable to check for edge existence: %v",
err)
}
// If the channel is marked as a zombie in our database, and
// we consider this a stale update, then we should not apply the
// policy.
isStaleUpdate := time.Since(policy.LastUpdate) >
b.cfg.ChannelPruneExpiry
if isZombie && isStaleUpdate {
return NewErrf(ErrIgnored, "ignoring stale update "+
"(flags=%v|%v) for zombie chan_id=%v",
policy.MessageFlags, policy.ChannelFlags,
policy.ChannelID)
}
// If the channel doesn't exist in our database, we cannot apply the
// updated policy.
if !exists {
return NewErrf(ErrIgnored, "ignoring update (flags=%v|%v) for "+
"unknown chan_id=%v", policy.MessageFlags,
policy.ChannelFlags, policy.ChannelID)
}
log.Debugf("Found edge1Timestamp=%v, edge2Timestamp=%v",
edge1Timestamp, edge2Timestamp)
// As edges are directional edge node has a unique policy for the
// direction of the edge they control. Therefore, we first check if we
// already have the most up-to-date information for that edge. If this
// message has a timestamp not strictly newer than what we already know
// of we can exit early.
switch policy.ChannelFlags & lnwire.ChanUpdateDirection {
// A flag set of 0 indicates this is an announcement for the "first"
// node in the channel.
case 0:
// Ignore outdated message.
if !edge1Timestamp.Before(policy.LastUpdate) {
return NewErrf(ErrOutdated, "Ignoring "+
"outdated update (flags=%v|%v) for "+
"known chan_id=%v", policy.MessageFlags,
policy.ChannelFlags, policy.ChannelID)
}
// Similarly, a flag set of 1 indicates this is an announcement
// for the "second" node in the channel.
case 1:
// Ignore outdated message.
if !edge2Timestamp.Before(policy.LastUpdate) {
return NewErrf(ErrOutdated, "Ignoring "+
"outdated update (flags=%v|%v) for "+
"known chan_id=%v", policy.MessageFlags,
policy.ChannelFlags, policy.ChannelID)
}
}
// Now that we know this isn't a stale update, we'll apply the new edge
// policy to the proper directional edge within the channel graph.
if err = b.cfg.Graph.UpdateEdgePolicy(policy, op...); err != nil {
err := errors.Errorf("unable to add channel: %v", err)
log.Error(err)
return err
}
log.Tracef("New channel update applied: %v",
lnutils.SpewLogClosure(policy))
b.stats.incNumChannelUpdates()
return nil
}
// logNetworkMsgProcessError logs the error received from processing a network
// message. It logs as a debug message if the error is not critical.
func logNetworkMsgProcessError(err error) {
if IsError(err, ErrIgnored, ErrOutdated) {
log.Debugf("process network updates got: %v", err)
return
}
log.Errorf("process network updates got: %v", err)
}
// CurrentBlockHeight returns the block height from POV of the router subsystem.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) CurrentBlockHeight() (uint32, error) {
_, height, err := b.cfg.Chain.GetBestBlock()
return uint32(height), err
}
// SyncedHeight returns the block height to which the router subsystem currently
// is synced to. This can differ from the above chain height if the goroutine
// responsible for processing the blocks isn't yet up to speed.
func (b *Builder) SyncedHeight() uint32 {
return b.bestHeight.Load()
}
// GetChannelByID return the channel by the channel id.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) GetChannelByID(chanID lnwire.ShortChannelID) (
*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
return b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64())
}
// FetchLightningNode attempts to look up a target node by its identity public
// key. graphdb.ErrGraphNodeNotFound is returned if the node doesn't exist
// within the graph.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) FetchLightningNode(
node route.Vertex) (*models.LightningNode, error) {
return b.cfg.Graph.FetchLightningNode(node)
}
// ForAllOutgoingChannels is used to iterate over all outgoing channels owned by
// the router.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) ForAllOutgoingChannels(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy) error) error {
return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode,
func(_ kvdb.RTx, c *models.ChannelEdgeInfo,
e *models.ChannelEdgePolicy,
_ *models.ChannelEdgePolicy) error {
if e == nil {
return fmt.Errorf("channel from self node " +
"has no policy")
}
return cb(c, e)
},
)
}
// AddProof updates the channel edge info with proof which is needed to
// properly announce the edge to the rest of the network.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) AddProof(chanID lnwire.ShortChannelID,
proof *models.ChannelAuthProof) error {
return b.cfg.Graph.AddEdgeProof(chanID, proof)
}
// IsStaleNode returns true if the graph source has a node announcement for the
// target node with a more recent timestamp.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) IsStaleNode(node route.Vertex,
timestamp time.Time) bool {
// If our attempt to assert that the node announcement is fresh fails,
// then we know that this is actually a stale announcement.
err := b.assertNodeAnnFreshness(node, timestamp)
if err != nil {
log.Debugf("Checking stale node %x got %v", node, err)
return true
}
return false
}
// IsPublicNode determines whether the given vertex is seen as a public node in
// the graph from the graph's source node's point of view.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) IsPublicNode(node route.Vertex) (bool, error) {
return b.cfg.Graph.IsPublicNode(node)
}
// IsKnownEdge returns true if the graph source already knows of the passed
// channel ID either as a live or zombie edge.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) IsKnownEdge(chanID lnwire.ShortChannelID) bool {
_, _, exists, isZombie, _ := b.cfg.Graph.HasChannelEdge(
chanID.ToUint64(),
)
return exists || isZombie
}
// IsZombieEdge returns true if the graph source has marked the given channel ID
// as a zombie edge.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) IsZombieEdge(chanID lnwire.ShortChannelID) (bool, error) {
_, _, _, isZombie, err := b.cfg.Graph.HasChannelEdge(chanID.ToUint64())
return isZombie, err
}
// IsStaleEdgePolicy returns true if the graph source has a channel edge for
// the passed channel ID (and flags) that have a more recent timestamp.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) IsStaleEdgePolicy(chanID lnwire.ShortChannelID,
timestamp time.Time, flags lnwire.ChanUpdateChanFlags) bool {
edge1Timestamp, edge2Timestamp, exists, isZombie, err :=
b.cfg.Graph.HasChannelEdge(chanID.ToUint64())
if err != nil {
log.Debugf("Check stale edge policy got error: %v", err)
return false
}
// If we know of the edge as a zombie, then we'll make some additional
// checks to determine if the new policy is fresh.
if isZombie {
// When running with AssumeChannelValid, we also prune channels
// if both of their edges are disabled. We'll mark the new
// policy as stale if it remains disabled.
if b.cfg.AssumeChannelValid {
isDisabled := flags&lnwire.ChanUpdateDisabled ==
lnwire.ChanUpdateDisabled
if isDisabled {
return true
}
}
// Otherwise, we'll fall back to our usual ChannelPruneExpiry.
return time.Since(timestamp) > b.cfg.ChannelPruneExpiry
}
// If we don't know of the edge, then it means it's fresh (thus not
// stale).
if !exists {
return false
}
// As edges are directional edge node has a unique policy for the
// direction of the edge they control. Therefore, we first check if we
// already have the most up-to-date information for that edge. If so,
// then we can exit early.
switch {
// A flag set of 0 indicates this is an announcement for the "first"
// node in the channel.
case flags&lnwire.ChanUpdateDirection == 0:
return !edge1Timestamp.Before(timestamp)
// Similarly, a flag set of 1 indicates this is an announcement for the
// "second" node in the channel.
case flags&lnwire.ChanUpdateDirection == 1:
return !edge2Timestamp.Before(timestamp)
}
return false
}
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
//
// NOTE: This method is part of the ChannelGraphSource interface.
func (b *Builder) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
return b.cfg.Graph.MarkEdgeLive(chanID.ToUint64())
}
package graphdb
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
)
// addressType specifies the network protocol and version that should be used
// when connecting to a node at a particular address.
type addressType uint8
const (
// tcp4Addr denotes an IPv4 TCP address.
tcp4Addr addressType = 0
// tcp6Addr denotes an IPv6 TCP address.
tcp6Addr addressType = 1
// v2OnionAddr denotes a version 2 Tor onion service address.
v2OnionAddr addressType = 2
// v3OnionAddr denotes a version 3 Tor (prop224) onion service address.
v3OnionAddr addressType = 3
// opaqueAddrs denotes an address (or a set of addresses) that LND was
// not able to parse since LND is not yet aware of the address type.
opaqueAddrs addressType = 4
)
// encodeTCPAddr serializes a TCP address into its compact raw bytes
// representation.
func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error {
var (
addrType byte
ip []byte
)
if addr.IP.To4() != nil {
addrType = byte(tcp4Addr)
ip = addr.IP.To4()
} else {
addrType = byte(tcp6Addr)
ip = addr.IP.To16()
}
if ip == nil {
return fmt.Errorf("unable to encode IP %v", addr.IP)
}
if _, err := w.Write([]byte{addrType}); err != nil {
return err
}
if _, err := w.Write(ip); err != nil {
return err
}
var port [2]byte
byteOrder.PutUint16(port[:], uint16(addr.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
return nil
}
// encodeOnionAddr serializes an onion address into its compact raw bytes
// representation.
func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error {
var suffixIndex int
hostLen := len(addr.OnionService)
switch hostLen {
case tor.V2Len:
if _, err := w.Write([]byte{byte(v2OnionAddr)}); err != nil {
return err
}
suffixIndex = tor.V2Len - tor.OnionSuffixLen
case tor.V3Len:
if _, err := w.Write([]byte{byte(v3OnionAddr)}); err != nil {
return err
}
suffixIndex = tor.V3Len - tor.OnionSuffixLen
default:
return errors.New("unknown onion service length")
}
suffix := addr.OnionService[suffixIndex:]
if suffix != tor.OnionSuffix {
return fmt.Errorf("invalid suffix \"%v\"", suffix)
}
host, err := tor.Base32Encoding.DecodeString(
addr.OnionService[:suffixIndex],
)
if err != nil {
return err
}
// Sanity check the decoded length.
switch {
case hostLen == tor.V2Len && len(host) != tor.V2DecodedLen:
return fmt.Errorf("onion service %v decoded to invalid host %x",
addr.OnionService, host)
case hostLen == tor.V3Len && len(host) != tor.V3DecodedLen:
return fmt.Errorf("onion service %v decoded to invalid host %x",
addr.OnionService, host)
}
if _, err := w.Write(host); err != nil {
return err
}
var port [2]byte
byteOrder.PutUint16(port[:], uint16(addr.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
return nil
}
// encodeOpaqueAddrs serializes the lnwire.OpaqueAddrs type to a raw set of
// bytes that we will persist.
func encodeOpaqueAddrs(w io.Writer, addr *lnwire.OpaqueAddrs) error {
// Write the type byte.
if _, err := w.Write([]byte{byte(opaqueAddrs)}); err != nil {
return err
}
// Write the length of the payload.
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(addr.Payload)))
if _, err := w.Write(l[:]); err != nil {
return err
}
// Write the payload.
_, err := w.Write(addr.Payload)
return err
}
// DeserializeAddr reads the serialized raw representation of an address and
// deserializes it into the actual address. This allows us to avoid address
// resolution within the channeldb package.
func DeserializeAddr(r io.Reader) (net.Addr, error) {
var addrType [1]byte
if _, err := r.Read(addrType[:]); err != nil {
return nil, err
}
var address net.Addr
switch addressType(addrType[0]) {
case tcp4Addr:
var ip [4]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
case tcp6Addr:
var ip [16]byte
if _, err := r.Read(ip[:]); err != nil {
return nil, err
}
var port [2]byte
if _, err := r.Read(port[:]); err != nil {
return nil, err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
case v2OnionAddr:
var h [tor.V2DecodedLen]byte
if _, err := r.Read(h[:]); err != nil {
return nil, err
}
var p [2]byte
if _, err := r.Read(p[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
case v3OnionAddr:
var h [tor.V3DecodedLen]byte
if _, err := r.Read(h[:]); err != nil {
return nil, err
}
var p [2]byte
if _, err := r.Read(p[:]); err != nil {
return nil, err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
case opaqueAddrs:
// Read the length of the payload.
var l [2]byte
if _, err := r.Read(l[:]); err != nil {
return nil, err
}
// Read the payload.
payload := make([]byte, binary.BigEndian.Uint16(l[:]))
if _, err := r.Read(payload); err != nil {
return nil, err
}
address = &lnwire.OpaqueAddrs{
Payload: payload,
}
default:
return nil, ErrUnknownAddressType
}
return address, nil
}
// SerializeAddr serializes an address into its raw bytes representation so that
// it can be deserialized without requiring address resolution.
func SerializeAddr(w io.Writer, address net.Addr) error {
switch addr := address.(type) {
case *net.TCPAddr:
return encodeTCPAddr(w, addr)
case *tor.OnionAddr:
return encodeOnionAddr(w, addr)
case *lnwire.OpaqueAddrs:
return encodeOpaqueAddrs(w, addr)
default:
return ErrUnknownAddressType
}
}
package graphdb
// channelCache is an in-memory cache used to improve the performance of
// ChanUpdatesInHorizon. It caches the chan info and edge policies for a
// particular channel.
type channelCache struct {
n int
channels map[uint64]ChannelEdge
}
// newChannelCache creates a new channelCache with maximum capacity of n
// channels.
func newChannelCache(n int) *channelCache {
return &channelCache{
n: n,
channels: make(map[uint64]ChannelEdge),
}
}
// get returns the channel from the cache, if it exists.
func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) {
channel, ok := c.channels[chanid]
return channel, ok
}
// insert adds the entry to the channel cache. If an entry for chanid already
// exists, it will be replaced with the new entry. If the entry doesn't exist,
// it will be inserted to the cache, performing a random eviction if the cache
// is at capacity.
func (c *channelCache) insert(chanid uint64, channel ChannelEdge) {
// If entry exists, replace it.
if _, ok := c.channels[chanid]; ok {
c.channels[chanid] = channel
return
}
// Otherwise, evict an entry at random and insert.
if len(c.channels) == c.n {
for id := range c.channels {
delete(c.channels, id)
break
}
}
c.channels[chanid] = channel
}
// remove deletes an edge for chanid from the cache, if it exists.
func (c *channelCache) remove(chanid uint64) {
delete(c.channels, chanid)
}
package graphdb
import (
"encoding/binary"
"io"
"github.com/btcsuite/btcd/wire"
)
var (
// byteOrder defines the preferred byte order, which is Big Endian.
byteOrder = binary.BigEndian
)
// WriteOutpoint writes an outpoint to the passed writer using the minimal
// amount of bytes possible.
func WriteOutpoint(w io.Writer, o *wire.OutPoint) error {
if _, err := w.Write(o.Hash[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, o.Index); err != nil {
return err
}
return nil
}
// ReadOutpoint reads an outpoint from the passed reader that was previously
// written using the WriteOutpoint struct.
func ReadOutpoint(r io.Reader, o *wire.OutPoint) error {
if _, err := io.ReadFull(r, o.Hash[:]); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &o.Index); err != nil {
return err
}
return nil
}
package graphdb
import (
"errors"
"fmt"
)
var (
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
// policy field is not found in the db even though its message flags
// indicate it should be.
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
"present")
// ErrGraphNotFound is returned when at least one of the components of
// graph doesn't exist.
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
// ErrGraphNeverPruned is returned when graph was never pruned.
ErrGraphNeverPruned = fmt.Errorf("graph never pruned")
// ErrSourceNodeNotSet is returned if the source node of the graph
// hasn't been added The source node is the center node within a
// star-graph.
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
// ErrGraphNodesNotFound is returned in case none of the nodes has
// been added in graph node bucket.
ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist")
// ErrGraphNoEdgesFound is returned in case of none of the channel/edges
// has been added in graph edge bucket.
ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist")
// ErrGraphNodeNotFound is returned when we're unable to find the target
// node.
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
// ErrZombieEdge is an error returned when we attempt to look up an edge
// but it is marked as a zombie within the zombie index.
ErrZombieEdge = errors.New("edge marked as zombie")
// ErrEdgeNotFound is returned when an edge for the target chanID
// can't be found.
ErrEdgeNotFound = fmt.Errorf("edge not found")
// ErrEdgeAlreadyExist is returned when edge with specific
// channel id can't be added because it already exist.
ErrEdgeAlreadyExist = fmt.Errorf("edge already exist")
// ErrNodeAliasNotFound is returned when alias for node can't be found.
ErrNodeAliasNotFound = fmt.Errorf("alias for node not found")
// ErrClosedScidsNotFound is returned when the closed scid bucket
// hasn't been created.
ErrClosedScidsNotFound = fmt.Errorf("closed scid bucket doesn't exist")
// ErrZombieEdgeNotFound is an error returned when we attempt to find an
// edge in the zombie index which is not there.
ErrZombieEdgeNotFound = errors.New("edge not found in zombie index")
// ErrUnknownAddressType is returned when a node's addressType is not
// an expected value.
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
)
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the
// caller attempts to write an announcement message which bares too many extra
// opaque bytes. We limit this value in order to ensure that we don't waste
// disk space due to nodes unnecessarily padding out their announcements with
// garbage data.
func ErrTooManyExtraOpaqueBytes(numBytes int) error {
return fmt.Errorf("max allowed number of opaque bytes is %v, received "+
"%v bytes", MaxAllowedExtraOpaqueBytes, numBytes)
}
package graphdb
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// ErrChanGraphShuttingDown indicates that the ChannelGraph has shutdown or is
// busy shutting down.
var ErrChanGraphShuttingDown = fmt.Errorf("ChannelGraph shutting down")
// Config is a struct that holds all the necessary dependencies for a
// ChannelGraph.
type Config struct {
// KVDB is the kvdb.Backend that will be used for initializing the
// KVStore CRUD layer.
KVDB kvdb.Backend
// KVStoreOpts is a list of functional options that will be used when
// initializing the KVStore.
KVStoreOpts []KVStoreOptionModifier
}
// ChannelGraph is a layer above the graph's CRUD layer.
//
// NOTE: currently, this is purely a pass-through layer directly to the backing
// KVStore. Upcoming commits will move the graph cache out of the KVStore and
// into this layer so that the KVStore is only responsible for CRUD operations.
type ChannelGraph struct {
started atomic.Bool
stopped atomic.Bool
// cacheMu guards any writes to the graphCache. It should be held
// across the DB write call and the graphCache update to make the
// two updates as atomic as possible.
cacheMu sync.Mutex
graphCache *GraphCache
*KVStore
*topologyManager
quit chan struct{}
wg sync.WaitGroup
}
// NewChannelGraph creates a new ChannelGraph instance with the given backend.
func NewChannelGraph(cfg *Config, options ...ChanGraphOption) (*ChannelGraph,
error) {
opts := defaultChanGraphOptions()
for _, o := range options {
o(opts)
}
store, err := NewKVStore(cfg.KVDB, cfg.KVStoreOpts...)
if err != nil {
return nil, err
}
g := &ChannelGraph{
KVStore: store,
topologyManager: newTopologyManager(),
quit: make(chan struct{}),
}
// The graph cache can be turned off (e.g. for mobile users) for a
// speed/memory usage tradeoff.
if opts.useGraphCache {
g.graphCache = NewGraphCache(opts.preAllocCacheNumNodes)
}
return g, nil
}
// Start kicks off any goroutines required for the ChannelGraph to function.
// If the graph cache is enabled, then it will be populated with the contents of
// the database.
func (c *ChannelGraph) Start() error {
if !c.started.CompareAndSwap(false, true) {
return nil
}
log.Debugf("ChannelGraph starting")
defer log.Debug("ChannelGraph started")
if c.graphCache != nil {
if err := c.populateCache(); err != nil {
return fmt.Errorf("could not populate the graph "+
"cache: %w", err)
}
}
c.wg.Add(1)
go c.handleTopologySubscriptions()
return nil
}
// Stop signals any active goroutines for a graceful closure.
func (c *ChannelGraph) Stop() error {
if !c.stopped.CompareAndSwap(false, true) {
return nil
}
log.Debugf("ChannelGraph shutting down...")
defer log.Debug("ChannelGraph shutdown complete")
close(c.quit)
c.wg.Wait()
return nil
}
// handleTopologySubscriptions ensures that topology client subscriptions,
// subscription cancellations and topology notifications are handled
// synchronously.
//
// NOTE: this MUST be run in a goroutine.
func (c *ChannelGraph) handleTopologySubscriptions() {
defer c.wg.Done()
for {
select {
// A new fully validated topology update has just arrived.
// We'll notify any registered clients.
case update := <-c.topologyUpdate:
// TODO(elle): change topology handling to be handled
// synchronously so that we can guarantee the order of
// notification delivery.
c.wg.Add(1)
go c.handleTopologyUpdate(update)
// TODO(roasbeef): remove all unconnected vertexes
// after N blocks pass with no corresponding
// announcements.
// A new notification client update has arrived. We're either
// gaining a new client, or cancelling notifications for an
// existing client.
case ntfnUpdate := <-c.ntfnClientUpdates:
clientID := ntfnUpdate.clientID
if ntfnUpdate.cancel {
client, ok := c.topologyClients.LoadAndDelete(
clientID,
)
if ok {
close(client.exit)
client.wg.Wait()
close(client.ntfnChan)
}
continue
}
c.topologyClients.Store(clientID, &topologyClient{
ntfnChan: ntfnUpdate.ntfnChan,
exit: make(chan struct{}),
})
case <-c.quit:
return
}
}
}
// populateCache loads the entire channel graph into the in-memory graph cache.
//
// NOTE: This should only be called if the graphCache has been constructed.
func (c *ChannelGraph) populateCache() error {
startTime := time.Now()
log.Info("Populating in-memory channel graph, this might take a " +
"while...")
err := c.KVStore.ForEachNodeCacheable(func(node route.Vertex,
features *lnwire.FeatureVector) error {
c.graphCache.AddNodeFeatures(node, features)
return nil
})
if err != nil {
return err
}
err = c.KVStore.ForEachChannel(func(info *models.ChannelEdgeInfo,
policy1, policy2 *models.ChannelEdgePolicy) error {
c.graphCache.AddChannel(info, policy1, policy2)
return nil
})
if err != nil {
return err
}
log.Infof("Finished populating in-memory channel graph (took %v, %s)",
time.Since(startTime), c.graphCache.Stats())
return nil
}
// ForEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller. If the graphCache
// is available, then it will be used to retrieve the node's channels instead
// of the database.
//
// Unknown policies are passed into the callback as nil values.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *ChannelGraph) ForEachNodeDirectedChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
if c.graphCache != nil {
return c.graphCache.ForEachChannel(node, cb)
}
return c.KVStore.ForEachNodeDirectedChannel(node, cb)
}
// FetchNodeFeatures returns the features of the given node. If no features are
// known for the node, an empty feature vector is returned.
// If the graphCache is available, then it will be used to retrieve the node's
// features instead of the database.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *ChannelGraph) FetchNodeFeatures(node route.Vertex) (
*lnwire.FeatureVector, error) {
if c.graphCache != nil {
return c.graphCache.GetFeatures(node), nil
}
return c.KVStore.FetchNodeFeatures(node)
}
// GraphSession will provide the call-back with access to a NodeTraverser
// instance which can be used to perform queries against the channel graph. If
// the graph cache is not enabled, then the call-back will be provided with
// access to the graph via a consistent read-only transaction.
func (c *ChannelGraph) GraphSession(cb func(graph NodeTraverser) error) error {
if c.graphCache != nil {
return cb(c)
}
return c.KVStore.GraphSession(cb)
}
// ForEachNodeCached iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered.
//
// NOTE: The callback contents MUST not be modified.
func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
chans map[uint64]*DirectedChannel) error) error {
if c.graphCache != nil {
return c.graphCache.ForEachNode(cb)
}
return c.KVStore.ForEachNodeCached(cb)
}
// AddLightningNode adds a vertex/node to the graph database. If the node is not
// in the database from before, this will add a new, unconnected one to the
// graph. If it is present from before, this will update that node's
// information. Note that this method is expected to only be called to update an
// already present node from a node announcement, or to insert a node found in a
// channel update.
func (c *ChannelGraph) AddLightningNode(node *models.LightningNode,
op ...batch.SchedulerOption) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := c.KVStore.AddLightningNode(node, op...)
if err != nil {
return err
}
if c.graphCache != nil {
c.graphCache.AddNodeFeatures(
node.PubKeyBytes, node.Features,
)
}
select {
case c.topologyUpdate <- node:
case <-c.quit:
return ErrChanGraphShuttingDown
}
return nil
}
// DeleteLightningNode starts a new database transaction to remove a vertex/node
// from the database according to the node's public key.
func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := c.KVStore.DeleteLightningNode(nodePub)
if err != nil {
return err
}
if c.graphCache != nil {
c.graphCache.RemoveNode(nodePub)
}
return nil
}
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
// undirected edge from the two target nodes are created. The information stored
// denotes the static attributes of the channel, such as the channelID, the keys
// involved in creation of the channel, and the set of features that the channel
// supports. The chanPoint and chanID are used to uniquely identify the edge
// globally within the database.
func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo,
op ...batch.SchedulerOption) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := c.KVStore.AddChannelEdge(edge, op...)
if err != nil {
return err
}
if c.graphCache != nil {
c.graphCache.AddChannel(edge, nil, nil)
}
select {
case c.topologyUpdate <- edge:
case <-c.quit:
return ErrChanGraphShuttingDown
}
return nil
}
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
// If the cache is enabled, the edge will be added back to the graph cache if
// we still have a record of this channel in the DB.
func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := c.KVStore.MarkEdgeLive(chanID)
if err != nil {
return err
}
if c.graphCache != nil {
// We need to add the channel back into our graph cache,
// otherwise we won't use it for path finding.
infos, err := c.KVStore.FetchChanInfos([]uint64{chanID})
if err != nil {
return err
}
if len(infos) == 0 {
return nil
}
info := infos[0]
c.graphCache.AddChannel(info.Info, info.Policy1, info.Policy2)
}
return nil
}
// DeleteChannelEdges removes edges with the given channel IDs from the
// database and marks them as zombies. This ensures that we're unable to re-add
// it to our database once again. If an edge does not exist within the
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
// true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state. The markZombie bool
// denotes whether to mark the channel as a zombie.
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning, markZombie bool,
chanIDs ...uint64) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
infos, err := c.KVStore.DeleteChannelEdges(
strictZombiePruning, markZombie, chanIDs...,
)
if err != nil {
return err
}
if c.graphCache != nil {
for _, info := range infos {
c.graphCache.RemoveChannel(
info.NodeKey1Bytes, info.NodeKey2Bytes,
info.ChannelID,
)
}
}
return err
}
// DisconnectBlockAtHeight is used to indicate that the block specified
// by the passed height has been disconnected from the main chain. This
// will "rewind" the graph back to the height below, deleting channels
// that are no longer confirmed from the graph. The prune log will be
// set to the last prune height valid for the remaining chain.
// Channels that were removed from the graph resulting from the
// disconnected block are returned.
func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) (
[]*models.ChannelEdgeInfo, error) {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
edges, err := c.KVStore.DisconnectBlockAtHeight(height)
if err != nil {
return nil, err
}
if c.graphCache != nil {
for _, edge := range edges {
c.graphCache.RemoveChannel(
edge.NodeKey1Bytes, edge.NodeKey2Bytes,
edge.ChannelID,
)
}
}
return edges, nil
}
// PruneGraph prunes newly closed channels from the channel graph in response
// to a new block being solved on the network. Any transactions which spend the
// funding output of any known channels within he graph will be deleted.
// Additionally, the "prune tip", or the last block which has been used to
// prune the graph is stored so callers can ensure the graph is fully in sync
// with the current UTXO state. A slice of channels that have been closed by
// the target block are returned if the function succeeds without error.
func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
blockHash *chainhash.Hash, blockHeight uint32) (
[]*models.ChannelEdgeInfo, error) {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
edges, nodes, err := c.KVStore.PruneGraph(
spentOutputs, blockHash, blockHeight,
)
if err != nil {
return nil, err
}
if c.graphCache != nil {
for _, edge := range edges {
c.graphCache.RemoveChannel(
edge.NodeKey1Bytes, edge.NodeKey2Bytes,
edge.ChannelID,
)
}
for _, node := range nodes {
c.graphCache.RemoveNode(node)
}
log.Debugf("Pruned graph, cache now has %s",
c.graphCache.Stats())
}
if len(edges) != 0 {
// Notify all currently registered clients of the newly closed
// channels.
closeSummaries := createCloseSummaries(
blockHeight, edges...,
)
c.notifyTopologyChange(&TopologyChange{
ClosedChannels: closeSummaries,
})
}
return edges, nil
}
// PruneGraphNodes is a garbage collection method which attempts to prune out
// any nodes from the channel graph that are currently unconnected. This ensure
// that we only maintain a graph of reachable nodes. In the event that a pruned
// node gains more channels, it will be re-added back to the graph.
func (c *ChannelGraph) PruneGraphNodes() error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
nodes, err := c.KVStore.PruneGraphNodes()
if err != nil {
return err
}
if c.graphCache != nil {
for _, node := range nodes {
c.graphCache.RemoveNode(node)
}
}
return nil
}
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
// ID's that we don't know and are not known zombies of the passed set. In other
// words, we perform a set difference of our set of chan ID's and the ones
// passed in. This method can be used by callers to determine the set of
// channels another peer knows of that we don't.
func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo,
isZombieChan func(time.Time, time.Time) bool) ([]uint64, error) {
unknown, knownZombies, err := c.KVStore.FilterKnownChanIDs(chansInfo)
if err != nil {
return nil, err
}
for _, info := range knownZombies {
// TODO(ziggie): Make sure that for the strict pruning case we
// compare the pubkeys and whether the right timestamp is not
// older than the `ChannelPruneExpiry`.
//
// NOTE: The timestamp data has no verification attached to it
// in the `ReplyChannelRange` msg so we are trusting this data
// at this point. However it is not critical because we are just
// removing the channel from the db when the timestamps are more
// recent. During the querying of the gossip msg verification
// happens as usual. However we should start punishing peers
// when they don't provide us honest data ?
isStillZombie := isZombieChan(
info.Node1UpdateTimestamp, info.Node2UpdateTimestamp,
)
if isStillZombie {
continue
}
// If we have marked it as a zombie but the latest update
// timestamps could bring it back from the dead, then we mark it
// alive, and we let it be added to the set of IDs to query our
// peer for.
err := c.KVStore.MarkEdgeLive(
info.ShortChannelID.ToUint64(),
)
// Since there is a chance that the edge could have been marked
// as "live" between the FilterKnownChanIDs call and the
// MarkEdgeLive call, we ignore the error if the edge is already
// marked as live.
if err != nil && !errors.Is(err, ErrZombieEdgeNotFound) {
return nil, err
}
}
return unknown, nil
}
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
// zombie. This method is used on an ad-hoc basis, when channels need to be
// marked as zombies outside the normal pruning cycle.
func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
pubKey1, pubKey2 [33]byte) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := c.KVStore.MarkEdgeZombie(chanID, pubKey1, pubKey2)
if err != nil {
return err
}
if c.graphCache != nil {
c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID)
}
return nil
}
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
// within the database for the referenced channel. The `flags` attribute within
// the ChannelEdgePolicy determines which of the directed edges are being
// updated. If the flag is 1, then the first node's information is being
// updated, otherwise it's the second node's information. The node ordering is
// determined by the lexicographical ordering of the identity public keys of the
// nodes on either side of the channel.
func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
op ...batch.SchedulerOption) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
from, to, err := c.KVStore.UpdateEdgePolicy(edge, op...)
if err != nil {
return err
}
if c.graphCache != nil {
var isUpdate1 bool
if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
isUpdate1 = true
}
c.graphCache.UpdatePolicy(edge, from, to, isUpdate1)
}
select {
case c.topologyUpdate <- edge:
case <-c.quit:
return ErrChanGraphShuttingDown
}
return nil
}
package graphdb
import (
"fmt"
"sync"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// DirectedChannel is a type that stores the channel information as seen from
// one side of the channel.
type DirectedChannel struct {
// ChannelID is the unique identifier of this channel.
ChannelID uint64
// IsNode1 indicates if this is the node with the smaller public key.
IsNode1 bool
// OtherNode is the public key of the node on the other end of this
// channel.
OtherNode route.Vertex
// Capacity is the announced capacity of this channel in satoshis.
Capacity btcutil.Amount
// OutPolicySet is a boolean that indicates whether the node has an
// outgoing policy set. For pathfinding only the existence of the policy
// is important to know, not the actual content.
OutPolicySet bool
// InPolicy is the incoming policy *from* the other node to this node.
// In path finding, we're walking backward from the destination to the
// source, so we're always interested in the edge that arrives to us
// from the other node.
InPolicy *models.CachedEdgePolicy
// Inbound fees of this node.
InboundFee lnwire.Fee
}
// DeepCopy creates a deep copy of the channel, including the incoming policy.
func (c *DirectedChannel) DeepCopy() *DirectedChannel {
channelCopy := *c
if channelCopy.InPolicy != nil {
inPolicyCopy := *channelCopy.InPolicy
channelCopy.InPolicy = &inPolicyCopy
// The fields for the ToNode can be overwritten by the path
// finding algorithm, which is why we need a deep copy in the
// first place. So we always start out with nil values, just to
// be sure they don't contain any old data.
channelCopy.InPolicy.ToNodePubKey = nil
channelCopy.InPolicy.ToNodeFeatures = nil
}
return &channelCopy
}
// GraphCache is a type that holds a minimal set of information of the public
// channel graph that can be used for pathfinding.
type GraphCache struct {
nodeChannels map[route.Vertex]map[uint64]*DirectedChannel
nodeFeatures map[route.Vertex]*lnwire.FeatureVector
mtx sync.RWMutex
}
// NewGraphCache creates a new graphCache.
func NewGraphCache(preAllocNumNodes int) *GraphCache {
return &GraphCache{
nodeChannels: make(
map[route.Vertex]map[uint64]*DirectedChannel,
// A channel connects two nodes, so we can look it up
// from both sides, meaning we get double the number of
// entries.
preAllocNumNodes*2,
),
nodeFeatures: make(
map[route.Vertex]*lnwire.FeatureVector,
preAllocNumNodes,
),
}
}
// Stats returns statistics about the current cache size.
func (c *GraphCache) Stats() string {
c.mtx.RLock()
defer c.mtx.RUnlock()
numChannels := 0
for node := range c.nodeChannels {
numChannels += len(c.nodeChannels[node])
}
return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+
"num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels),
numChannels)
}
// AddNodeFeatures adds a graph node and its features to the cache.
func (c *GraphCache) AddNodeFeatures(node route.Vertex,
features *lnwire.FeatureVector) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.nodeFeatures[node] = features
}
// AddChannel adds a non-directed channel, meaning that the order of policy 1
// and policy 2 does not matter, the directionality is extracted from the info
// and policy flags automatically. The policy will be set as the outgoing policy
// on one node and the incoming policy on the peer's side.
func (c *GraphCache) AddChannel(info *models.ChannelEdgeInfo,
policy1 *models.ChannelEdgePolicy, policy2 *models.ChannelEdgePolicy) {
if info == nil {
return
}
if policy1 != nil && policy1.IsDisabled() &&
policy2 != nil && policy2.IsDisabled() {
return
}
// Create the edge entry for both nodes.
c.mtx.Lock()
c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: true,
OtherNode: info.NodeKey2Bytes,
Capacity: info.Capacity,
})
c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: false,
OtherNode: info.NodeKey1Bytes,
Capacity: info.Capacity,
})
c.mtx.Unlock()
// The policy's node is always the to_node. So if policy 1 has to_node
// of node 2 then we have the policy 1 as seen from node 1.
if policy1 != nil {
fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes
if policy1.ToNode != info.NodeKey2Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy1, fromNode, toNode, isEdge1)
}
if policy2 != nil {
fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes
if policy2.ToNode != info.NodeKey1Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy2, fromNode, toNode, isEdge1)
}
}
// updateOrAddEdge makes sure the edge information for a node is either updated
// if it already exists or is added to that node's list of channels.
func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) {
if len(c.nodeChannels[node]) == 0 {
c.nodeChannels[node] = make(map[uint64]*DirectedChannel)
}
c.nodeChannels[node][edge.ChannelID] = edge
}
// UpdatePolicy updates a single policy on both the from and to node. The order
// of the from and to node is not strictly important. But we assume that a
// channel edge was added beforehand so that the directed channel struct already
// exists in the cache.
func (c *GraphCache) UpdatePolicy(policy *models.ChannelEdgePolicy, fromNode,
toNode route.Vertex, edge1 bool) {
// Extract inbound fee if possible and available. If there is a decoding
// error, ignore this policy.
var inboundFee lnwire.Fee
_, err := policy.ExtraOpaqueData.ExtractRecords(&inboundFee)
if err != nil {
log.Errorf("Failed to extract records from edge policy %v: %v",
policy.ChannelID, err)
return
}
c.mtx.Lock()
defer c.mtx.Unlock()
updatePolicy := func(nodeKey route.Vertex) {
if len(c.nodeChannels[nodeKey]) == 0 {
log.Warnf("Node=%v not found in graph cache", nodeKey)
return
}
channel, ok := c.nodeChannels[nodeKey][policy.ChannelID]
if !ok {
log.Warnf("Channel=%v not found in graph cache",
policy.ChannelID)
return
}
// Edge 1 is defined as the policy for the direction of node1 to
// node2.
switch {
// This is node 1, and it is edge 1, so this is the outgoing
// policy for node 1.
case channel.IsNode1 && edge1:
channel.OutPolicySet = true
channel.InboundFee = inboundFee
// This is node 2, and it is edge 2, so this is the outgoing
// policy for node 2.
case !channel.IsNode1 && !edge1:
channel.OutPolicySet = true
channel.InboundFee = inboundFee
// The other two cases left mean it's the inbound policy for the
// node.
default:
channel.InPolicy = models.NewCachedPolicy(policy)
}
}
updatePolicy(fromNode)
updatePolicy(toNode)
}
// RemoveNode completely removes a node and all its channels (including the
// peer's side).
func (c *GraphCache) RemoveNode(node route.Vertex) {
c.mtx.Lock()
defer c.mtx.Unlock()
delete(c.nodeFeatures, node)
// First remove all channels from the other nodes' lists.
for _, channel := range c.nodeChannels[node] {
c.removeChannelIfFound(channel.OtherNode, channel.ChannelID)
}
// Then remove our whole node completely.
delete(c.nodeChannels, node)
}
// RemoveChannel removes a single channel between two nodes.
func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) {
c.mtx.Lock()
defer c.mtx.Unlock()
// Remove that one channel from both sides.
c.removeChannelIfFound(node1, chanID)
c.removeChannelIfFound(node2, chanID)
}
// removeChannelIfFound removes a single channel from one side.
func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) {
if len(c.nodeChannels[node]) == 0 {
return
}
delete(c.nodeChannels[node], chanID)
}
// UpdateChannel updates the channel edge information for a specific edge. We
// expect the edge to already exist and be known. If it does not yet exist, this
// call is a no-op.
func (c *GraphCache) UpdateChannel(info *models.ChannelEdgeInfo) {
c.mtx.Lock()
defer c.mtx.Unlock()
if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 ||
len(c.nodeChannels[info.NodeKey2Bytes]) == 0 {
return
}
channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID]
if ok {
// We only expect to be called when the channel is already
// known.
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey2Bytes
}
channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID]
if ok {
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey1Bytes
}
}
// getChannels returns a copy of the passed node's channels or nil if there
// isn't any.
func (c *GraphCache) getChannels(node route.Vertex) []*DirectedChannel {
c.mtx.RLock()
defer c.mtx.RUnlock()
channels, ok := c.nodeChannels[node]
if !ok {
return nil
}
features, ok := c.nodeFeatures[node]
if !ok {
// If the features were set to nil explicitly, that's fine here.
// The router will overwrite the features of the destination
// node with those found in the invoice if necessary. But if we
// didn't yet get a node announcement we want to mimic the
// behavior of the old DB based code that would always set an
// empty feature vector instead of leaving it nil.
features = lnwire.EmptyFeatureVector()
}
toNodeCallback := func() route.Vertex {
return node
}
i := 0
channelsCopy := make([]*DirectedChannel, len(channels))
for _, channel := range channels {
// We need to copy the channel and policy to avoid it being
// updated in the cache if the path finding algorithm sets
// fields on it (currently only the ToNodeFeatures of the
// policy).
channelCopy := channel.DeepCopy()
if channelCopy.InPolicy != nil {
channelCopy.InPolicy.ToNodePubKey = toNodeCallback
channelCopy.InPolicy.ToNodeFeatures = features
}
channelsCopy[i] = channelCopy
i++
}
return channelsCopy
}
// ForEachChannel invokes the given callback for each channel of the given node.
func (c *GraphCache) ForEachChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
// Obtain a copy of the node's channels. We need do this in order to
// avoid deadlocks caused by interaction with the graph cache, channel
// state and the graph database from multiple goroutines. This snapshot
// is only used for path finding where being stale is acceptable since
// the real world graph and our representation may always become
// slightly out of sync for a short time and the actual channel state
// is stored separately.
channels := c.getChannels(node)
for _, channel := range channels {
if err := cb(channel); err != nil {
return err
}
}
return nil
}
// ForEachNode iterates over the adjacency list of the graph, executing the
// call back for each node and the set of channels that emanate from the given
// node.
//
// NOTE: This method should be considered _read only_, the channels or nodes
// passed in MUST NOT be modified.
func (c *GraphCache) ForEachNode(cb func(node route.Vertex,
channels map[uint64]*DirectedChannel) error) error {
c.mtx.RLock()
defer c.mtx.RUnlock()
for node, channels := range c.nodeChannels {
// We don't make a copy here since this is a read-only RPC
// call. We also don't need the node features either for this
// call.
if err := cb(node, channels); err != nil {
return err
}
}
return nil
}
// GetFeatures returns the features of the node with the given ID. If no
// features are known for the node, an empty feature vector is returned.
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
c.mtx.RLock()
defer c.mtx.RUnlock()
features, ok := c.nodeFeatures[node]
if !ok || features == nil {
// The router expects the features to never be nil, so we return
// an empty feature set instead.
return lnwire.EmptyFeatureVector()
}
return features
}
package graphdb
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"sort"
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/aliasmgr"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
)
var (
// nodeBucket is a bucket which houses all the vertices or nodes within
// the channel graph. This bucket has a single-sub bucket which adds an
// additional index from pubkey -> alias. Within the top-level of this
// bucket, the key space maps a node's compressed public key to the
// serialized information for that node. Additionally, there's a
// special key "source" which stores the pubkey of the source node. The
// source node is used as the starting point for all graph/queries and
// traversals. The graph is formed as a star-graph with the source node
// at the center.
//
// maps: pubKey -> nodeInfo
// maps: source -> selfPubKey
nodeBucket = []byte("graph-node")
// nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket
// will be used to quickly look up the "freshness" of a node's last
// update to the network. The bucket only contains keys, and no values,
// it's mapping:
//
// maps: updateTime || nodeID -> nil
nodeUpdateIndexBucket = []byte("graph-node-update-index")
// sourceKey is a special key that resides within the nodeBucket. The
// sourceKey maps a key to the public key of the "self node".
sourceKey = []byte("source")
// aliasIndexBucket is a sub-bucket that's nested within the main
// nodeBucket. This bucket maps the public key of a node to its
// current alias. This bucket is provided as it can be used within a
// future UI layer to add an additional degree of confirmation.
aliasIndexBucket = []byte("alias")
// edgeBucket is a bucket which houses all of the edge or channel
// information within the channel graph. This bucket essentially acts
// as an adjacency list, which in conjunction with a range scan, can be
// used to iterate over all the incoming and outgoing edges for a
// particular node. Key in the bucket use a prefix scheme which leads
// with the node's public key and sends with the compact edge ID.
// For each chanID, there will be two entries within the bucket, as the
// graph is directed: nodes may have different policies w.r.t to fees
// for their respective directions.
//
// maps: pubKey || chanID -> channel edge policy for node
edgeBucket = []byte("graph-edge")
// unknownPolicy is represented as an empty slice. It is
// used as the value in edgeBucket for unknown channel edge policies.
// Unknown policies are still stored in the database to enable efficient
// lookup of incoming channel edges.
unknownPolicy = []byte{}
// chanStart is an array of all zero bytes which is used to perform
// range scans within the edgeBucket to obtain all of the outgoing
// edges for a particular node.
chanStart [8]byte
// edgeIndexBucket is an index which can be used to iterate all edges
// in the bucket, grouping them according to their in/out nodes.
// Additionally, the items in this bucket also contain the complete
// edge information for a channel. The edge information includes the
// capacity of the channel, the nodes that made the channel, etc. This
// bucket resides within the edgeBucket above. Creation of an edge
// proceeds in two phases: first the edge is added to the edge index,
// afterwards the edgeBucket can be updated with the latest details of
// the edge as they are announced on the network.
//
// maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo
edgeIndexBucket = []byte("edge-index")
// edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This
// bucket contains an index which allows us to gauge the "freshness" of
// a channel's last updates.
//
// maps: updateTime || chanID -> nil
edgeUpdateIndexBucket = []byte("edge-update-index")
// channelPointBucket maps a channel's full outpoint (txid:index) to
// its short 8-byte channel ID. This bucket resides within the
// edgeBucket above, and can be used to quickly remove an edge due to
// the outpoint being spent, or to query for existence of a channel.
//
// maps: outPoint -> chanID
channelPointBucket = []byte("chan-index")
// zombieBucket is a sub-bucket of the main edgeBucket bucket
// responsible for maintaining an index of zombie channels. Each entry
// exists within the bucket as follows:
//
// maps: chanID -> pubKey1 || pubKey2
//
// The chanID represents the channel ID of the edge that is marked as a
// zombie and is used as the key, which maps to the public keys of the
// edge's participants.
zombieBucket = []byte("zombie-index")
// disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket
// bucket responsible for maintaining an index of disabled edge
// policies. Each entry exists within the bucket as follows:
//
// maps: <chanID><direction> -> []byte{}
//
// The chanID represents the channel ID of the edge and the direction is
// one byte representing the direction of the edge. The main purpose of
// this index is to allow pruning disabled channels in a fast way
// without the need to iterate all over the graph.
disabledEdgePolicyBucket = []byte("disabled-edge-policy-index")
// graphMetaBucket is a top-level bucket which stores various meta-deta
// related to the on-disk channel graph. Data stored in this bucket
// includes the block to which the graph has been synced to, the total
// number of channels, etc.
graphMetaBucket = []byte("graph-meta")
// pruneLogBucket is a bucket within the graphMetaBucket that stores
// a mapping from the block height to the hash for the blocks used to
// prune the graph.
// Once a new block is discovered, any channels that have been closed
// (by spending the outpoint) can safely be removed from the graph, and
// the block is added to the prune log. We need to keep such a log for
// the case where a reorg happens, and we must "rewind" the state of the
// graph by removing channels that were previously confirmed. In such a
// case we'll remove all entries from the prune log with a block height
// that no longer exists.
pruneLogBucket = []byte("prune-log")
// closedScidBucket is a top-level bucket that stores scids for
// channels that we know to be closed. This is used so that we don't
// need to perform expensive validation checks if we receive a channel
// announcement for the channel again.
//
// maps: scid -> []byte{}
closedScidBucket = []byte("closed-scid")
)
const (
// MaxAllowedExtraOpaqueBytes is the largest amount of opaque bytes that
// we'll permit to be written to disk. We limit this as otherwise, it
// would be possible for a node to create a ton of updates and slowly
// fill our disk, and also waste bandwidth due to relaying.
MaxAllowedExtraOpaqueBytes = 10000
)
// KVStore is a persistent, on-disk graph representation of the Lightning
// Network. This struct can be used to implement path finding algorithms on top
// of, and also to update a node's view based on information received from the
// p2p network. Internally, the graph is stored using a modified adjacency list
// representation with some added object interaction possible with each
// serialized edge/node. The graph is stored is directed, meaning that are two
// edges stored for each channel: an inbound/outbound edge for each node pair.
// Nodes, edges, and edge information can all be added to the graph
// independently. Edge removal results in the deletion of all edge information
// for that edge.
type KVStore struct {
db kvdb.Backend
// cacheMu guards all caches (rejectCache and chanCache). If
// this mutex will be acquired at the same time as the DB mutex then
// the cacheMu MUST be acquired first to prevent deadlock.
cacheMu sync.RWMutex
rejectCache *rejectCache
chanCache *channelCache
chanScheduler batch.Scheduler
nodeScheduler batch.Scheduler
}
// NewKVStore allocates a new KVStore backed by a DB instance. The
// returned instance has its own unique reject cache and channel cache.
func NewKVStore(db kvdb.Backend, options ...KVStoreOptionModifier) (*KVStore,
error) {
opts := DefaultOptions()
for _, o := range options {
o(opts)
}
if !opts.NoMigration {
if err := initKVStore(db); err != nil {
return nil, err
}
}
g := &KVStore{
db: db,
rejectCache: newRejectCache(opts.RejectCacheSize),
chanCache: newChannelCache(opts.ChannelCacheSize),
}
g.chanScheduler = batch.NewTimeScheduler(
db, &g.cacheMu, opts.BatchCommitInterval,
)
g.nodeScheduler = batch.NewTimeScheduler(
db, nil, opts.BatchCommitInterval,
)
return g, nil
}
// channelMapKey is the key structure used for storing channel edge policies.
type channelMapKey struct {
nodeKey route.Vertex
chanID [8]byte
}
// getChannelMap loads all channel edge policies from the database and stores
// them in a map.
func (c *KVStore) getChannelMap(edges kvdb.RBucket) (
map[channelMapKey]*models.ChannelEdgePolicy, error) {
// Create a map to store all channel edge policies.
channelMap := make(map[channelMapKey]*models.ChannelEdgePolicy)
err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error {
// Skip embedded buckets.
if bytes.Equal(k, edgeIndexBucket) ||
bytes.Equal(k, edgeUpdateIndexBucket) ||
bytes.Equal(k, zombieBucket) ||
bytes.Equal(k, disabledEdgePolicyBucket) ||
bytes.Equal(k, channelPointBucket) {
return nil
}
// Validate key length.
if len(k) != 33+8 {
return fmt.Errorf("invalid edge key %x encountered", k)
}
var key channelMapKey
copy(key.nodeKey[:], k[:33])
copy(key.chanID[:], k[33:])
// No need to deserialize unknown policy.
if bytes.Equal(edgeBytes, unknownPolicy) {
return nil
}
edgeReader := bytes.NewReader(edgeBytes)
edge, err := deserializeChanEdgePolicyRaw(
edgeReader,
)
switch {
// If the db policy was missing an expected optional field, we
// return nil as if the policy was unknown.
case errors.Is(err, ErrEdgePolicyOptionalFieldNotFound):
return nil
case err != nil:
return err
}
channelMap[key] = edge
return nil
})
if err != nil {
return nil, err
}
return channelMap, nil
}
var graphTopLevelBuckets = [][]byte{
nodeBucket,
edgeBucket,
graphMetaBucket,
closedScidBucket,
}
// Wipe completely deletes all saved state within all used buckets within the
// database. The deletion is done in a single transaction, therefore this
// operation is fully atomic.
func (c *KVStore) Wipe() error {
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb)
if err != nil &&
!errors.Is(err, kvdb.ErrBucketNotFound) {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
return initKVStore(c.db)
}
// createChannelDB creates and initializes a fresh version of In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
// within the database are created.
func initKVStore(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err
}
}
nodes := tx.ReadWriteBucket(nodeBucket)
_, err := nodes.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucketIfNotExists(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges := tx.ReadWriteBucket(edgeBucket)
_, err = edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(channelPointBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
graphMeta := tx.ReadWriteBucket(graphMetaBucket)
_, err = graphMeta.CreateBucketIfNotExists(pruneLogBucket)
return err
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channel graph: %w", err)
}
return nil
}
// AddrsForNode returns all known addresses for the target node public key that
// the graph DB is aware of. The returned boolean indicates if the given node is
// unknown to the graph DB or not.
//
// NOTE: this is part of the channeldb.AddrSource interface.
func (c *KVStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
error) {
pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed())
if err != nil {
return false, nil, err
}
node, err := c.FetchLightningNode(pubKey)
// We don't consider it an error if the graph is unaware of the node.
switch {
case err != nil && !errors.Is(err, ErrGraphNodeNotFound):
return false, nil, err
case errors.Is(err, ErrGraphNodeNotFound):
return false, nil, nil
}
return true, node.Addresses, nil
}
// ForEachChannel iterates through all the channel edges stored within the
// graph and invokes the passed callback for each edge. The callback takes two
// edges as since this is a directed graph, both the in/out edges are visited.
// If the callback returns an error, then the transaction is aborted and the
// iteration stops early.
//
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
// for that particular channel edge routing policy will be passed into the
// callback.
func (c *KVStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
return c.db.View(func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
// First, load all edges in memory indexed by node and channel
// id.
channelMap, err := c.getChannelMap(edges)
if err != nil {
return err
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// Load edge index, recombine each channel with the policies
// loaded above and invoke the callback.
return kvdb.ForAll(
edgeIndex, func(k, edgeInfoBytes []byte) error {
var chanID [8]byte
copy(chanID[:], k)
edgeInfoReader := bytes.NewReader(edgeInfoBytes)
info, err := deserializeChanEdgeInfo(
edgeInfoReader,
)
if err != nil {
return err
}
policy1 := channelMap[channelMapKey{
nodeKey: info.NodeKey1Bytes,
chanID: chanID,
}]
policy2 := channelMap[channelMapKey{
nodeKey: info.NodeKey2Bytes,
chanID: chanID,
}]
return cb(&info, policy1, policy2)
},
)
}, func() {})
}
// forEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller. An optional read
// transaction may be provided. If none is provided, a new one will be created.
//
// Unknown policies are passed into the callback as nil values.
func (c *KVStore) forEachNodeDirectedChannel(tx kvdb.RTx,
node route.Vertex, cb func(channel *DirectedChannel) error) error {
// Fallback that uses the database.
toNodeCallback := func() route.Vertex {
return node
}
toNodeFeatures, err := c.fetchNodeFeatures(tx, node)
if err != nil {
return err
}
dbCallback := func(tx kvdb.RTx, e *models.ChannelEdgeInfo, p1,
p2 *models.ChannelEdgePolicy) error {
var cachedInPolicy *models.CachedEdgePolicy
if p2 != nil {
cachedInPolicy = models.NewCachedPolicy(p2)
cachedInPolicy.ToNodePubKey = toNodeCallback
cachedInPolicy.ToNodeFeatures = toNodeFeatures
}
var inboundFee lnwire.Fee
if p1 != nil {
// Extract inbound fee. If there is a decoding error,
// skip this edge.
_, err := p1.ExtraOpaqueData.ExtractRecords(&inboundFee)
if err != nil {
return nil
}
}
directedChannel := &DirectedChannel{
ChannelID: e.ChannelID,
IsNode1: node == e.NodeKey1Bytes,
OtherNode: e.NodeKey2Bytes,
Capacity: e.Capacity,
OutPolicySet: p1 != nil,
InPolicy: cachedInPolicy,
InboundFee: inboundFee,
}
if node == e.NodeKey2Bytes {
directedChannel.OtherNode = e.NodeKey1Bytes
}
return cb(directedChannel)
}
return nodeTraversal(tx, node[:], c.db, dbCallback)
}
// fetchNodeFeatures returns the features of a given node. If no features are
// known for the node, an empty feature vector is returned. An optional read
// transaction may be provided. If none is provided, a new one will be created.
func (c *KVStore) fetchNodeFeatures(tx kvdb.RTx,
node route.Vertex) (*lnwire.FeatureVector, error) {
// Fallback that uses the database.
targetNode, err := c.FetchLightningNodeTx(tx, node)
switch {
// If the node exists and has features, return them directly.
case err == nil:
return targetNode.Features, nil
// If we couldn't find a node announcement, populate a blank feature
// vector.
case errors.Is(err, ErrGraphNodeNotFound):
return lnwire.EmptyFeatureVector(), nil
// Otherwise, bubble the error up.
default:
return nil, err
}
}
// ForEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *KVStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {
return c.forEachNodeDirectedChannel(nil, nodePub, cb)
}
// FetchNodeFeatures returns the features of the given node. If no features are
// known for the node, an empty feature vector is returned.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (c *KVStore) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return c.fetchNodeFeatures(nil, nodePub)
}
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
// data to the call-back.
//
// NOTE: The callback contents MUST not be modified.
func (c *KVStore) ForEachNodeCached(cb func(node route.Vertex,
chans map[uint64]*DirectedChannel) error) error {
// Otherwise call back to a version that uses the database directly.
// We'll iterate over each node, then the set of channels for each
// node, and construct a similar callback functiopn signature as the
// main funcotin expects.
return c.forEachNode(func(tx kvdb.RTx,
node *models.LightningNode) error {
channels := make(map[uint64]*DirectedChannel)
err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes,
func(tx kvdb.RTx, e *models.ChannelEdgeInfo,
p1 *models.ChannelEdgePolicy,
p2 *models.ChannelEdgePolicy) error {
toNodeCallback := func() route.Vertex {
return node.PubKeyBytes
}
toNodeFeatures, err := c.fetchNodeFeatures(
tx, node.PubKeyBytes,
)
if err != nil {
return err
}
var cachedInPolicy *models.CachedEdgePolicy
if p2 != nil {
cachedInPolicy =
models.NewCachedPolicy(p2)
cachedInPolicy.ToNodePubKey =
toNodeCallback
cachedInPolicy.ToNodeFeatures =
toNodeFeatures
}
directedChannel := &DirectedChannel{
ChannelID: e.ChannelID,
IsNode1: node.PubKeyBytes ==
e.NodeKey1Bytes,
OtherNode: e.NodeKey2Bytes,
Capacity: e.Capacity,
OutPolicySet: p1 != nil,
InPolicy: cachedInPolicy,
}
if node.PubKeyBytes == e.NodeKey2Bytes {
directedChannel.OtherNode =
e.NodeKey1Bytes
}
channels[e.ChannelID] = directedChannel
return nil
})
if err != nil {
return err
}
return cb(node.PubKeyBytes, channels)
})
}
// DisabledChannelIDs returns the channel ids of disabled channels.
// A channel is disabled when two of the associated ChanelEdgePolicies
// have their disabled bit on.
func (c *KVStore) DisabledChannelIDs() ([]uint64, error) {
var disabledChanIDs []uint64
var chanEdgeFound map[uint64]struct{}
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
disabledEdgePolicyIndex := edges.NestedReadBucket(
disabledEdgePolicyBucket,
)
if disabledEdgePolicyIndex == nil {
return nil
}
// We iterate over all disabled policies and we add each channel
// that has more than one disabled policy to disabledChanIDs
// array.
return disabledEdgePolicyIndex.ForEach(
func(k, v []byte) error {
chanID := byteOrder.Uint64(k[:8])
_, edgeFound := chanEdgeFound[chanID]
if edgeFound {
delete(chanEdgeFound, chanID)
disabledChanIDs = append(
disabledChanIDs, chanID,
)
return nil
}
chanEdgeFound[chanID] = struct{}{}
return nil
},
)
}, func() {
disabledChanIDs = nil
chanEdgeFound = make(map[uint64]struct{})
})
if err != nil {
return nil, err
}
return disabledChanIDs, nil
}
// ForEachNode iterates through all the stored vertices/nodes in the graph,
// executing the passed callback with each node encountered. If the callback
// returns an error, then the transaction is aborted and the iteration stops
// early. Any operations performed on the NodeTx passed to the call-back are
// executed under the same read transaction and so, methods on the NodeTx object
// _MUST_ only be called from within the call-back.
func (c *KVStore) ForEachNode(cb func(tx NodeRTx) error) error {
return c.forEachNode(func(tx kvdb.RTx,
node *models.LightningNode) error {
return cb(newChanGraphNodeTx(tx, c, node))
})
}
// forEachNode iterates through all the stored vertices/nodes in the graph,
// executing the passed callback with each node encountered. If the callback
// returns an error, then the transaction is aborted and the iteration stops
// early.
//
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
// traversal when graph gets mega.
func (c *KVStore) forEachNode(
cb func(kvdb.RTx, *models.LightningNode) error) error {
traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
nodeReader := bytes.NewReader(nodeBytes)
node, err := deserializeLightningNode(nodeReader)
if err != nil {
return err
}
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(tx, &node)
})
}
return kvdb.View(c.db, traversal, func() {})
}
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
func (c *KVStore) ForEachNodeCacheable(cb func(route.Vertex,
*lnwire.FeatureVector) error) error {
traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
nodeReader := bytes.NewReader(nodeBytes)
node, features, err := deserializeLightningNodeCacheable( //nolint:ll
nodeReader,
)
if err != nil {
return err
}
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(node, features)
})
}
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
// node based off the source node.
func (c *KVStore) SourceNode() (*models.LightningNode, error) {
var source *models.LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
node, err := c.sourceNode(nodes)
if err != nil {
return err
}
source = node
return nil
}, func() {
source = nil
})
if err != nil {
return nil, err
}
return source, nil
}
// sourceNode uses an existing database transaction and returns the source node
// of the graph. The source node is treated as the center node within a
// star-graph. This method may be used to kick off a path finding algorithm in
// order to explore the reachability of another node based off the source node.
func (c *KVStore) sourceNode(nodes kvdb.RBucket) (*models.LightningNode,
error) {
selfPub := nodes.Get(sourceKey)
if selfPub == nil {
return nil, ErrSourceNodeNotSet
}
// With the pubKey of the source node retrieved, we're able to
// fetch the full node information.
node, err := fetchLightningNode(nodes, selfPub)
if err != nil {
return nil, err
}
return &node, nil
}
// SetSourceNode sets the source node within the graph database. The source
// node is to be used as the center of a star-graph within path finding
// algorithms.
func (c *KVStore) SetSourceNode(node *models.LightningNode) error {
nodePubBytes := node.PubKeyBytes[:]
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
// Next we create the mapping from source to the targeted
// public key.
if err := nodes.Put(sourceKey, nodePubBytes); err != nil {
return err
}
// Finally, we commit the information of the lightning node
// itself.
return addLightningNode(tx, node)
}, func() {})
}
// AddLightningNode adds a vertex/node to the graph database. If the node is not
// in the database from before, this will add a new, unconnected one to the
// graph. If it is present from before, this will update that node's
// information. Note that this method is expected to only be called to update an
// already present node from a node announcement, or to insert a node found in a
// channel update.
//
// TODO(roasbeef): also need sig of announcement.
func (c *KVStore) AddLightningNode(node *models.LightningNode,
op ...batch.SchedulerOption) error {
r := &batch.Request{
Update: func(tx kvdb.RwTx) error {
return addLightningNode(tx, node)
},
}
for _, f := range op {
f(r)
}
return c.nodeScheduler.Execute(r)
}
func addLightningNode(tx kvdb.RwTx, node *models.LightningNode) error {
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
aliases, err := nodes.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
updateIndex, err := nodes.CreateBucketIfNotExists(
nodeUpdateIndexBucket,
)
if err != nil {
return err
}
return putLightningNode(nodes, aliases, updateIndex, node)
}
// LookupAlias attempts to return the alias as advertised by the target node.
// TODO(roasbeef): currently assumes that aliases are unique...
func (c *KVStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
var alias string
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
aliases := nodes.NestedReadBucket(aliasIndexBucket)
if aliases == nil {
return ErrGraphNodesNotFound
}
nodePub := pub.SerializeCompressed()
a := aliases.Get(nodePub)
if a == nil {
return ErrNodeAliasNotFound
}
// TODO(roasbeef): should actually be using the utf-8
// package...
alias = string(a)
return nil
}, func() {
alias = ""
})
if err != nil {
return "", err
}
return alias, nil
}
// DeleteLightningNode starts a new database transaction to remove a vertex/node
// from the database according to the node's public key.
func (c *KVStore) DeleteLightningNode(nodePub route.Vertex) error {
// TODO(roasbeef): ensure dangling edges are removed...
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodeNotFound
}
return c.deleteLightningNode(nodes, nodePub[:])
}, func() {})
}
// deleteLightningNode uses an existing database transaction to remove a
// vertex/node from the database according to the node's public key.
func (c *KVStore) deleteLightningNode(nodes kvdb.RwBucket,
compressedPubKey []byte) error {
aliases := nodes.NestedReadWriteBucket(aliasIndexBucket)
if aliases == nil {
return ErrGraphNodesNotFound
}
if err := aliases.Delete(compressedPubKey); err != nil {
return err
}
// Before we delete the node, we'll fetch its current state so we can
// determine when its last update was to clear out the node update
// index.
node, err := fetchLightningNode(nodes, compressedPubKey)
if err != nil {
return err
}
if err := nodes.Delete(compressedPubKey); err != nil {
return err
}
// Finally, we'll delete the index entry for the node within the
// nodeUpdateIndexBucket as this node is no longer active, so we don't
// need to track its last update.
nodeUpdateIndex := nodes.NestedReadWriteBucket(nodeUpdateIndexBucket)
if nodeUpdateIndex == nil {
return ErrGraphNodesNotFound
}
// In order to delete the entry, we'll need to reconstruct the key for
// its last update.
updateUnix := uint64(node.LastUpdate.Unix())
var indexKey [8 + 33]byte
byteOrder.PutUint64(indexKey[:8], updateUnix)
copy(indexKey[8:], compressedPubKey)
return nodeUpdateIndex.Delete(indexKey[:])
}
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
// undirected edge from the two target nodes are created. The information stored
// denotes the static attributes of the channel, such as the channelID, the keys
// involved in creation of the channel, and the set of features that the channel
// supports. The chanPoint and chanID are used to uniquely identify the edge
// globally within the database.
func (c *KVStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
op ...batch.SchedulerOption) error {
var alreadyExists bool
r := &batch.Request{
Reset: func() {
alreadyExists = false
},
Update: func(tx kvdb.RwTx) error {
err := c.addChannelEdge(tx, edge)
// Silence ErrEdgeAlreadyExist so that the batch can
// succeed, but propagate the error via local state.
if errors.Is(err, ErrEdgeAlreadyExist) {
alreadyExists = true
return nil
}
return err
},
OnCommit: func(err error) error {
switch {
case err != nil:
return err
case alreadyExists:
return ErrEdgeAlreadyExist
default:
c.rejectCache.remove(edge.ChannelID)
c.chanCache.remove(edge.ChannelID)
return nil
}
},
}
for _, f := range op {
if f == nil {
return fmt.Errorf("nil scheduler option was used")
}
f(r)
}
return c.chanScheduler.Execute(r)
}
// addChannelEdge is the private form of AddChannelEdge that allows callers to
// utilize an existing db transaction.
func (c *KVStore) addChannelEdge(tx kvdb.RwTx,
edge *models.ChannelEdgeInfo) error {
// Construct the channel's primary key which is the 8-byte channel ID.
var chanKey [8]byte
binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID)
nodes, err := tx.CreateTopLevelBucket(nodeBucket)
if err != nil {
return err
}
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return err
}
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket)
if err != nil {
return err
}
// First, attempt to check if this edge has already been created. If
// so, then we can exit early as this method is meant to be idempotent.
if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil {
return ErrEdgeAlreadyExist
}
// Before we insert the channel into the database, we'll ensure that
// both nodes already exist in the channel graph. If either node
// doesn't, then we'll insert a "shell" node that just includes its
// public key, so subsequent validation and queries can work properly.
_, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:])
switch {
case errors.Is(node1Err, ErrGraphNodeNotFound):
node1Shell := models.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
HaveNodeAnnouncement: false,
}
err := addLightningNode(tx, &node1Shell)
if err != nil {
return fmt.Errorf("unable to create shell node "+
"for: %x: %w", edge.NodeKey1Bytes, err)
}
case node1Err != nil:
return node1Err
}
_, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:])
switch {
case errors.Is(node2Err, ErrGraphNodeNotFound):
node2Shell := models.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
HaveNodeAnnouncement: false,
}
err := addLightningNode(tx, &node2Shell)
if err != nil {
return fmt.Errorf("unable to create shell node "+
"for: %x: %w", edge.NodeKey2Bytes, err)
}
case node2Err != nil:
return node2Err
}
// If the edge hasn't been created yet, then we'll first add it to the
// edge index in order to associate the edge between two nodes and also
// store the static components of the channel.
if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil {
return err
}
// Mark edge policies for both sides as unknown. This is to enable
// efficient incoming channel lookup for a node.
keys := []*[33]byte{
&edge.NodeKey1Bytes,
&edge.NodeKey2Bytes,
}
for _, key := range keys {
err := putChanEdgePolicyUnknown(edges, edge.ChannelID, key[:])
if err != nil {
return err
}
}
// Finally we add it to the channel index which maps channel points
// (outpoints) to the shorter channel ID's.
var b bytes.Buffer
if err := WriteOutpoint(&b, &edge.ChannelPoint); err != nil {
return err
}
return chanIndex.Put(b.Bytes(), chanKey[:])
}
// HasChannelEdge returns true if the database knows of a channel edge with the
// passed channel ID, and false otherwise. If an edge with that ID is found
// within the graph, then two time stamps representing the last time the edge
// was updated for both directed edges are returned along with the boolean. If
// it is not found, then the zombie index is checked and its result is returned
// as the second boolean.
func (c *KVStore) HasChannelEdge(
chanID uint64) (time.Time, time.Time, bool, bool, error) {
var (
upd1Time time.Time
upd2Time time.Time
exists bool
isZombie bool
)
// We'll query the cache with the shared lock held to allow multiple
// readers to access values in the cache concurrently if they exist.
c.cacheMu.RLock()
if entry, ok := c.rejectCache.get(chanID); ok {
c.cacheMu.RUnlock()
upd1Time = time.Unix(entry.upd1Time, 0)
upd2Time = time.Unix(entry.upd2Time, 0)
exists, isZombie = entry.flags.unpack()
return upd1Time, upd2Time, exists, isZombie, nil
}
c.cacheMu.RUnlock()
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
// The item was not found with the shared lock, so we'll acquire the
// exclusive lock and check the cache again in case another method added
// the entry to the cache while no lock was held.
if entry, ok := c.rejectCache.get(chanID); ok {
upd1Time = time.Unix(entry.upd1Time, 0)
upd2Time = time.Unix(entry.upd2Time, 0)
exists, isZombie = entry.flags.unpack()
return upd1Time, upd2Time, exists, isZombie, nil
}
if err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
var channelID [8]byte
byteOrder.PutUint64(channelID[:], chanID)
// If the edge doesn't exist, then we'll also check our zombie
// index.
if edgeIndex.Get(channelID[:]) == nil {
exists = false
zombieIndex := edges.NestedReadBucket(zombieBucket)
if zombieIndex != nil {
isZombie, _, _ = isZombieEdge(
zombieIndex, chanID,
)
}
return nil
}
exists = true
isZombie = false
// If the channel has been found in the graph, then retrieve
// the edges itself so we can return the last updated
// timestamps.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodeNotFound
}
e1, e2, err := fetchChanEdgePolicies(
edgeIndex, edges, channelID[:],
)
if err != nil {
return err
}
// As we may have only one of the edges populated, only set the
// update time if the edge was found in the database.
if e1 != nil {
upd1Time = e1.LastUpdate
}
if e2 != nil {
upd2Time = e2.LastUpdate
}
return nil
}, func() {}); err != nil {
return time.Time{}, time.Time{}, exists, isZombie, err
}
c.rejectCache.insert(chanID, rejectCacheEntry{
upd1Time: upd1Time.Unix(),
upd2Time: upd2Time.Unix(),
flags: packRejectFlags(exists, isZombie),
})
return upd1Time, upd2Time, exists, isZombie, nil
}
// AddEdgeProof sets the proof of an existing edge in the graph database.
func (c *KVStore) AddEdgeProof(chanID lnwire.ShortChannelID,
proof *models.ChannelAuthProof) error {
// Construct the channel's primary key which is the 8-byte channel ID.
var chanKey [8]byte
binary.BigEndian.PutUint64(chanKey[:], chanID.ToUint64())
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrEdgeNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrEdgeNotFound
}
edge, err := fetchChanEdgeInfo(edgeIndex, chanKey[:])
if err != nil {
return err
}
edge.AuthProof = proof
return putChanEdgeInfo(edgeIndex, &edge, chanKey)
}, func() {})
}
const (
// pruneTipBytes is the total size of the value which stores a prune
// entry of the graph in the prune log. The "prune tip" is the last
// entry in the prune log, and indicates if the channel graph is in
// sync with the current UTXO state. The structure of the value
// is: blockHash, taking 32 bytes total.
pruneTipBytes = 32
)
// PruneGraph prunes newly closed channels from the channel graph in response
// to a new block being solved on the network. Any transactions which spend the
// funding output of any known channels within he graph will be deleted.
// Additionally, the "prune tip", or the last block which has been used to
// prune the graph is stored so callers can ensure the graph is fully in sync
// with the current UTXO state. A slice of channels that have been closed by
// the target block along with any pruned nodes are returned if the function
// succeeds without error.
func (c *KVStore) PruneGraph(spentOutputs []*wire.OutPoint,
blockHash *chainhash.Hash, blockHeight uint32) (
[]*models.ChannelEdgeInfo, []route.Vertex, error) {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
var (
chansClosed []*models.ChannelEdgeInfo
prunedNodes []route.Vertex
)
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
// First grab the edges bucket which houses the information
// we'd like to delete
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return err
}
// Next grab the two edge indexes which will also need to be
// updated.
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
chanIndex, err := edges.CreateBucketIfNotExists(
channelPointBucket,
)
if err != nil {
return err
}
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return ErrSourceNodeNotSet
}
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
// For each of the outpoints that have been spent within the
// block, we attempt to delete them from the graph as if that
// outpoint was a channel, then it has now been closed.
for _, chanPoint := range spentOutputs {
// TODO(roasbeef): load channel bloom filter, continue
// if NOT if filter
var opBytes bytes.Buffer
err := WriteOutpoint(&opBytes, chanPoint)
if err != nil {
return err
}
// First attempt to see if the channel exists within
// the database, if not, then we can exit early.
chanID := chanIndex.Get(opBytes.Bytes())
if chanID == nil {
continue
}
// Attempt to delete the channel, an ErrEdgeNotFound
// will be returned if that outpoint isn't known to be
// a channel. If no error is returned, then a channel
// was successfully pruned.
edgeInfo, err := c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
chanID, false, false,
)
if err != nil && !errors.Is(err, ErrEdgeNotFound) {
return err
}
chansClosed = append(chansClosed, edgeInfo)
}
metaBucket, err := tx.CreateTopLevelBucket(graphMetaBucket)
if err != nil {
return err
}
pruneBucket, err := metaBucket.CreateBucketIfNotExists(
pruneLogBucket,
)
if err != nil {
return err
}
// With the graph pruned, add a new entry to the prune log,
// which can be used to check if the graph is fully synced with
// the current UTXO state.
var blockHeightBytes [4]byte
byteOrder.PutUint32(blockHeightBytes[:], blockHeight)
var newTip [pruneTipBytes]byte
copy(newTip[:], blockHash[:])
err = pruneBucket.Put(blockHeightBytes[:], newTip[:])
if err != nil {
return err
}
// Now that the graph has been pruned, we'll also attempt to
// prune any nodes that have had a channel closed within the
// latest block.
prunedNodes, err = c.pruneGraphNodes(nodes, edgeIndex)
return err
}, func() {
chansClosed = nil
prunedNodes = nil
})
if err != nil {
return nil, nil, err
}
for _, channel := range chansClosed {
c.rejectCache.remove(channel.ChannelID)
c.chanCache.remove(channel.ChannelID)
}
return chansClosed, prunedNodes, nil
}
// PruneGraphNodes is a garbage collection method which attempts to prune out
// any nodes from the channel graph that are currently unconnected. This ensure
// that we only maintain a graph of reachable nodes. In the event that a pruned
// node gains more channels, it will be re-added back to the graph.
func (c *KVStore) PruneGraphNodes() ([]route.Vertex, error) {
var prunedNodes []route.Vertex
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrGraphNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
var err error
prunedNodes, err = c.pruneGraphNodes(nodes, edgeIndex)
if err != nil {
return err
}
return nil
}, func() {
prunedNodes = nil
})
return prunedNodes, err
}
// pruneGraphNodes attempts to remove any nodes from the graph who have had a
// channel closed within the current block. If the node still has existing
// channels in the graph, this will act as a no-op.
func (c *KVStore) pruneGraphNodes(nodes kvdb.RwBucket,
edgeIndex kvdb.RwBucket) ([]route.Vertex, error) {
log.Trace("Pruning nodes from graph with no open channels")
// We'll retrieve the graph's source node to ensure we don't remove it
// even if it no longer has any open channels.
sourceNode, err := c.sourceNode(nodes)
if err != nil {
return nil, err
}
// We'll use this map to keep count the number of references to a node
// in the graph. A node should only be removed once it has no more
// references in the graph.
nodeRefCounts := make(map[[33]byte]int)
err = nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
var nodePub [33]byte
copy(nodePub[:], pubKey)
nodeRefCounts[nodePub] = 0
return nil
})
if err != nil {
return nil, err
}
// To ensure we never delete the source node, we'll start off by
// bumping its ref count to 1.
nodeRefCounts[sourceNode.PubKeyBytes] = 1
// Next, we'll run through the edgeIndex which maps a channel ID to the
// edge info. We'll use this scan to populate our reference count map
// above.
err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
// The first 66 bytes of the edge info contain the pubkeys of
// the nodes that this edge attaches. We'll extract them, and
// add them to the ref count map.
var node1, node2 [33]byte
copy(node1[:], edgeInfoBytes[:33])
copy(node2[:], edgeInfoBytes[33:])
// With the nodes extracted, we'll increase the ref count of
// each of the nodes.
nodeRefCounts[node1]++
nodeRefCounts[node2]++
return nil
})
if err != nil {
return nil, err
}
// Finally, we'll make a second pass over the set of nodes, and delete
// any nodes that have a ref count of zero.
var pruned []route.Vertex
for nodePubKey, refCount := range nodeRefCounts {
// If the ref count of the node isn't zero, then we can safely
// skip it as it still has edges to or from it within the
// graph.
if refCount != 0 {
continue
}
// If we reach this point, then there are no longer any edges
// that connect this node, so we can delete it.
err := c.deleteLightningNode(nodes, nodePubKey[:])
if err != nil {
if errors.Is(err, ErrGraphNodeNotFound) ||
errors.Is(err, ErrGraphNodesNotFound) {
log.Warnf("Unable to prune node %x from the "+
"graph: %v", nodePubKey, err)
continue
}
return nil, err
}
log.Infof("Pruned unconnected node %x from channel graph",
nodePubKey[:])
pruned = append(pruned, nodePubKey)
}
if len(pruned) > 0 {
log.Infof("Pruned %v unconnected nodes from the channel graph",
len(pruned))
}
return pruned, err
}
// DisconnectBlockAtHeight is used to indicate that the block specified
// by the passed height has been disconnected from the main chain. This
// will "rewind" the graph back to the height below, deleting channels
// that are no longer confirmed from the graph. The prune log will be
// set to the last prune height valid for the remaining chain.
// Channels that were removed from the graph resulting from the
// disconnected block are returned.
func (c *KVStore) DisconnectBlockAtHeight(height uint32) (
[]*models.ChannelEdgeInfo, error) {
// Every channel having a ShortChannelID starting at 'height'
// will no longer be confirmed.
startShortChanID := lnwire.ShortChannelID{
BlockHeight: height,
}
// Delete everything after this height from the db up until the
// SCID alias range.
endShortChanID := aliasmgr.StartingAlias
// The block height will be the 3 first bytes of the channel IDs.
var chanIDStart [8]byte
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
var chanIDEnd [8]byte
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
// Keep track of the channels that are removed from the graph.
var removedChans []*models.ChannelEdgeInfo
if err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
edges, err := tx.CreateTopLevelBucket(edgeBucket)
if err != nil {
return err
}
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
chanIndex, err := edges.CreateBucketIfNotExists(
channelPointBucket,
)
if err != nil {
return err
}
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
// Scan from chanIDStart to chanIDEnd, deleting every
// found edge.
// NOTE: we must delete the edges after the cursor loop, since
// modifying the bucket while traversing is not safe.
// NOTE: We use a < comparison in bytes.Compare instead of <=
// so that the StartingAlias itself isn't deleted.
var keys [][]byte
cursor := edgeIndex.ReadWriteCursor()
//nolint:ll
for k, _ := cursor.Seek(chanIDStart[:]); k != nil &&
bytes.Compare(k, chanIDEnd[:]) < 0; k, _ = cursor.Next() {
keys = append(keys, k)
}
for _, k := range keys {
edgeInfo, err := c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
k, false, false,
)
if err != nil && !errors.Is(err, ErrEdgeNotFound) {
return err
}
removedChans = append(removedChans, edgeInfo)
}
// Delete all the entries in the prune log having a height
// greater or equal to the block disconnected.
metaBucket, err := tx.CreateTopLevelBucket(graphMetaBucket)
if err != nil {
return err
}
pruneBucket, err := metaBucket.CreateBucketIfNotExists(
pruneLogBucket,
)
if err != nil {
return err
}
var pruneKeyStart [4]byte
byteOrder.PutUint32(pruneKeyStart[:], height)
var pruneKeyEnd [4]byte
byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32)
// To avoid modifying the bucket while traversing, we delete
// the keys in a second loop.
var pruneKeys [][]byte
pruneCursor := pruneBucket.ReadWriteCursor()
//nolint:ll
for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil &&
bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() {
pruneKeys = append(pruneKeys, k)
}
for _, k := range pruneKeys {
if err := pruneBucket.Delete(k); err != nil {
return err
}
}
return nil
}, func() {
removedChans = nil
}); err != nil {
return nil, err
}
for _, channel := range removedChans {
c.rejectCache.remove(channel.ChannelID)
c.chanCache.remove(channel.ChannelID)
}
return removedChans, nil
}
// PruneTip returns the block height and hash of the latest block that has been
// used to prune channels in the graph. Knowing the "prune tip" allows callers
// to tell if the graph is currently in sync with the current best known UTXO
// state.
func (c *KVStore) PruneTip() (*chainhash.Hash, uint32, error) {
var (
tipHash chainhash.Hash
tipHeight uint32
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
graphMeta := tx.ReadBucket(graphMetaBucket)
if graphMeta == nil {
return ErrGraphNotFound
}
pruneBucket := graphMeta.NestedReadBucket(pruneLogBucket)
if pruneBucket == nil {
return ErrGraphNeverPruned
}
pruneCursor := pruneBucket.ReadCursor()
// The prune key with the largest block height will be our
// prune tip.
k, v := pruneCursor.Last()
if k == nil {
return ErrGraphNeverPruned
}
// Once we have the prune tip, the value will be the block hash,
// and the key the block height.
copy(tipHash[:], v)
tipHeight = byteOrder.Uint32(k)
return nil
}, func() {})
if err != nil {
return nil, 0, err
}
return &tipHash, tipHeight, nil
}
// DeleteChannelEdges removes edges with the given channel IDs from the
// database and marks them as zombies. This ensures that we're unable to re-add
// it to our database once again. If an edge does not exist within the
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
// true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state. The markZombie bool
// denotes whether or not to mark the channel as a zombie.
func (c *KVStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
// TODO(roasbeef): possibly delete from node bucket if node has no more
// channels
// TODO(roasbeef): don't delete both edges?
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
var infos []*models.ChannelEdgeInfo
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrEdgeNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrEdgeNotFound
}
chanIndex := edges.NestedReadWriteBucket(channelPointBucket)
if chanIndex == nil {
return ErrEdgeNotFound
}
nodes := tx.ReadWriteBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodeNotFound
}
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
var rawChanID [8]byte
for _, chanID := range chanIDs {
byteOrder.PutUint64(rawChanID[:], chanID)
edgeInfo, err := c.delChannelEdgeUnsafe(
edges, edgeIndex, chanIndex, zombieIndex,
rawChanID[:], markZombie, strictZombiePruning,
)
if err != nil {
return err
}
infos = append(infos, edgeInfo)
}
return nil
}, func() {
infos = nil
})
if err != nil {
return nil, err
}
for _, chanID := range chanIDs {
c.rejectCache.remove(chanID)
c.chanCache.remove(chanID)
}
return infos, nil
}
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
// passed channel point (outpoint). If the passed channel doesn't exist within
// the database, then ErrEdgeNotFound is returned.
func (c *KVStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
var chanID uint64
if err := kvdb.View(c.db, func(tx kvdb.RTx) error {
var err error
chanID, err = getChanID(tx, chanPoint)
return err
}, func() {
chanID = 0
}); err != nil {
return 0, err
}
return chanID, nil
}
// getChanID returns the assigned channel ID for a given channel point.
func getChanID(tx kvdb.RTx, chanPoint *wire.OutPoint) (uint64, error) {
var b bytes.Buffer
if err := WriteOutpoint(&b, chanPoint); err != nil {
return 0, err
}
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return 0, ErrGraphNoEdgesFound
}
chanIndex := edges.NestedReadBucket(channelPointBucket)
if chanIndex == nil {
return 0, ErrGraphNoEdgesFound
}
chanIDBytes := chanIndex.Get(b.Bytes())
if chanIDBytes == nil {
return 0, ErrEdgeNotFound
}
chanID := byteOrder.Uint64(chanIDBytes)
return chanID, nil
}
// TODO(roasbeef): allow updates to use Batch?
// HighestChanID returns the "highest" known channel ID in the channel graph.
// This represents the "newest" channel from the PoV of the chain. This method
// can be used by peers to quickly determine if they're graphs are in sync.
func (c *KVStore) HighestChanID() (uint64, error) {
var cid uint64
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// In order to find the highest chan ID, we'll fetch a cursor
// and use that to seek to the "end" of our known rage.
cidCursor := edgeIndex.ReadCursor()
lastChanID, _ := cidCursor.Last()
// If there's no key, then this means that we don't actually
// know of any channels, so we'll return a predicable error.
if lastChanID == nil {
return ErrGraphNoEdgesFound
}
// Otherwise, we'll de serialize the channel ID and return it
// to the caller.
cid = byteOrder.Uint64(lastChanID)
return nil
}, func() {
cid = 0
})
if err != nil && !errors.Is(err, ErrGraphNoEdgesFound) {
return 0, err
}
return cid, nil
}
// ChannelEdge represents the complete set of information for a channel edge in
// the known channel graph. This struct couples the core information of the
// edge as well as each of the known advertised edge policies.
type ChannelEdge struct {
// Info contains all the static information describing the channel.
Info *models.ChannelEdgeInfo
// Policy1 points to the "first" edge policy of the channel containing
// the dynamic information required to properly route through the edge.
Policy1 *models.ChannelEdgePolicy
// Policy2 points to the "second" edge policy of the channel containing
// the dynamic information required to properly route through the edge.
Policy2 *models.ChannelEdgePolicy
// Node1 is "node 1" in the channel. This is the node that would have
// produced Policy1 if it exists.
Node1 *models.LightningNode
// Node2 is "node 2" in the channel. This is the node that would have
// produced Policy2 if it exists.
Node2 *models.LightningNode
}
// ChanUpdatesInHorizon returns all the known channel edges which have at least
// one edge that has an update timestamp within the specified horizon.
func (c *KVStore) ChanUpdatesInHorizon(startTime,
endTime time.Time) ([]ChannelEdge, error) {
// To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent
// re-adding it.
var edgesSeen map[uint64]struct{}
var edgesToCache map[uint64]ChannelEdge
var edgesInHorizon []ChannelEdge
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
var hits int
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
edgeUpdateIndex := edges.NestedReadBucket(edgeUpdateIndexBucket)
if edgeUpdateIndex == nil {
return ErrGraphNoEdgesFound
}
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
// We'll now obtain a cursor to perform a range query within
// the index to find all channels within the horizon.
updateCursor := edgeUpdateIndex.ReadCursor()
var startTimeBytes, endTimeBytes [8 + 8]byte
byteOrder.PutUint64(
startTimeBytes[:8], uint64(startTime.Unix()),
)
byteOrder.PutUint64(
endTimeBytes[:8], uint64(endTime.Unix()),
)
// With our start and end times constructed, we'll step through
// the index collecting the info and policy of each update of
// each channel that has a last update within the time range.
//
//nolint:ll
for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil &&
bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() {
// We have a new eligible entry, so we'll slice of the
// chan ID so we can query it in the DB.
chanID := indexKey[8:]
// If we've already retrieved the info and policies for
// this edge, then we can skip it as we don't need to do
// so again.
chanIDInt := byteOrder.Uint64(chanID)
if _, ok := edgesSeen[chanIDInt]; ok {
continue
}
if channel, ok := c.chanCache.get(chanIDInt); ok {
hits++
edgesSeen[chanIDInt] = struct{}{}
edgesInHorizon = append(edgesInHorizon, channel)
continue
}
// First, we'll fetch the static edge information.
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
if err != nil {
chanID := byteOrder.Uint64(chanID)
return fmt.Errorf("unable to fetch info for "+
"edge with chan_id=%v: %v", chanID, err)
}
// With the static information obtained, we'll now
// fetch the dynamic policy info.
edge1, edge2, err := fetchChanEdgePolicies(
edgeIndex, edges, chanID,
)
if err != nil {
chanID := byteOrder.Uint64(chanID)
return fmt.Errorf("unable to fetch policies "+
"for edge with chan_id=%v: %v", chanID,
err)
}
node1, err := fetchLightningNode(
nodes, edgeInfo.NodeKey1Bytes[:],
)
if err != nil {
return err
}
node2, err := fetchLightningNode(
nodes, edgeInfo.NodeKey2Bytes[:],
)
if err != nil {
return err
}
// Finally, we'll collate this edge with the rest of
// edges to be returned.
edgesSeen[chanIDInt] = struct{}{}
channel := ChannelEdge{
Info: &edgeInfo,
Policy1: edge1,
Policy2: edge2,
Node1: &node1,
Node2: &node2,
}
edgesInHorizon = append(edgesInHorizon, channel)
edgesToCache[chanIDInt] = channel
}
return nil
}, func() {
edgesSeen = make(map[uint64]struct{})
edgesToCache = make(map[uint64]ChannelEdge)
edgesInHorizon = nil
})
switch {
case errors.Is(err, ErrGraphNoEdgesFound):
fallthrough
case errors.Is(err, ErrGraphNodesNotFound):
break
case err != nil:
return nil, err
}
// Insert any edges loaded from disk into the cache.
for chanid, channel := range edgesToCache {
c.chanCache.insert(chanid, channel)
}
log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
float64(hits)/float64(len(edgesInHorizon)), hits,
len(edgesInHorizon))
return edgesInHorizon, nil
}
// NodeUpdatesInHorizon returns all the known lightning node which have an
// update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node
// announcements.
func (c *KVStore) NodeUpdatesInHorizon(startTime,
endTime time.Time) ([]models.LightningNode, error) {
var nodesInHorizon []models.LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
nodeUpdateIndex := nodes.NestedReadBucket(nodeUpdateIndexBucket)
if nodeUpdateIndex == nil {
return ErrGraphNodesNotFound
}
// We'll now obtain a cursor to perform a range query within
// the index to find all node announcements within the horizon.
updateCursor := nodeUpdateIndex.ReadCursor()
var startTimeBytes, endTimeBytes [8 + 33]byte
byteOrder.PutUint64(
startTimeBytes[:8], uint64(startTime.Unix()),
)
byteOrder.PutUint64(
endTimeBytes[:8], uint64(endTime.Unix()),
)
// With our start and end times constructed, we'll step through
// the index collecting info for each node within the time
// range.
//
//nolint:ll
for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil &&
bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() {
nodePub := indexKey[8:]
node, err := fetchLightningNode(nodes, nodePub)
if err != nil {
return err
}
nodesInHorizon = append(nodesInHorizon, node)
}
return nil
}, func() {
nodesInHorizon = nil
})
switch {
case errors.Is(err, ErrGraphNoEdgesFound):
fallthrough
case errors.Is(err, ErrGraphNodesNotFound):
break
case err != nil:
return nil, err
}
return nodesInHorizon, nil
}
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
// ID's that we don't know and are not known zombies of the passed set. In other
// words, we perform a set difference of our set of chan ID's and the ones
// passed in. This method can be used by callers to determine the set of
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
// known zombies is also returned.
func (c *KVStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
[]ChannelUpdateInfo, error) {
var (
newChanIDs []uint64
knownZombies []ChannelUpdateInfo
)
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// Fetch the zombie index, it may not exist if no edges have
// ever been marked as zombies. If the index has been
// initialized, we will use it later to skip known zombie edges.
zombieIndex := edges.NestedReadBucket(zombieBucket)
// We'll run through the set of chanIDs and collate only the
// set of channel that are unable to be found within our db.
var cidBytes [8]byte
for _, info := range chansInfo {
scid := info.ShortChannelID.ToUint64()
byteOrder.PutUint64(cidBytes[:], scid)
// If the edge is already known, skip it.
if v := edgeIndex.Get(cidBytes[:]); v != nil {
continue
}
// If the edge is a known zombie, skip it.
if zombieIndex != nil {
isZombie, _, _ := isZombieEdge(
zombieIndex, scid,
)
if isZombie {
knownZombies = append(
knownZombies, info,
)
continue
}
}
newChanIDs = append(newChanIDs, scid)
}
return nil
}, func() {
newChanIDs = nil
knownZombies = nil
})
switch {
// If we don't know of any edges yet, then we'll return the entire set
// of chan IDs specified.
case errors.Is(err, ErrGraphNoEdgesFound):
ogChanIDs := make([]uint64, len(chansInfo))
for i, info := range chansInfo {
ogChanIDs[i] = info.ShortChannelID.ToUint64()
}
return ogChanIDs, nil, nil
case err != nil:
return nil, nil, err
}
return newChanIDs, knownZombies, nil
}
// ChannelUpdateInfo couples the SCID of a channel with the timestamps of the
// latest received channel updates for the channel.
type ChannelUpdateInfo struct {
// ShortChannelID is the SCID identifier of the channel.
ShortChannelID lnwire.ShortChannelID
// Node1UpdateTimestamp is the timestamp of the latest received update
// from the node 1 channel peer. This will be set to zero time if no
// update has yet been received from this node.
Node1UpdateTimestamp time.Time
// Node2UpdateTimestamp is the timestamp of the latest received update
// from the node 2 channel peer. This will be set to zero time if no
// update has yet been received from this node.
Node2UpdateTimestamp time.Time
}
// NewChannelUpdateInfo is a constructor which makes sure we initialize the
// timestamps with zero seconds unix timestamp which equals
// `January 1, 1970, 00:00:00 UTC` in case the value is `time.Time{}`.
func NewChannelUpdateInfo(scid lnwire.ShortChannelID, node1Timestamp,
node2Timestamp time.Time) ChannelUpdateInfo {
chanInfo := ChannelUpdateInfo{
ShortChannelID: scid,
Node1UpdateTimestamp: node1Timestamp,
Node2UpdateTimestamp: node2Timestamp,
}
if node1Timestamp.IsZero() {
chanInfo.Node1UpdateTimestamp = time.Unix(0, 0)
}
if node2Timestamp.IsZero() {
chanInfo.Node2UpdateTimestamp = time.Unix(0, 0)
}
return chanInfo
}
// BlockChannelRange represents a range of channels for a given block height.
type BlockChannelRange struct {
// Height is the height of the block all of the channels below were
// included in.
Height uint32
// Channels is the list of channels identified by their short ID
// representation known to us that were included in the block height
// above. The list may include channel update timestamp information if
// requested.
Channels []ChannelUpdateInfo
}
// FilterChannelRange returns the channel ID's of all known channels which were
// mined in a block height within the passed range. The channel IDs are grouped
// by their common block height. This method can be used to quickly share with a
// peer the set of channels we know of within a particular range to catch them
// up after a period of time offline. If withTimestamps is true then the
// timestamp info of the latest received channel update messages of the channel
// will be included in the response.
func (c *KVStore) FilterChannelRange(startHeight,
endHeight uint32, withTimestamps bool) ([]BlockChannelRange, error) {
startChanID := &lnwire.ShortChannelID{
BlockHeight: startHeight,
}
endChanID := lnwire.ShortChannelID{
BlockHeight: endHeight,
TxIndex: math.MaxUint32 & 0x00ffffff,
TxPosition: math.MaxUint16,
}
// As we need to perform a range scan, we'll convert the starting and
// ending height to their corresponding values when encoded using short
// channel ID's.
var chanIDStart, chanIDEnd [8]byte
byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64())
byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64())
var channelsPerBlock map[uint32][]ChannelUpdateInfo
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
cursor := edgeIndex.ReadCursor()
// We'll now iterate through the database, and find each
// channel ID that resides within the specified range.
//
//nolint:ll
for k, v := cursor.Seek(chanIDStart[:]); k != nil &&
bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() {
// Don't send alias SCIDs during gossip sync.
edgeReader := bytes.NewReader(v)
edgeInfo, err := deserializeChanEdgeInfo(edgeReader)
if err != nil {
return err
}
if edgeInfo.AuthProof == nil {
continue
}
// This channel ID rests within the target range, so
// we'll add it to our returned set.
rawCid := byteOrder.Uint64(k)
cid := lnwire.NewShortChanIDFromInt(rawCid)
chanInfo := NewChannelUpdateInfo(
cid, time.Time{}, time.Time{},
)
if !withTimestamps {
channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight],
chanInfo,
)
continue
}
node1Key, node2Key := computeEdgePolicyKeys(&edgeInfo)
rawPolicy := edges.Get(node1Key)
if len(rawPolicy) != 0 {
r := bytes.NewReader(rawPolicy)
edge, err := deserializeChanEdgePolicyRaw(r)
if err != nil && !errors.Is(
err, ErrEdgePolicyOptionalFieldNotFound,
) {
return err
}
chanInfo.Node1UpdateTimestamp = edge.LastUpdate
}
rawPolicy = edges.Get(node2Key)
if len(rawPolicy) != 0 {
r := bytes.NewReader(rawPolicy)
edge, err := deserializeChanEdgePolicyRaw(r)
if err != nil && !errors.Is(
err, ErrEdgePolicyOptionalFieldNotFound,
) {
return err
}
chanInfo.Node2UpdateTimestamp = edge.LastUpdate
}
channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight], chanInfo,
)
}
return nil
}, func() {
channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
})
switch {
// If we don't know of any channels yet, then there's nothing to
// filter, so we'll return an empty slice.
case errors.Is(err, ErrGraphNoEdgesFound) || len(channelsPerBlock) == 0:
return nil, nil
case err != nil:
return nil, err
}
// Return the channel ranges in ascending block height order.
blocks := make([]uint32, 0, len(channelsPerBlock))
for block := range channelsPerBlock {
blocks = append(blocks, block)
}
sort.Slice(blocks, func(i, j int) bool {
return blocks[i] < blocks[j]
})
channelRanges := make([]BlockChannelRange, 0, len(channelsPerBlock))
for _, block := range blocks {
channelRanges = append(channelRanges, BlockChannelRange{
Height: block,
Channels: channelsPerBlock[block],
})
}
return channelRanges, nil
}
// FetchChanInfos returns the set of channel edges that correspond to the passed
// channel ID's. If an edge is the query is unknown to the database, it will
// skipped and the result will contain only those edges that exist at the time
// of the query. This can be used to respond to peer queries that are seeking to
// fill in gaps in their view of the channel graph.
func (c *KVStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
return c.fetchChanInfos(nil, chanIDs)
}
// fetchChanInfos returns the set of channel edges that correspond to the passed
// channel ID's. If an edge is the query is unknown to the database, it will
// skipped and the result will contain only those edges that exist at the time
// of the query. This can be used to respond to peer queries that are seeking to
// fill in gaps in their view of the channel graph.
//
// NOTE: An optional transaction may be provided. If none is provided, then a
// new one will be created.
func (c *KVStore) fetchChanInfos(tx kvdb.RTx, chanIDs []uint64) (
[]ChannelEdge, error) {
// TODO(roasbeef): sort cids?
var (
chanEdges []ChannelEdge
cidBytes [8]byte
)
fetchChanInfos := func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
for _, cid := range chanIDs {
byteOrder.PutUint64(cidBytes[:], cid)
// First, we'll fetch the static edge information. If
// the edge is unknown, we will skip the edge and
// continue gathering all known edges.
edgeInfo, err := fetchChanEdgeInfo(
edgeIndex, cidBytes[:],
)
switch {
case errors.Is(err, ErrEdgeNotFound):
continue
case err != nil:
return err
}
// With the static information obtained, we'll now
// fetch the dynamic policy info.
edge1, edge2, err := fetchChanEdgePolicies(
edgeIndex, edges, cidBytes[:],
)
if err != nil {
return err
}
node1, err := fetchLightningNode(
nodes, edgeInfo.NodeKey1Bytes[:],
)
if err != nil {
return err
}
node2, err := fetchLightningNode(
nodes, edgeInfo.NodeKey2Bytes[:],
)
if err != nil {
return err
}
chanEdges = append(chanEdges, ChannelEdge{
Info: &edgeInfo,
Policy1: edge1,
Policy2: edge2,
Node1: &node1,
Node2: &node2,
})
}
return nil
}
if tx == nil {
err := kvdb.View(c.db, fetchChanInfos, func() {
chanEdges = nil
})
if err != nil {
return nil, err
}
return chanEdges, nil
}
err := fetchChanInfos(tx)
if err != nil {
return nil, err
}
return chanEdges, nil
}
func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
edge1, edge2 *models.ChannelEdgePolicy) error {
// First, we'll fetch the edge update index bucket which currently
// stores an entry for the channel we're about to delete.
updateIndex := edgesBucket.NestedReadWriteBucket(edgeUpdateIndexBucket)
if updateIndex == nil {
// No edges in bucket, return early.
return nil
}
// Now that we have the bucket, we'll attempt to construct a template
// for the index key: updateTime || chanid.
var indexKey [8 + 8]byte
byteOrder.PutUint64(indexKey[8:], chanID)
// With the template constructed, we'll attempt to delete an entry that
// would have been created by both edges: we'll alternate the update
// times, as one may had overridden the other.
if edge1 != nil {
byteOrder.PutUint64(
indexKey[:8], uint64(edge1.LastUpdate.Unix()),
)
if err := updateIndex.Delete(indexKey[:]); err != nil {
return err
}
}
// We'll also attempt to delete the entry that may have been created by
// the second edge.
if edge2 != nil {
byteOrder.PutUint64(
indexKey[:8], uint64(edge2.LastUpdate.Unix()),
)
if err := updateIndex.Delete(indexKey[:]); err != nil {
return err
}
}
return nil
}
// delChannelEdgeUnsafe deletes the edge with the given chanID from the graph
// cache. It then goes on to delete any policy info and edge info for this
// channel from the DB and finally, if isZombie is true, it will add an entry
// for this channel in the zombie index.
//
// NOTE: this method MUST only be called if the cacheMu has already been
// acquired.
func (c *KVStore) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex,
zombieIndex kvdb.RwBucket, chanID []byte, isZombie,
strictZombie bool) (*models.ChannelEdgeInfo, error) {
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
if err != nil {
return nil, err
}
// We'll also remove the entry in the edge update index bucket before
// we delete the edges themselves so we can access their last update
// times.
cid := byteOrder.Uint64(chanID)
edge1, edge2, err := fetchChanEdgePolicies(edgeIndex, edges, chanID)
if err != nil {
return nil, err
}
err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2)
if err != nil {
return nil, err
}
// The edge key is of the format pubKey || chanID. First we construct
// the latter half, populating the channel ID.
var edgeKey [33 + 8]byte
copy(edgeKey[33:], chanID)
// With the latter half constructed, copy over the first public key to
// delete the edge in this direction, then the second to delete the
// edge in the opposite direction.
copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:])
if edges.Get(edgeKey[:]) != nil {
if err := edges.Delete(edgeKey[:]); err != nil {
return nil, err
}
}
copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:])
if edges.Get(edgeKey[:]) != nil {
if err := edges.Delete(edgeKey[:]); err != nil {
return nil, err
}
}
// As part of deleting the edge we also remove all disabled entries
// from the edgePolicyDisabledIndex bucket. We do that for both
// directions.
err = updateEdgePolicyDisabledIndex(edges, cid, false, false)
if err != nil {
return nil, err
}
err = updateEdgePolicyDisabledIndex(edges, cid, true, false)
if err != nil {
return nil, err
}
// With the edge data deleted, we can purge the information from the two
// edge indexes.
if err := edgeIndex.Delete(chanID); err != nil {
return nil, err
}
var b bytes.Buffer
if err := WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil {
return nil, err
}
if err := chanIndex.Delete(b.Bytes()); err != nil {
return nil, err
}
// Finally, we'll mark the edge as a zombie within our index if it's
// being removed due to the channel becoming a zombie. We do this to
// ensure we don't store unnecessary data for spent channels.
if !isZombie {
return &edgeInfo, nil
}
nodeKey1, nodeKey2 := edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes
if strictZombie {
nodeKey1, nodeKey2 = makeZombiePubkeys(&edgeInfo, edge1, edge2)
}
return &edgeInfo, markEdgeZombie(
zombieIndex, byteOrder.Uint64(chanID), nodeKey1, nodeKey2,
)
}
// makeZombiePubkeys derives the node pubkeys to store in the zombie index for a
// particular pair of channel policies. The return values are one of:
// 1. (pubkey1, pubkey2)
// 2. (pubkey1, blank)
// 3. (blank, pubkey2)
//
// A blank pubkey means that corresponding node will be unable to resurrect a
// channel on its own. For example, node1 may continue to publish recent
// updates, but node2 has fallen way behind. After marking an edge as a zombie,
// we don't want another fresh update from node1 to resurrect, as the edge can
// only become live once node2 finally sends something recent.
//
// In the case where we have neither update, we allow either party to resurrect
// the channel. If the channel were to be marked zombie again, it would be
// marked with the correct lagging channel since we received an update from only
// one side.
func makeZombiePubkeys(info *models.ChannelEdgeInfo,
e1, e2 *models.ChannelEdgePolicy) ([33]byte, [33]byte) {
switch {
// If we don't have either edge policy, we'll return both pubkeys so
// that the channel can be resurrected by either party.
case e1 == nil && e2 == nil:
return info.NodeKey1Bytes, info.NodeKey2Bytes
// If we're missing edge1, or if both edges are present but edge1 is
// older, we'll return edge1's pubkey and a blank pubkey for edge2. This
// means that only an update from edge1 will be able to resurrect the
// channel.
case e1 == nil || (e2 != nil && e1.LastUpdate.Before(e2.LastUpdate)):
return info.NodeKey1Bytes, [33]byte{}
// Otherwise, we're missing edge2 or edge2 is the older side, so we
// return a blank pubkey for edge1. In this case, only an update from
// edge2 can resurect the channel.
default:
return [33]byte{}, info.NodeKey2Bytes
}
}
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
// within the database for the referenced channel. The `flags` attribute within
// the ChannelEdgePolicy determines which of the directed edges are being
// updated. If the flag is 1, then the first node's information is being
// updated, otherwise it's the second node's information. The node ordering is
// determined by the lexicographical ordering of the identity public keys of the
// nodes on either side of the channel.
func (c *KVStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy,
op ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
var (
isUpdate1 bool
edgeNotFound bool
from, to route.Vertex
)
r := &batch.Request{
Reset: func() {
isUpdate1 = false
edgeNotFound = false
},
Update: func(tx kvdb.RwTx) error {
var err error
from, to, isUpdate1, err = updateEdgePolicy(tx, edge)
if err != nil {
log.Errorf("UpdateEdgePolicy faild: %v", err)
}
// Silence ErrEdgeNotFound so that the batch can
// succeed, but propagate the error via local state.
if errors.Is(err, ErrEdgeNotFound) {
edgeNotFound = true
return nil
}
return err
},
OnCommit: func(err error) error {
switch {
case err != nil:
return err
case edgeNotFound:
return ErrEdgeNotFound
default:
c.updateEdgeCache(edge, isUpdate1)
return nil
}
},
}
for _, f := range op {
f(r)
}
err := c.chanScheduler.Execute(r)
return from, to, err
}
func (c *KVStore) updateEdgeCache(e *models.ChannelEdgePolicy,
isUpdate1 bool) {
// If an entry for this channel is found in reject cache, we'll modify
// the entry with the updated timestamp for the direction that was just
// written. If the edge doesn't exist, we'll load the cache entry lazily
// during the next query for this edge.
if entry, ok := c.rejectCache.get(e.ChannelID); ok {
if isUpdate1 {
entry.upd1Time = e.LastUpdate.Unix()
} else {
entry.upd2Time = e.LastUpdate.Unix()
}
c.rejectCache.insert(e.ChannelID, entry)
}
// If an entry for this channel is found in channel cache, we'll modify
// the entry with the updated policy for the direction that was just
// written. If the edge doesn't exist, we'll defer loading the info and
// policies and lazily read from disk during the next query.
if channel, ok := c.chanCache.get(e.ChannelID); ok {
if isUpdate1 {
channel.Policy1 = e
} else {
channel.Policy2 = e
}
c.chanCache.insert(e.ChannelID, channel)
}
}
// updateEdgePolicy attempts to update an edge's policy within the relevant
// buckets using an existing database transaction. The returned boolean will be
// true if the updated policy belongs to node1, and false if the policy belonged
// to node2.
func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy) (
route.Vertex, route.Vertex, bool, error) {
var noVertex route.Vertex
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return noVertex, noVertex, false, ErrEdgeNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return noVertex, noVertex, false, ErrEdgeNotFound
}
// Create the channelID key be converting the channel ID
// integer into a byte slice.
var chanID [8]byte
byteOrder.PutUint64(chanID[:], edge.ChannelID)
// With the channel ID, we then fetch the value storing the two
// nodes which connect this channel edge.
nodeInfo := edgeIndex.Get(chanID[:])
if nodeInfo == nil {
return noVertex, noVertex, false, ErrEdgeNotFound
}
// Depending on the flags value passed above, either the first
// or second edge policy is being updated.
var fromNode, toNode []byte
var isUpdate1 bool
if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
fromNode = nodeInfo[:33]
toNode = nodeInfo[33:66]
isUpdate1 = true
} else {
fromNode = nodeInfo[33:66]
toNode = nodeInfo[:33]
isUpdate1 = false
}
// Finally, with the direction of the edge being updated
// identified, we update the on-disk edge representation.
err := putChanEdgePolicy(edges, edge, fromNode, toNode)
if err != nil {
return noVertex, noVertex, false, err
}
var (
fromNodePubKey route.Vertex
toNodePubKey route.Vertex
)
copy(fromNodePubKey[:], fromNode)
copy(toNodePubKey[:], toNode)
return fromNodePubKey, toNodePubKey, isUpdate1, nil
}
// isPublic determines whether the node is seen as public within the graph from
// the source node's point of view. An existing database transaction can also be
// specified.
func (c *KVStore) isPublic(tx kvdb.RTx, nodePub route.Vertex,
sourcePubKey []byte) (bool, error) {
// In order to determine whether this node is publicly advertised within
// the graph, we'll need to look at all of its edges and check whether
// they extend to any other node than the source node. errDone will be
// used to terminate the check early.
nodeIsPublic := false
errDone := errors.New("done")
err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx,
info *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy,
_ *models.ChannelEdgePolicy) error {
// If this edge doesn't extend to the source node, we'll
// terminate our search as we can now conclude that the node is
// publicly advertised within the graph due to the local node
// knowing of the current edge.
if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) &&
!bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) {
nodeIsPublic = true
return errDone
}
// Since the edge _does_ extend to the source node, we'll also
// need to ensure that this is a public edge.
if info.AuthProof != nil {
nodeIsPublic = true
return errDone
}
// Otherwise, we'll continue our search.
return nil
})
if err != nil && !errors.Is(err, errDone) {
return false, err
}
return nodeIsPublic, nil
}
// FetchLightningNodeTx attempts to look up a target node by its identity
// public key. If the node isn't found in the database, then
// ErrGraphNodeNotFound is returned. An optional transaction may be provided.
// If none is provided, then a new one will be created.
func (c *KVStore) FetchLightningNodeTx(tx kvdb.RTx, nodePub route.Vertex) (
*models.LightningNode, error) {
return c.fetchLightningNode(tx, nodePub)
}
// FetchLightningNode attempts to look up a target node by its identity public
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
// returned.
func (c *KVStore) FetchLightningNode(nodePub route.Vertex) (
*models.LightningNode, error) {
return c.fetchLightningNode(nil, nodePub)
}
// fetchLightningNode attempts to look up a target node by its identity public
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
// returned. An optional transaction may be provided. If none is provided, then
// a new one will be created.
func (c *KVStore) fetchLightningNode(tx kvdb.RTx,
nodePub route.Vertex) (*models.LightningNode, error) {
var node *models.LightningNode
fetch := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
// If a key for this serialized public key isn't found, then
// the target node doesn't exist within the database.
nodeBytes := nodes.Get(nodePub[:])
if nodeBytes == nil {
return ErrGraphNodeNotFound
}
// If the node is found, then we can de deserialize the node
// information to return to the user.
nodeReader := bytes.NewReader(nodeBytes)
n, err := deserializeLightningNode(nodeReader)
if err != nil {
return err
}
node = &n
return nil
}
if tx == nil {
err := kvdb.View(
c.db, fetch, func() {
node = nil
},
)
if err != nil {
return nil, err
}
return node, nil
}
err := fetch(tx)
if err != nil {
return nil, err
}
return node, nil
}
// HasLightningNode determines if the graph has a vertex identified by the
// target node identity public key. If the node exists in the database, a
// timestamp of when the data for the node was lasted updated is returned along
// with a true boolean. Otherwise, an empty time.Time is returned with a false
// boolean.
func (c *KVStore) HasLightningNode(nodePub [33]byte) (time.Time, bool,
error) {
var (
updateTime time.Time
exists bool
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
// If a key for this serialized public key isn't found, we can
// exit early.
nodeBytes := nodes.Get(nodePub[:])
if nodeBytes == nil {
exists = false
return nil
}
// Otherwise we continue on to obtain the time stamp
// representing the last time the data for this node was
// updated.
nodeReader := bytes.NewReader(nodeBytes)
node, err := deserializeLightningNode(nodeReader)
if err != nil {
return err
}
exists = true
updateTime = node.LastUpdate
return nil
}, func() {
updateTime = time.Time{}
exists = false
})
if err != nil {
return time.Time{}, exists, err
}
return updateTime, exists, nil
}
// nodeTraversal is used to traverse all channels of a node given by its
// public key and passes channel information into the specified callback.
func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend,
cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error) error {
traversal := func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNotFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// In order to reach all the edges for this node, we take
// advantage of the construction of the key-space within the
// edge bucket. The keys are stored in the form: pubKey ||
// chanID. Therefore, starting from a chanID of zero, we can
// scan forward in the bucket, grabbing all the edges for the
// node. Once the prefix no longer matches, then we know we're
// done.
var nodeStart [33 + 8]byte
copy(nodeStart[:], nodePub)
copy(nodeStart[33:], chanStart[:])
// Starting from the key pubKey || 0, we seek forward in the
// bucket until the retrieved key no longer has the public key
// as its prefix. This indicates that we've stepped over into
// another node's edges, so we can terminate our scan.
edgeCursor := edges.ReadCursor()
for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { //nolint:ll
// If the prefix still matches, the channel id is
// returned in nodeEdge. Channel id is used to lookup
// the node at the other end of the channel and both
// edge policies.
chanID := nodeEdge[33:]
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
if err != nil {
return err
}
outgoingPolicy, err := fetchChanEdgePolicy(
edges, chanID, nodePub,
)
if err != nil {
return err
}
otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub)
if err != nil {
return err
}
incomingPolicy, err := fetchChanEdgePolicy(
edges, chanID, otherNode[:],
)
if err != nil {
return err
}
// Finally, we execute the callback.
err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy)
if err != nil {
return err
}
}
return nil
}
// If no transaction was provided, then we'll create a new transaction
// to execute the transaction within.
if tx == nil {
return kvdb.View(db, traversal, func() {})
}
// Otherwise, we re-use the existing transaction to execute the graph
// traversal.
return traversal(tx)
}
// ForEachNodeChannel iterates through all channels of the given node,
// executing the passed callback with an edge info structure and the policies
// of each end of the channel. The first edge policy is the outgoing edge *to*
// the connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (c *KVStore) ForEachNodeChannel(nodePub route.Vertex,
cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error) error {
return nodeTraversal(nil, nodePub[:], c.db, cb)
}
// ForEachNodeChannelTx iterates through all channels of the given node,
// executing the passed callback with an edge info structure and the policies
// of each end of the channel. The first edge policy is the outgoing edge *to*
// the connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise, the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *KVStore) ForEachNodeChannelTx(tx kvdb.RTx,
nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo,
*models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error) error {
return nodeTraversal(tx, nodePub[:], c.db, cb)
}
// FetchOtherNode attempts to fetch the full LightningNode that's opposite of
// the target node in the channel. This is useful when one knows the pubkey of
// one of the nodes, and wishes to obtain the full LightningNode for the other
// end of the channel.
func (c *KVStore) FetchOtherNode(tx kvdb.RTx,
channel *models.ChannelEdgeInfo, thisNodeKey []byte) (
*models.LightningNode, error) {
// Ensure that the node passed in is actually a member of the channel.
var targetNodeBytes [33]byte
switch {
case bytes.Equal(channel.NodeKey1Bytes[:], thisNodeKey):
targetNodeBytes = channel.NodeKey2Bytes
case bytes.Equal(channel.NodeKey2Bytes[:], thisNodeKey):
targetNodeBytes = channel.NodeKey1Bytes
default:
return nil, fmt.Errorf("node not participating in this channel")
}
var targetNode *models.LightningNode
fetchNodeFunc := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
node, err := fetchLightningNode(nodes, targetNodeBytes[:])
if err != nil {
return err
}
targetNode = &node
return nil
}
// If the transaction is nil, then we'll need to create a new one,
// otherwise we can use the existing db transaction.
var err error
if tx == nil {
err = kvdb.View(c.db, fetchNodeFunc, func() {
targetNode = nil
})
} else {
err = fetchNodeFunc(tx)
}
return targetNode, err
}
// computeEdgePolicyKeys is a helper function that can be used to compute the
// keys used to index the channel edge policy info for the two nodes of the
// edge. The keys for node 1 and node 2 are returned respectively.
func computeEdgePolicyKeys(info *models.ChannelEdgeInfo) ([]byte, []byte) {
var (
node1Key [33 + 8]byte
node2Key [33 + 8]byte
)
copy(node1Key[:], info.NodeKey1Bytes[:])
copy(node2Key[:], info.NodeKey2Bytes[:])
byteOrder.PutUint64(node1Key[33:], info.ChannelID)
byteOrder.PutUint64(node2Key[33:], info.ChannelID)
return node1Key[:], node2Key[:]
}
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
// the channel identified by the funding outpoint. If the channel can't be
// found, then ErrEdgeNotFound is returned. A struct which houses the general
// information for the channel itself is returned as well as two structs that
// contain the routing policies for the channel in either direction.
func (c *KVStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
var (
edgeInfo *models.ChannelEdgeInfo
policy1 *models.ChannelEdgePolicy
policy2 *models.ChannelEdgePolicy
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First, grab the node bucket. This will be used to populate
// the Node pointers in each edge read from disk.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
// Next, grab the edge bucket which stores the edges, and also
// the index itself so we can group the directed edges together
// logically.
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// If the channel's outpoint doesn't exist within the outpoint
// index, then the edge does not exist.
chanIndex := edges.NestedReadBucket(channelPointBucket)
if chanIndex == nil {
return ErrGraphNoEdgesFound
}
var b bytes.Buffer
if err := WriteOutpoint(&b, op); err != nil {
return err
}
chanID := chanIndex.Get(b.Bytes())
if chanID == nil {
return fmt.Errorf("%w: op=%v", ErrEdgeNotFound, op)
}
// If the channel is found to exists, then we'll first retrieve
// the general information for the channel.
edge, err := fetchChanEdgeInfo(edgeIndex, chanID)
if err != nil {
return fmt.Errorf("%w: chanID=%x", err, chanID)
}
edgeInfo = &edge
// Once we have the information about the channels' parameters,
// we'll fetch the routing policies for each for the directed
// edges.
e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, chanID)
if err != nil {
return fmt.Errorf("failed to find policy: %w", err)
}
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if err != nil {
return nil, nil, nil, err
}
return edgeInfo, policy1, policy2, nil
}
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
// channel identified by the channel ID. If the channel can't be found, then
// ErrEdgeNotFound is returned. A struct which houses the general information
// for the channel itself is returned as well as two structs that contain the
// routing policies for the channel in either direction.
//
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
// the ChannelEdgeInfo will only include the public keys of each node.
func (c *KVStore) FetchChannelEdgesByID(chanID uint64) (
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
var (
edgeInfo *models.ChannelEdgeInfo
policy1 *models.ChannelEdgePolicy
policy2 *models.ChannelEdgePolicy
channelID [8]byte
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// First, grab the node bucket. This will be used to populate
// the Node pointers in each edge read from disk.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
// Next, grab the edge bucket which stores the edges, and also
// the index itself so we can group the directed edges together
// logically.
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
byteOrder.PutUint64(channelID[:], chanID)
// Now, attempt to fetch edge.
edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:])
// If it doesn't exist, we'll quickly check our zombie index to
// see if we've previously marked it as so.
if errors.Is(err, ErrEdgeNotFound) {
// If the zombie index doesn't exist, or the edge is not
// marked as a zombie within it, then we'll return the
// original ErrEdgeNotFound error.
zombieIndex := edges.NestedReadBucket(zombieBucket)
if zombieIndex == nil {
return ErrEdgeNotFound
}
isZombie, pubKey1, pubKey2 := isZombieEdge(
zombieIndex, chanID,
)
if !isZombie {
return ErrEdgeNotFound
}
// Otherwise, the edge is marked as a zombie, so we'll
// populate the edge info with the public keys of each
// party as this is the only information we have about
// it and return an error signaling so.
edgeInfo = &models.ChannelEdgeInfo{
NodeKey1Bytes: pubKey1,
NodeKey2Bytes: pubKey2,
}
return ErrZombieEdge
}
// Otherwise, we'll just return the error if any.
if err != nil {
return err
}
edgeInfo = &edge
// Then we'll attempt to fetch the accompanying policies of this
// edge.
e1, e2, err := fetchChanEdgePolicies(
edgeIndex, edges, channelID[:],
)
if err != nil {
return err
}
policy1 = e1
policy2 = e2
return nil
}, func() {
edgeInfo = nil
policy1 = nil
policy2 = nil
})
if errors.Is(err, ErrZombieEdge) {
return edgeInfo, nil, nil, err
}
if err != nil {
return nil, nil, nil, err
}
return edgeInfo, policy1, policy2, nil
}
// IsPublicNode is a helper method that determines whether the node with the
// given public key is seen as a public node in the graph from the graph's
// source node's point of view.
func (c *KVStore) IsPublicNode(pubKey [33]byte) (bool, error) {
var nodeIsPublic bool
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNodesNotFound
}
ourPubKey := nodes.Get(sourceKey)
if ourPubKey == nil {
return ErrSourceNodeNotSet
}
node, err := fetchLightningNode(nodes, pubKey[:])
if err != nil {
return err
}
nodeIsPublic, err = c.isPublic(tx, node.PubKeyBytes, ourPubKey)
return err
}, func() {
nodeIsPublic = false
})
if err != nil {
return false, err
}
return nodeIsPublic, nil
}
// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys.
func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) {
witnessScript, err := input.GenMultiSigScript(aPub, bPub)
if err != nil {
return nil, err
}
// With the witness script generated, we'll now turn it into a p2wsh
// script:
// * OP_0 <sha256(script)>
bldr := txscript.NewScriptBuilder(
txscript.WithScriptAllocSize(input.P2WSHSize),
)
bldr.AddOp(txscript.OP_0)
scriptHash := sha256.Sum256(witnessScript)
bldr.AddData(scriptHash[:])
return bldr.Script()
}
// EdgePoint couples the outpoint of a channel with the funding script that it
// creates. The FilteredChainView will use this to watch for spends of this
// edge point on chain. We require both of these values as depending on the
// concrete implementation, either the pkScript, or the out point will be used.
type EdgePoint struct {
// FundingPkScript is the p2wsh multi-sig script of the target channel.
FundingPkScript []byte
// OutPoint is the outpoint of the target channel.
OutPoint wire.OutPoint
}
// String returns a human readable version of the target EdgePoint. We return
// the outpoint directly as it is enough to uniquely identify the edge point.
func (e *EdgePoint) String() string {
return e.OutPoint.String()
}
// ChannelView returns the verifiable edge information for each active channel
// within the known channel graph. The set of UTXO's (along with their scripts)
// returned are the ones that need to be watched on chain to detect channel
// closes on the resident blockchain.
func (c *KVStore) ChannelView() ([]EdgePoint, error) {
var edgePoints []EdgePoint
if err := kvdb.View(c.db, func(tx kvdb.RTx) error {
// We're going to iterate over the entire channel index, so
// we'll need to fetch the edgeBucket to get to the index as
// it's a sub-bucket.
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
chanIndex := edges.NestedReadBucket(channelPointBucket)
if chanIndex == nil {
return ErrGraphNoEdgesFound
}
edgeIndex := edges.NestedReadBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrGraphNoEdgesFound
}
// Once we have the proper bucket, we'll range over each key
// (which is the channel point for the channel) and decode it,
// accumulating each entry.
return chanIndex.ForEach(
func(chanPointBytes, chanID []byte) error {
chanPointReader := bytes.NewReader(
chanPointBytes,
)
var chanPoint wire.OutPoint
err := ReadOutpoint(chanPointReader, &chanPoint)
if err != nil {
return err
}
edgeInfo, err := fetchChanEdgeInfo(
edgeIndex, chanID,
)
if err != nil {
return err
}
pkScript, err := genMultiSigP2WSH(
edgeInfo.BitcoinKey1Bytes[:],
edgeInfo.BitcoinKey2Bytes[:],
)
if err != nil {
return err
}
edgePoints = append(edgePoints, EdgePoint{
FundingPkScript: pkScript,
OutPoint: chanPoint,
})
return nil
},
)
}, func() {
edgePoints = nil
}); err != nil {
return nil, err
}
return edgePoints, nil
}
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
// zombie. This method is used on an ad-hoc basis, when channels need to be
// marked as zombies outside the normal pruning cycle.
func (c *KVStore) MarkEdgeZombie(chanID uint64,
pubKey1, pubKey2 [33]byte) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
err := kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return fmt.Errorf("unable to create zombie "+
"bucket: %w", err)
}
return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2)
})
if err != nil {
return err
}
c.rejectCache.remove(chanID)
c.chanCache.remove(chanID)
return nil
}
// markEdgeZombie marks an edge as a zombie within our zombie index. The public
// keys should represent the node public keys of the two parties involved in the
// edge.
func markEdgeZombie(zombieIndex kvdb.RwBucket, chanID uint64, pubKey1,
pubKey2 [33]byte) error {
var k [8]byte
byteOrder.PutUint64(k[:], chanID)
var v [66]byte
copy(v[:33], pubKey1[:])
copy(v[33:], pubKey2[:])
return zombieIndex.Put(k[:], v[:])
}
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
func (c *KVStore) MarkEdgeLive(chanID uint64) error {
c.cacheMu.Lock()
defer c.cacheMu.Unlock()
return c.markEdgeLiveUnsafe(nil, chanID)
}
// markEdgeLiveUnsafe clears an edge from the zombie index. This method can be
// called with an existing kvdb.RwTx or the argument can be set to nil in which
// case a new transaction will be created.
//
// NOTE: this method MUST only be called if the cacheMu has already been
// acquired.
func (c *KVStore) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error {
dbFn := func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
zombieIndex := edges.NestedReadWriteBucket(zombieBucket)
if zombieIndex == nil {
return nil
}
var k [8]byte
byteOrder.PutUint64(k[:], chanID)
if len(zombieIndex.Get(k[:])) == 0 {
return ErrZombieEdgeNotFound
}
return zombieIndex.Delete(k[:])
}
// If the transaction is nil, we'll create a new one. Otherwise, we use
// the existing transaction
var err error
if tx == nil {
err = kvdb.Update(c.db, dbFn, func() {})
} else {
err = dbFn(tx)
}
if err != nil {
return err
}
c.rejectCache.remove(chanID)
c.chanCache.remove(chanID)
return nil
}
// IsZombieEdge returns whether the edge is considered zombie. If it is a
// zombie, then the two node public keys corresponding to this edge are also
// returned.
func (c *KVStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
var (
isZombie bool
pubKey1, pubKey2 [33]byte
)
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return ErrGraphNoEdgesFound
}
zombieIndex := edges.NestedReadBucket(zombieBucket)
if zombieIndex == nil {
return nil
}
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
return nil
}, func() {
isZombie = false
pubKey1 = [33]byte{}
pubKey2 = [33]byte{}
})
if err != nil {
return false, [33]byte{}, [33]byte{}
}
return isZombie, pubKey1, pubKey2
}
// isZombieEdge returns whether an entry exists for the given channel in the
// zombie index. If an entry exists, then the two node public keys corresponding
// to this edge are also returned.
func isZombieEdge(zombieIndex kvdb.RBucket,
chanID uint64) (bool, [33]byte, [33]byte) {
var k [8]byte
byteOrder.PutUint64(k[:], chanID)
v := zombieIndex.Get(k[:])
if v == nil {
return false, [33]byte{}, [33]byte{}
}
var pubKey1, pubKey2 [33]byte
copy(pubKey1[:], v[:33])
copy(pubKey2[:], v[33:])
return true, pubKey1, pubKey2
}
// NumZombies returns the current number of zombie channels in the graph.
func (c *KVStore) NumZombies() (uint64, error) {
var numZombies uint64
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket)
if edges == nil {
return nil
}
zombieIndex := edges.NestedReadBucket(zombieBucket)
if zombieIndex == nil {
return nil
}
return zombieIndex.ForEach(func(_, _ []byte) error {
numZombies++
return nil
})
}, func() {
numZombies = 0
})
if err != nil {
return 0, err
}
return numZombies, nil
}
// PutClosedScid stores a SCID for a closed channel in the database. This is so
// that we can ignore channel announcements that we know to be closed without
// having to validate them and fetch a block.
func (c *KVStore) PutClosedScid(scid lnwire.ShortChannelID) error {
return kvdb.Update(c.db, func(tx kvdb.RwTx) error {
closedScids, err := tx.CreateTopLevelBucket(closedScidBucket)
if err != nil {
return err
}
var k [8]byte
byteOrder.PutUint64(k[:], scid.ToUint64())
return closedScids.Put(k[:], []byte{})
}, func() {})
}
// IsClosedScid checks whether a channel identified by the passed in scid is
// closed. This helps avoid having to perform expensive validation checks.
// TODO: Add an LRU cache to cut down on disc reads.
func (c *KVStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
var isClosed bool
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
closedScids := tx.ReadBucket(closedScidBucket)
if closedScids == nil {
return ErrClosedScidsNotFound
}
var k [8]byte
byteOrder.PutUint64(k[:], scid.ToUint64())
if closedScids.Get(k[:]) != nil {
isClosed = true
return nil
}
return nil
}, func() {
isClosed = false
})
if err != nil {
return false, err
}
return isClosed, nil
}
// GraphSession will provide the call-back with access to a NodeTraverser
// instance which can be used to perform queries against the channel graph.
func (c *KVStore) GraphSession(cb func(graph NodeTraverser) error) error {
return c.db.View(func(tx walletdb.ReadTx) error {
return cb(&nodeTraverserSession{
db: c,
tx: tx,
})
}, func() {})
}
// nodeTraverserSession implements the NodeTraverser interface but with a
// backing read only transaction for a consistent view of the graph.
type nodeTraverserSession struct {
tx kvdb.RTx
db *KVStore
}
// ForEachNodeDirectedChannel calls the callback for every channel of the given
// node.
//
// NOTE: Part of the NodeTraverser interface.
func (c *nodeTraverserSession) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {
return c.db.forEachNodeDirectedChannel(c.tx, nodePub, cb)
}
// FetchNodeFeatures returns the features of the given node. If the node is
// unknown, assume no additional features are supported.
//
// NOTE: Part of the NodeTraverser interface.
func (c *nodeTraverserSession) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
return c.db.fetchNodeFeatures(c.tx, nodePub)
}
func putLightningNode(nodeBucket, aliasBucket, updateIndex kvdb.RwBucket,
node *models.LightningNode) error {
var (
scratch [16]byte
b bytes.Buffer
)
pub, err := node.PubKey()
if err != nil {
return err
}
nodePub := pub.SerializeCompressed()
// If the node has the update time set, write it, else write 0.
updateUnix := uint64(0)
if node.LastUpdate.Unix() > 0 {
updateUnix = uint64(node.LastUpdate.Unix())
}
byteOrder.PutUint64(scratch[:8], updateUnix)
if _, err := b.Write(scratch[:8]); err != nil {
return err
}
if _, err := b.Write(nodePub); err != nil {
return err
}
// If we got a node announcement for this node, we will have the rest
// of the data available. If not we don't have more data to write.
if !node.HaveNodeAnnouncement {
// Write HaveNodeAnnouncement=0.
byteOrder.PutUint16(scratch[:2], 0)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
return nodeBucket.Put(nodePub, b.Bytes())
}
// Write HaveNodeAnnouncement=1.
byteOrder.PutUint16(scratch[:2], 1)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.R); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.G); err != nil {
return err
}
if err := binary.Write(&b, byteOrder, node.Color.B); err != nil {
return err
}
if err := wire.WriteVarString(&b, 0, node.Alias); err != nil {
return err
}
if err := node.Features.Encode(&b); err != nil {
return err
}
numAddresses := uint16(len(node.Addresses))
byteOrder.PutUint16(scratch[:2], numAddresses)
if _, err := b.Write(scratch[:2]); err != nil {
return err
}
for _, address := range node.Addresses {
if err := SerializeAddr(&b, address); err != nil {
return err
}
}
sigLen := len(node.AuthSigBytes)
if sigLen > 80 {
return fmt.Errorf("max sig len allowed is 80, had %v",
sigLen)
}
err = wire.WriteVarBytes(&b, 0, node.AuthSigBytes)
if err != nil {
return err
}
if len(node.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(node.ExtraOpaqueData))
}
err = wire.WriteVarBytes(&b, 0, node.ExtraOpaqueData)
if err != nil {
return err
}
if err := aliasBucket.Put(nodePub, []byte(node.Alias)); err != nil {
return err
}
// With the alias bucket updated, we'll now update the index that
// tracks the time series of node updates.
var indexKey [8 + 33]byte
byteOrder.PutUint64(indexKey[:8], updateUnix)
copy(indexKey[8:], nodePub)
// If there was already an old index entry for this node, then we'll
// delete the old one before we write the new entry.
if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil {
// Extract out the old update time to we can reconstruct the
// prior index key to delete it from the index.
oldUpdateTime := nodeBytes[:8]
var oldIndexKey [8 + 33]byte
copy(oldIndexKey[:8], oldUpdateTime)
copy(oldIndexKey[8:], nodePub)
if err := updateIndex.Delete(oldIndexKey[:]); err != nil {
return err
}
}
if err := updateIndex.Put(indexKey[:], nil); err != nil {
return err
}
return nodeBucket.Put(nodePub, b.Bytes())
}
func fetchLightningNode(nodeBucket kvdb.RBucket,
nodePub []byte) (models.LightningNode, error) {
nodeBytes := nodeBucket.Get(nodePub)
if nodeBytes == nil {
return models.LightningNode{}, ErrGraphNodeNotFound
}
nodeReader := bytes.NewReader(nodeBytes)
return deserializeLightningNode(nodeReader)
}
func deserializeLightningNodeCacheable(r io.Reader) (route.Vertex,
*lnwire.FeatureVector, error) {
var (
pubKey route.Vertex
features = lnwire.EmptyFeatureVector()
nodeScratch [8]byte
)
// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(nodeScratch[:]); err != nil {
return pubKey, nil, err
}
if _, err := io.ReadFull(r, pubKey[:]); err != nil {
return pubKey, nil, err
}
// Read the node announcement flag.
if _, err := r.Read(nodeScratch[:2]); err != nil {
return pubKey, nil, err
}
hasNodeAnn := byteOrder.Uint16(nodeScratch[:2])
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return pubKey, features, nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return pubKey, nil, err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return pubKey, nil, err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return pubKey, nil, err
}
if _, err := wire.ReadVarString(r, 0); err != nil {
return pubKey, nil, err
}
if err := features.Decode(r); err != nil {
return pubKey, nil, err
}
return pubKey, features, nil
}
func deserializeLightningNode(r io.Reader) (models.LightningNode, error) {
var (
node models.LightningNode
scratch [8]byte
err error
)
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.Features = lnwire.EmptyFeatureVector()
if _, err := r.Read(scratch[:]); err != nil {
return models.LightningNode{}, err
}
unix := int64(byteOrder.Uint64(scratch[:]))
node.LastUpdate = time.Unix(unix, 0)
if _, err := io.ReadFull(r, node.PubKeyBytes[:]); err != nil {
return models.LightningNode{}, err
}
if _, err := r.Read(scratch[:2]); err != nil {
return models.LightningNode{}, err
}
hasNodeAnn := byteOrder.Uint16(scratch[:2])
if hasNodeAnn == 1 {
node.HaveNodeAnnouncement = true
} else {
node.HaveNodeAnnouncement = false
}
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if !node.HaveNodeAnnouncement {
return node, nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
if err := binary.Read(r, byteOrder, &node.Color.R); err != nil {
return models.LightningNode{}, err
}
if err := binary.Read(r, byteOrder, &node.Color.G); err != nil {
return models.LightningNode{}, err
}
if err := binary.Read(r, byteOrder, &node.Color.B); err != nil {
return models.LightningNode{}, err
}
node.Alias, err = wire.ReadVarString(r, 0)
if err != nil {
return models.LightningNode{}, err
}
err = node.Features.Decode(r)
if err != nil {
return models.LightningNode{}, err
}
if _, err := r.Read(scratch[:2]); err != nil {
return models.LightningNode{}, err
}
numAddresses := int(byteOrder.Uint16(scratch[:2]))
var addresses []net.Addr
for i := 0; i < numAddresses; i++ {
address, err := DeserializeAddr(r)
if err != nil {
return models.LightningNode{}, err
}
addresses = append(addresses, address)
}
node.Addresses = addresses
node.AuthSigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig")
if err != nil {
return models.LightningNode{}, err
}
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the node as is.
node.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
case errors.Is(err, io.EOF):
case err != nil:
return models.LightningNode{}, err
}
return node, nil
}
func putChanEdgeInfo(edgeIndex kvdb.RwBucket,
edgeInfo *models.ChannelEdgeInfo, chanID [8]byte) error {
var b bytes.Buffer
if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil {
return err
}
if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil {
return err
}
if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil {
return err
}
if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil {
return err
}
authProof := edgeInfo.AuthProof
var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte
if authProof != nil {
nodeSig1 = authProof.NodeSig1Bytes
nodeSig2 = authProof.NodeSig2Bytes
bitcoinSig1 = authProof.BitcoinSig1Bytes
bitcoinSig2 = authProof.BitcoinSig2Bytes
}
if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil {
return err
}
if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil {
return err
}
if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil {
return err
}
if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil {
return err
}
if err := WriteOutpoint(&b, &edgeInfo.ChannelPoint); err != nil {
return err
}
err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity))
if err != nil {
return err
}
if _, err := b.Write(chanID[:]); err != nil {
return err
}
if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil {
return err
}
if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData))
}
err = wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData)
if err != nil {
return err
}
return edgeIndex.Put(chanID[:], b.Bytes())
}
func fetchChanEdgeInfo(edgeIndex kvdb.RBucket,
chanID []byte) (models.ChannelEdgeInfo, error) {
edgeInfoBytes := edgeIndex.Get(chanID)
if edgeInfoBytes == nil {
return models.ChannelEdgeInfo{}, ErrEdgeNotFound
}
edgeInfoReader := bytes.NewReader(edgeInfoBytes)
return deserializeChanEdgeInfo(edgeInfoReader)
}
func deserializeChanEdgeInfo(r io.Reader) (models.ChannelEdgeInfo, error) {
var (
err error
edgeInfo models.ChannelEdgeInfo
)
if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil {
return models.ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil {
return models.ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil {
return models.ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil {
return models.ChannelEdgeInfo{}, err
}
edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features")
if err != nil {
return models.ChannelEdgeInfo{}, err
}
proof := &models.ChannelAuthProof{}
proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return models.ChannelEdgeInfo{}, err
}
proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return models.ChannelEdgeInfo{}, err
}
proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return models.ChannelEdgeInfo{}, err
}
proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs")
if err != nil {
return models.ChannelEdgeInfo{}, err
}
if !proof.IsEmpty() {
edgeInfo.AuthProof = proof
}
edgeInfo.ChannelPoint = wire.OutPoint{}
if err := ReadOutpoint(r, &edgeInfo.ChannelPoint); err != nil {
return models.ChannelEdgeInfo{}, err
}
if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil {
return models.ChannelEdgeInfo{}, err
}
if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil {
return models.ChannelEdgeInfo{}, err
}
if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil {
return models.ChannelEdgeInfo{}, err
}
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the edge as is.
edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
case errors.Is(err, io.EOF):
case err != nil:
return models.ChannelEdgeInfo{}, err
}
return edgeInfo, nil
}
func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy,
from, to []byte) error {
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], edge.ChannelID)
var b bytes.Buffer
if err := serializeChanEdgePolicy(&b, edge, to); err != nil {
return err
}
// Before we write out the new edge, we'll create a new entry in the
// update index in order to keep it fresh.
updateUnix := uint64(edge.LastUpdate.Unix())
var indexKey [8 + 8]byte
byteOrder.PutUint64(indexKey[:8], updateUnix)
byteOrder.PutUint64(indexKey[8:], edge.ChannelID)
updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
if err != nil {
return err
}
// If there was already an entry for this edge, then we'll need to
// delete the old one to ensure we don't leave around any after-images.
// An unknown policy value does not have a update time recorded, so
// it also does not need to be removed.
if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil &&
!bytes.Equal(edgeBytes, unknownPolicy) {
// In order to delete the old entry, we'll need to obtain the
// *prior* update time in order to delete it. To do this, we'll
// need to deserialize the existing policy within the database
// (now outdated by the new one), and delete its corresponding
// entry within the update index. We'll ignore any
// ErrEdgePolicyOptionalFieldNotFound error, as we only need
// the channel ID and update time to delete the entry.
// TODO(halseth): get rid of these invalid policies in a
// migration.
oldEdgePolicy, err := deserializeChanEdgePolicy(
bytes.NewReader(edgeBytes),
)
if err != nil &&
!errors.Is(err, ErrEdgePolicyOptionalFieldNotFound) {
return err
}
oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix())
var oldIndexKey [8 + 8]byte
byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime)
byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID)
if err := updateIndex.Delete(oldIndexKey[:]); err != nil {
return err
}
}
if err := updateIndex.Put(indexKey[:], nil); err != nil {
return err
}
err = updateEdgePolicyDisabledIndex(
edges, edge.ChannelID,
edge.ChannelFlags&lnwire.ChanUpdateDirection > 0,
edge.IsDisabled(),
)
if err != nil {
return err
}
return edges.Put(edgeKey[:], b.Bytes())
}
// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex
// bucket by either add a new disabled ChannelEdgePolicy or remove an existing
// one.
// The direction represents the direction of the edge and disabled is used for
// deciding whether to remove or add an entry to the bucket.
// In general a channel is disabled if two entries for the same chanID exist
// in this bucket.
// Maintaining the bucket this way allows a fast retrieval of disabled
// channels, for example when prune is needed.
func updateEdgePolicyDisabledIndex(edges kvdb.RwBucket, chanID uint64,
direction bool, disabled bool) error {
var disabledEdgeKey [8 + 1]byte
byteOrder.PutUint64(disabledEdgeKey[0:], chanID)
if direction {
disabledEdgeKey[8] = 1
}
disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists(
disabledEdgePolicyBucket,
)
if err != nil {
return err
}
if disabled {
return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{})
}
return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:])
}
// putChanEdgePolicyUnknown marks the edge policy as unknown
// in the edges bucket.
func putChanEdgePolicyUnknown(edges kvdb.RwBucket, channelID uint64,
from []byte) error {
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], channelID)
if edges.Get(edgeKey[:]) != nil {
return fmt.Errorf("cannot write unknown policy for channel %v "+
" when there is already a policy present", channelID)
}
return edges.Put(edgeKey[:], unknownPolicy)
}
func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
nodePub []byte) (*models.ChannelEdgePolicy, error) {
var edgeKey [33 + 8]byte
copy(edgeKey[:], nodePub)
copy(edgeKey[33:], chanID)
edgeBytes := edges.Get(edgeKey[:])
if edgeBytes == nil {
return nil, ErrEdgeNotFound
}
// No need to deserialize unknown policy.
if bytes.Equal(edgeBytes, unknownPolicy) {
return nil, nil
}
edgeReader := bytes.NewReader(edgeBytes)
ep, err := deserializeChanEdgePolicy(edgeReader)
switch {
// If the db policy was missing an expected optional field, we return
// nil as if the policy was unknown.
case errors.Is(err, ErrEdgePolicyOptionalFieldNotFound):
return nil, nil
case err != nil:
return nil, err
}
return ep, nil
}
func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
chanID []byte) (*models.ChannelEdgePolicy, *models.ChannelEdgePolicy,
error) {
edgeInfo := edgeIndex.Get(chanID)
if edgeInfo == nil {
return nil, nil, fmt.Errorf("%w: chanID=%x", ErrEdgeNotFound,
chanID)
}
// The first node is contained within the first half of the edge
// information. We only propagate the error here and below if it's
// something other than edge non-existence.
node1Pub := edgeInfo[:33]
edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub)
if err != nil {
return nil, nil, fmt.Errorf("%w: node1Pub=%x", ErrEdgeNotFound,
node1Pub)
}
// Similarly, the second node is contained within the latter
// half of the edge information.
node2Pub := edgeInfo[33:66]
edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub)
if err != nil {
return nil, nil, fmt.Errorf("%w: node2Pub=%x", ErrEdgeNotFound,
node2Pub)
}
return edge1, edge2, nil
}
func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy,
to []byte) error {
err := wire.WriteVarBytes(w, 0, edge.SigBytes)
if err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil {
return err
}
var scratch [8]byte
updateUnix := uint64(edge.LastUpdate.Unix())
byteOrder.PutUint64(scratch[:], updateUnix)
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil {
return err
}
if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil {
return err
}
if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil {
return err
}
err = binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat))
if err != nil {
return err
}
err = binary.Write(
w, byteOrder, uint64(edge.FeeProportionalMillionths),
)
if err != nil {
return err
}
if _, err := w.Write(to); err != nil {
return err
}
// If the max_htlc field is present, we write it. To be compatible with
// older versions that wasn't aware of this field, we write it as part
// of the opaque data.
// TODO(halseth): clean up when moving to TLV.
var opaqueBuf bytes.Buffer
if edge.MessageFlags.HasMaxHtlc() {
err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC))
if err != nil {
return err
}
}
if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData))
}
if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil {
return err
}
return nil
}
func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy, error) {
// Deserialize the policy. Note that in case an optional field is not
// found, both an error and a populated policy object are returned.
edge, deserializeErr := deserializeChanEdgePolicyRaw(r)
if deserializeErr != nil &&
!errors.Is(deserializeErr, ErrEdgePolicyOptionalFieldNotFound) {
return nil, deserializeErr
}
return edge, deserializeErr
}
func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy,
error) {
edge := &models.ChannelEdgePolicy{}
var err error
edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig")
if err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil {
return nil, err
}
var scratch [8]byte
if _, err := r.Read(scratch[:]); err != nil {
return nil, err
}
unix := int64(byteOrder.Uint64(scratch[:]))
edge.LastUpdate = time.Unix(unix, 0)
if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil {
return nil, err
}
var n uint64
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.MinHTLC = lnwire.MilliSatoshi(n)
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.FeeBaseMSat = lnwire.MilliSatoshi(n)
if err := binary.Read(r, byteOrder, &n); err != nil {
return nil, err
}
edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n)
if _, err := r.Read(edge.ToNode[:]); err != nil {
return nil, err
}
// We'll try and see if there are any opaque bytes left, if not, then
// we'll ignore the EOF error and return the edge as is.
edge.ExtraOpaqueData, err = wire.ReadVarBytes(
r, 0, MaxAllowedExtraOpaqueBytes, "blob",
)
switch {
case errors.Is(err, io.ErrUnexpectedEOF):
case errors.Is(err, io.EOF):
case err != nil:
return nil, err
}
// See if optional fields are present.
if edge.MessageFlags.HasMaxHtlc() {
// The max_htlc field should be at the beginning of the opaque
// bytes.
opq := edge.ExtraOpaqueData
// If the max_htlc field is not present, it might be old data
// stored before this field was validated. We'll return the
// edge along with an error.
if len(opq) < 8 {
return edge, ErrEdgePolicyOptionalFieldNotFound
}
maxHtlc := byteOrder.Uint64(opq[:8])
edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc)
// Exclude the parsed field from the rest of the opaque data.
edge.ExtraOpaqueData = opq[8:]
}
return edge, nil
}
// chanGraphNodeTx is an implementation of the NodeRTx interface backed by the
// KVStore and a kvdb.RTx.
type chanGraphNodeTx struct {
tx kvdb.RTx
db *KVStore
node *models.LightningNode
}
// A compile-time constraint to ensure chanGraphNodeTx implements the NodeRTx
// interface.
var _ NodeRTx = (*chanGraphNodeTx)(nil)
func newChanGraphNodeTx(tx kvdb.RTx, db *KVStore,
node *models.LightningNode) *chanGraphNodeTx {
return &chanGraphNodeTx{
tx: tx,
db: db,
node: node,
}
}
// Node returns the raw information of the node.
//
// NOTE: This is a part of the NodeRTx interface.
func (c *chanGraphNodeTx) Node() *models.LightningNode {
return c.node
}
// FetchNode fetches the node with the given pub key under the same transaction
// used to fetch the current node. The returned node is also a NodeRTx and any
// operations on that NodeRTx will also be done under the same transaction.
//
// NOTE: This is a part of the NodeRTx interface.
func (c *chanGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
node, err := c.db.FetchLightningNodeTx(c.tx, nodePub)
if err != nil {
return nil, err
}
return newChanGraphNodeTx(c.tx, c.db, node), nil
}
// ForEachChannel can be used to iterate over the node's channels under
// the same transaction used to fetch the node.
//
// NOTE: This is a part of the NodeRTx interface.
func (c *chanGraphNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
return c.db.ForEachNodeChannelTx(c.tx, c.node.PubKeyBytes,
func(_ kvdb.RTx, info *models.ChannelEdgeInfo, policy1,
policy2 *models.ChannelEdgePolicy) error {
return f(info, policy1, policy2)
},
)
}
// MakeTestGraph creates a new instance of the KVStore for testing
// purposes.
func MakeTestGraph(t testing.TB, modifiers ...KVStoreOptionModifier) (
*ChannelGraph, error) {
opts := DefaultOptions()
for _, modifier := range modifiers {
modifier(opts)
}
// Next, create KVStore for the first time.
backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr")
if err != nil {
backendCleanup()
return nil, err
}
graph, err := NewChannelGraph(&Config{
KVDB: backend,
KVStoreOpts: modifiers,
})
if err != nil {
backendCleanup()
return nil, err
}
require.NoError(t, graph.Start())
t.Cleanup(func() {
_ = backend.Close()
backendCleanup()
require.NoError(t, graph.Stop())
})
return graph, nil
}
package graphdb
import (
"github.com/btcsuite/btclog"
"github.com/lightningnetwork/lnd/build"
)
// Subsystem defines the logging code for this subsystem.
const Subsystem = "GRDB"
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package models
import (
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
// feeRateParts is the total number of parts used to express fee rates.
feeRateParts = 1e6
)
// CachedEdgePolicy is a struct that only caches the information of a
// ChannelEdgePolicy that we actually use for pathfinding and therefore need to
// store in the cache.
type CachedEdgePolicy struct {
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// MessageFlags is a bitfield which indicates the presence of optional
// fields (like max_htlc) in the policy.
MessageFlags lnwire.ChanUpdateMsgFlags
// ChannelFlags is a bitfield which signals the capabilities of the
// channel as well as the directed edge this update applies to.
ChannelFlags lnwire.ChanUpdateChanFlags
// TimeLockDelta is the number of blocks this node will subtract from
// the expiry of an incoming HTLC. This value expresses the time buffer
// the node would like to HTLC exchanges.
TimeLockDelta uint16
// MinHTLC is the smallest value HTLC this node will forward, expressed
// in millisatoshi.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the largest value HTLC this node will forward, expressed
// in millisatoshi.
MaxHTLC lnwire.MilliSatoshi
// FeeBaseMSat is the base HTLC fee that will be charged for forwarding
// ANY HTLC, expressed in mSAT's.
FeeBaseMSat lnwire.MilliSatoshi
// FeeProportionalMillionths is the rate that the node will charge for
// HTLCs for each millionth of a satoshi forwarded.
FeeProportionalMillionths lnwire.MilliSatoshi
// ToNodePubKey is a function that returns the to node of a policy.
// Since we only ever store the inbound policy, this is always the node
// that we query the channels for in ForEachChannel(). Therefore, we can
// save a lot of space by not storing this information in the memory and
// instead just set this function when we copy the policy from cache in
// ForEachChannel().
ToNodePubKey func() route.Vertex
// ToNodeFeatures are the to node's features. They are never set while
// the edge is in the cache, only on the copy that is returned in
// ForEachChannel().
ToNodeFeatures *lnwire.FeatureVector
}
// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over
// the passed active payment channel. This value is currently computed as
// specified in BOLT07, but will likely change in the near future.
func (c *CachedEdgePolicy) ComputeFee(
amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts
}
// NewCachedPolicy turns a full policy into a minimal one that can be cached.
func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy {
return &CachedEdgePolicy{
ChannelID: policy.ChannelID,
MessageFlags: policy.MessageFlags,
ChannelFlags: policy.ChannelFlags,
TimeLockDelta: policy.TimeLockDelta,
MinHTLC: policy.MinHTLC,
MaxHTLC: policy.MaxHTLC,
FeeBaseMSat: policy.FeeBaseMSat,
FeeProportionalMillionths: policy.FeeProportionalMillionths,
}
}
package models
import (
"encoding/binary"
"fmt"
"io"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// serializedCircuitKeyLen is the exact length needed to encode a
// serialized CircuitKey.
serializedCircuitKeyLen = 16
)
var (
// ErrInvalidCircuitKeyLen signals that a circuit key could not be
// decoded because the byte slice is of an invalid length.
ErrInvalidCircuitKeyLen = fmt.Errorf("length of serialized circuit " +
"key must be 16 bytes")
)
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
// from the switch, and is used to purge our in-memory state of HTLCs that have
// already been processed by a link. Two list of CircuitKeys are included in
// each CommitDiff to allow a link to determine which in-memory htlcs directed
// the opening and closing of circuits in the switch's circuit map.
type CircuitKey struct {
// ChanID is the short chanid indicating the HTLC's origin.
//
// NOTE: It is fine for this value to be blank, as this indicates a
// locally-sourced payment.
ChanID lnwire.ShortChannelID
// HtlcID is the unique htlc index predominately assigned by links,
// though can also be assigned by switch in the case of locally-sourced
// payments.
HtlcID uint64
}
// SetBytes deserializes the given bytes into this CircuitKey.
func (k *CircuitKey) SetBytes(bs []byte) error {
if len(bs) != serializedCircuitKeyLen {
return ErrInvalidCircuitKeyLen
}
k.ChanID = lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(bs[:8]))
k.HtlcID = binary.BigEndian.Uint64(bs[8:])
return nil
}
// Bytes returns the serialized bytes for this circuit key.
func (k CircuitKey) Bytes() []byte {
bs := make([]byte, serializedCircuitKeyLen)
binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64())
binary.BigEndian.PutUint64(bs[8:], k.HtlcID)
return bs
}
// Encode writes a CircuitKey to the provided io.Writer.
func (k *CircuitKey) Encode(w io.Writer) error {
var scratch [serializedCircuitKeyLen]byte
binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64())
binary.BigEndian.PutUint64(scratch[8:], k.HtlcID)
_, err := w.Write(scratch[:])
return err
}
// Decode reads a CircuitKey from the provided io.Reader.
func (k *CircuitKey) Decode(r io.Reader) error {
var scratch [serializedCircuitKeyLen]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
k.ChanID = lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(scratch[:8]))
k.HtlcID = binary.BigEndian.Uint64(scratch[8:])
return nil
}
// String returns a string representation of the CircuitKey.
func (k CircuitKey) String() string {
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID)
}
// ForwardingPolicy describes the set of constraints that a given ChannelLink
// is to adhere to when forwarding HTLC's. For each incoming HTLC, this set of
// constraints will be consulted in order to ensure that adequate fees are
// paid, and our time-lock parameters are respected. In the event that an
// incoming HTLC violates any of these constraints, it is to be _rejected_ with
// the error possibly carrying along a ChannelUpdate message that includes the
// latest policy.
type ForwardingPolicy struct {
// MinHTLCOut is the smallest HTLC that is to be forwarded.
MinHTLCOut lnwire.MilliSatoshi
// MaxHTLC is the largest HTLC that is to be forwarded.
MaxHTLC lnwire.MilliSatoshi
// BaseFee is the base fee, expressed in milli-satoshi that must be
// paid for each incoming HTLC. This field, combined with FeeRate is
// used to compute the required fee for a given HTLC.
BaseFee lnwire.MilliSatoshi
// FeeRate is the fee rate, expressed in milli-satoshi that must be
// paid for each incoming HTLC. This field combined with BaseFee is
// used to compute the required fee for a given HTLC.
FeeRate lnwire.MilliSatoshi
// InboundFee is the fee that must be paid for incoming HTLCs.
InboundFee InboundFee
// TimeLockDelta is the absolute time-lock value, expressed in blocks,
// that will be subtracted from an incoming HTLC's timelock value to
// create the time-lock value for the forwarded outgoing HTLC. The
// following constraint MUST hold for an HTLC to be forwarded:
//
// * incomingHtlc.timeLock - timeLockDelta = fwdInfo.OutgoingCTLV
//
// where fwdInfo is the forwarding information extracted from the
// per-hop payload of the incoming HTLC's onion packet.
TimeLockDelta uint32
// TODO(roasbeef): add fee module inside of switch
}
package models
import "github.com/btcsuite/btcd/btcec/v2/ecdsa"
// ChannelAuthProof is the authentication proof (the signature portion) for a
// channel. Using the four signatures contained in the struct, and some
// auxiliary knowledge (the funding script, node identities, and outpoint) nodes
// on the network are able to validate the authenticity and existence of a
// channel. Each of these signatures signs the following digest: chanID ||
// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len ||
// features.
type ChannelAuthProof struct {
// nodeSig1 is a cached instance of the first node signature.
nodeSig1 *ecdsa.Signature
// NodeSig1Bytes are the raw bytes of the first node signature encoded
// in DER format.
NodeSig1Bytes []byte
// nodeSig2 is a cached instance of the second node signature.
nodeSig2 *ecdsa.Signature
// NodeSig2Bytes are the raw bytes of the second node signature
// encoded in DER format.
NodeSig2Bytes []byte
// bitcoinSig1 is a cached instance of the first bitcoin signature.
bitcoinSig1 *ecdsa.Signature
// BitcoinSig1Bytes are the raw bytes of the first bitcoin signature
// encoded in DER format.
BitcoinSig1Bytes []byte
// bitcoinSig2 is a cached instance of the second bitcoin signature.
bitcoinSig2 *ecdsa.Signature
// BitcoinSig2Bytes are the raw bytes of the second bitcoin signature
// encoded in DER format.
BitcoinSig2Bytes []byte
}
// Node1Sig is the signature using the identity key of the node that is first
// in a lexicographical ordering of the serialized public keys of the two nodes
// that created the channel.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (c *ChannelAuthProof) Node1Sig() (*ecdsa.Signature, error) {
if c.nodeSig1 != nil {
return c.nodeSig1, nil
}
sig, err := ecdsa.ParseSignature(c.NodeSig1Bytes)
if err != nil {
return nil, err
}
c.nodeSig1 = sig
return sig, nil
}
// Node2Sig is the signature using the identity key of the node that is second
// in a lexicographical ordering of the serialized public keys of the two nodes
// that created the channel.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (c *ChannelAuthProof) Node2Sig() (*ecdsa.Signature, error) {
if c.nodeSig2 != nil {
return c.nodeSig2, nil
}
sig, err := ecdsa.ParseSignature(c.NodeSig2Bytes)
if err != nil {
return nil, err
}
c.nodeSig2 = sig
return sig, nil
}
// BitcoinSig1 is the signature using the public key of the first node that was
// used in the channel's multi-sig output.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (c *ChannelAuthProof) BitcoinSig1() (*ecdsa.Signature, error) {
if c.bitcoinSig1 != nil {
return c.bitcoinSig1, nil
}
sig, err := ecdsa.ParseSignature(c.BitcoinSig1Bytes)
if err != nil {
return nil, err
}
c.bitcoinSig1 = sig
return sig, nil
}
// BitcoinSig2 is the signature using the public key of the second node that
// was used in the channel's multi-sig output.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (c *ChannelAuthProof) BitcoinSig2() (*ecdsa.Signature, error) {
if c.bitcoinSig2 != nil {
return c.bitcoinSig2, nil
}
sig, err := ecdsa.ParseSignature(c.BitcoinSig2Bytes)
if err != nil {
return nil, err
}
c.bitcoinSig2 = sig
return sig, nil
}
// IsEmpty check is the authentication proof is empty Proof is empty if at
// least one of the signatures are equal to nil.
func (c *ChannelAuthProof) IsEmpty() bool {
return len(c.NodeSig1Bytes) == 0 ||
len(c.NodeSig2Bytes) == 0 ||
len(c.BitcoinSig1Bytes) == 0 ||
len(c.BitcoinSig2Bytes) == 0
}
package models
import (
"bytes"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
)
// ChannelEdgeInfo represents a fully authenticated channel along with all its
// unique attributes. Once an authenticated channel announcement has been
// processed on the network, then an instance of ChannelEdgeInfo encapsulating
// the channels attributes is stored. The other portions relevant to routing
// policy of a channel are stored within a ChannelEdgePolicy for each direction
// of the channel.
type ChannelEdgeInfo struct {
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// ChainHash is the hash that uniquely identifies the chain that this
// channel was opened within.
//
// TODO(roasbeef): need to modify db keying for multi-chain
// * must add chain hash to prefix as well
ChainHash chainhash.Hash
// NodeKey1Bytes is the raw public key of the first node.
NodeKey1Bytes [33]byte
nodeKey1 *btcec.PublicKey
// NodeKey2Bytes is the raw public key of the first node.
NodeKey2Bytes [33]byte
nodeKey2 *btcec.PublicKey
// BitcoinKey1Bytes is the raw public key of the first node.
BitcoinKey1Bytes [33]byte
bitcoinKey1 *btcec.PublicKey
// BitcoinKey2Bytes is the raw public key of the first node.
BitcoinKey2Bytes [33]byte
bitcoinKey2 *btcec.PublicKey
// Features is an opaque byte slice that encodes the set of channel
// specific features that this channel edge supports.
Features []byte
// AuthProof is the authentication proof for this channel. This proof
// contains a set of signatures binding four identities, which attests
// to the legitimacy of the advertised channel.
AuthProof *ChannelAuthProof
// ChannelPoint is the funding outpoint of the channel. This can be
// used to uniquely identify the channel within the channel graph.
ChannelPoint wire.OutPoint
// Capacity is the total capacity of the channel, this is determined by
// the value output in the outpoint that created this channel.
Capacity btcutil.Amount
// FundingScript holds the script of the channel's funding transaction.
//
// NOTE: this is not currently persisted and so will not be present if
// the edge object is loaded from the database.
FundingScript fn.Option[[]byte]
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
}
// AddNodeKeys is a setter-like method that can be used to replace the set of
// keys for the target ChannelEdgeInfo.
func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1,
bitcoinKey2 *btcec.PublicKey) {
c.nodeKey1 = nodeKey1
copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed())
c.nodeKey2 = nodeKey2
copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed())
c.bitcoinKey1 = bitcoinKey1
copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed())
c.bitcoinKey2 = bitcoinKey2
copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed())
}
// NodeKey1 is the identity public key of the "first" node that was involved in
// the creation of this channel. A node is considered "first" if the
// lexicographical ordering the its serialized public key is "smaller" than
// that of the other node involved in channel creation.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) {
if c.nodeKey1 != nil {
return c.nodeKey1, nil
}
key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:])
if err != nil {
return nil, err
}
c.nodeKey1 = key
return key, nil
}
// NodeKey2 is the identity public key of the "second" node that was involved in
// the creation of this channel. A node is considered "second" if the
// lexicographical ordering the its serialized public key is "larger" than that
// of the other node involved in channel creation.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) {
if c.nodeKey2 != nil {
return c.nodeKey2, nil
}
key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:])
if err != nil {
return nil, err
}
c.nodeKey2 = key
return key, nil
}
// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first node, that
// was involved in the funding transaction that originally created the channel
// that this struct represents.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) {
if c.bitcoinKey1 != nil {
return c.bitcoinKey1, nil
}
key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:])
if err != nil {
return nil, err
}
c.bitcoinKey1 = key
return key, nil
}
// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second node, that
// was involved in the funding transaction that originally created the channel
// that this struct represents.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) {
if c.bitcoinKey2 != nil {
return c.bitcoinKey2, nil
}
key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:])
if err != nil {
return nil, err
}
c.bitcoinKey2 = key
return key, nil
}
// OtherNodeKeyBytes returns the node key bytes of the other end of the channel.
func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) (
[33]byte, error) {
switch {
case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey):
return c.NodeKey2Bytes, nil
case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey):
return c.NodeKey1Bytes, nil
default:
return [33]byte{}, fmt.Errorf("node not participating in " +
"this channel")
}
}
package models
import (
"fmt"
"time"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/lightningnetwork/lnd/lnwire"
)
// ChannelEdgePolicy represents a *directed* edge within the channel graph. For
// each channel in the database, there are two distinct edges: one for each
// possible direction of travel along the channel. The edges themselves hold
// information concerning fees, and minimum time-lock information which is
// utilized during path finding.
type ChannelEdgePolicy struct {
// SigBytes is the raw bytes of the signature of the channel edge
// policy. We'll only parse these if the caller needs to access the
// signature for validation purposes. Do not set SigBytes directly, but
// use SetSigBytes instead to make sure that the cache is invalidated.
SigBytes []byte
// sig is a cached fully parsed signature.
sig *ecdsa.Signature
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// LastUpdate is the last time an authenticated edge for this channel
// was received.
LastUpdate time.Time
// MessageFlags is a bitfield which indicates the presence of optional
// fields (like max_htlc) in the policy.
MessageFlags lnwire.ChanUpdateMsgFlags
// ChannelFlags is a bitfield which signals the capabilities of the
// channel as well as the directed edge this update applies to.
ChannelFlags lnwire.ChanUpdateChanFlags
// TimeLockDelta is the number of blocks this node will subtract from
// the expiry of an incoming HTLC. This value expresses the time buffer
// the node would like to HTLC exchanges.
TimeLockDelta uint16
// MinHTLC is the smallest value HTLC this node will forward, expressed
// in millisatoshi.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the largest value HTLC this node will forward, expressed
// in millisatoshi.
MaxHTLC lnwire.MilliSatoshi
// FeeBaseMSat is the base HTLC fee that will be charged for forwarding
// ANY HTLC, expressed in mSAT's.
FeeBaseMSat lnwire.MilliSatoshi
// FeeProportionalMillionths is the rate that the node will charge for
// HTLCs for each millionth of a satoshi forwarded.
FeeProportionalMillionths lnwire.MilliSatoshi
// ToNode is the public key of the node that this directed edge leads
// to. Using this pub key, the channel graph can further be traversed.
ToNode [33]byte
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData lnwire.ExtraOpaqueData
}
// Signature is a channel announcement signature, which is needed for proper
// edge policy announcement.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (c *ChannelEdgePolicy) Signature() (*ecdsa.Signature, error) {
if c.sig != nil {
return c.sig, nil
}
sig, err := ecdsa.ParseSignature(c.SigBytes)
if err != nil {
return nil, err
}
c.sig = sig
return sig, nil
}
// SetSigBytes updates the signature and invalidates the cached parsed
// signature.
func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) {
c.SigBytes = sig
c.sig = nil
}
// IsDisabled determines whether the edge has the disabled bit set.
func (c *ChannelEdgePolicy) IsDisabled() bool {
return c.ChannelFlags.IsDisabled()
}
// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over
// the passed active payment channel. This value is currently computed as
// specified in BOLT07, but will likely change in the near future.
func (c *ChannelEdgePolicy) ComputeFee(
amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts
}
// String returns a human-readable version of the channel edge policy.
func (c *ChannelEdgePolicy) String() string {
return fmt.Sprintf("ChannelID=%v, MessageFlags=%v, ChannelFlags=%v, "+
"LastUpdate=%v", c.ChannelID, c.MessageFlags, c.ChannelFlags,
c.LastUpdate)
}
package models
import "github.com/lightningnetwork/lnd/lnwire"
const (
// maxFeeRate is the maximum fee rate that we allow. It is set to allow
// a variable fee component of up to 10x the payment amount.
maxFeeRate = 10 * feeRateParts
)
type InboundFee struct {
Base int32
Rate int32
}
// NewInboundFeeFromWire constructs an inbound fee structure from a wire fee.
func NewInboundFeeFromWire(fee lnwire.Fee) InboundFee {
return InboundFee{
Base: fee.BaseFee,
Rate: fee.FeeRate,
}
}
// ToWire converts the inbound fee to a wire fee structure.
func (i *InboundFee) ToWire() lnwire.Fee {
return lnwire.Fee{
BaseFee: i.Base,
FeeRate: i.Rate,
}
}
// CalcFee calculates what the inbound fee should minimally be for forwarding
// the given amount. This amount is the total of the outgoing amount plus the
// outbound fee, which is what the inbound fee is based on.
func (i *InboundFee) CalcFee(amt lnwire.MilliSatoshi) int64 {
fee := int64(i.Base)
rate := int64(i.Rate)
// Cap the rate to prevent overflows.
switch {
case rate > maxFeeRate:
rate = maxFeeRate
case rate < -maxFeeRate:
rate = -maxFeeRate
}
// Calculate proportional component. To keep the integer math simple,
// positive fees are rounded down while negative fees are rounded up.
fee += rate * int64(amt) / feeRateParts
return fee
}
package models
import (
"fmt"
"image/color"
"net"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/lightningnetwork/lnd/lnwire"
)
// LightningNode represents an individual vertex/node within the channel graph.
// A node is connected to other nodes by one or more channel edges emanating
// from it. As the graph is directed, a node will also have an incoming edge
// attached to it for each outgoing edge.
type LightningNode struct {
// PubKeyBytes is the raw bytes of the public key of the target node.
PubKeyBytes [33]byte
pubKey *btcec.PublicKey
// HaveNodeAnnouncement indicates whether we received a node
// announcement for this particular node. If true, the remaining fields
// will be set, if false only the PubKey is known for this node.
HaveNodeAnnouncement bool
// LastUpdate is the last time the vertex information for this node has
// been updated.
LastUpdate time.Time
// Address is the TCP address this node is reachable over.
Addresses []net.Addr
// Color is the selected color for the node.
Color color.RGBA
// Alias is a nick-name for the node. The alias can be used to confirm
// a node's identity or to serve as a short ID for an address book.
Alias string
// AuthSigBytes is the raw signature under the advertised public key
// which serves to authenticate the attributes announced by this node.
AuthSigBytes []byte
// Features is the list of protocol features supported by this node.
Features *lnwire.FeatureVector
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
// TODO(roasbeef): discovery will need storage to keep it's last IP
// address and re-announce if interface changes?
// TODO(roasbeef): add update method and fetch?
}
// PubKey is the node's long-term identity public key. This key will be used to
// authenticated any advertisements/updates sent by the node.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the pubkey if absolutely necessary.
func (l *LightningNode) PubKey() (*btcec.PublicKey, error) {
if l.pubKey != nil {
return l.pubKey, nil
}
key, err := btcec.ParsePubKey(l.PubKeyBytes[:])
if err != nil {
return nil, err
}
l.pubKey = key
return key, nil
}
// AuthSig is a signature under the advertised public key which serves to
// authenticate the attributes announced by this node.
//
// NOTE: By having this method to access an attribute, we ensure we only need
// to fully deserialize the signature if absolutely necessary.
func (l *LightningNode) AuthSig() (*ecdsa.Signature, error) {
return ecdsa.ParseSignature(l.AuthSigBytes)
}
// AddPubKey is a setter-link method that can be used to swap out the public
// key for a node.
func (l *LightningNode) AddPubKey(key *btcec.PublicKey) {
l.pubKey = key
copy(l.PubKeyBytes[:], key.SerializeCompressed())
}
// NodeAnnouncement retrieves the latest node announcement of the node.
func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement,
error) {
if !l.HaveNodeAnnouncement {
return nil, fmt.Errorf("node does not have node announcement")
}
alias, err := lnwire.NewNodeAlias(l.Alias)
if err != nil {
return nil, err
}
nodeAnn := &lnwire.NodeAnnouncement{
Features: l.Features.RawFeatureVector,
NodeID: l.PubKeyBytes,
RGBColor: l.Color,
Alias: alias,
Addresses: l.Addresses,
Timestamp: uint32(l.LastUpdate.Unix()),
ExtraOpaqueData: l.ExtraOpaqueData,
}
if !signed {
return nodeAnn, nil
}
sig, err := lnwire.NewSigFromECDSARawSignature(l.AuthSigBytes)
if err != nil {
return nil, err
}
nodeAnn.Signature = sig
return nodeAnn, nil
}
// NodeFromWireAnnouncement creates a LightningNode instance from an
// lnwire.NodeAnnouncement message.
func NodeFromWireAnnouncement(msg *lnwire.NodeAnnouncement) *LightningNode {
timestamp := time.Unix(int64(msg.Timestamp), 0)
features := lnwire.NewFeatureVector(msg.Features, lnwire.Features)
return &LightningNode{
HaveNodeAnnouncement: true,
LastUpdate: timestamp,
Addresses: msg.Addresses,
PubKeyBytes: msg.NodeID,
Alias: msg.Alias.String(),
AuthSigBytes: msg.Signature.ToSignatureBytes(),
Features: features,
Color: msg.RGBColor,
ExtraOpaqueData: msg.ExtraOpaqueData,
}
}
package graphdb
import (
"fmt"
"image/color"
"net"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
)
// topologyManager holds all the fields required to manage the network topology
// subscriptions and notifications.
type topologyManager struct {
// ntfnClientCounter is an atomic counter that's used to assign unique
// notification client IDs to new clients.
ntfnClientCounter atomic.Uint64
// topologyUpdate is a channel that carries new topology updates
// messages from outside the ChannelGraph to be processed by the
// networkHandler.
topologyUpdate chan any
// topologyClients maps a client's unique notification ID to a
// topologyClient client that contains its notification dispatch
// channel.
topologyClients *lnutils.SyncMap[uint64, *topologyClient]
// ntfnClientUpdates is a channel that's used to send new updates to
// topology notification clients to the ChannelGraph. Updates either
// add a new notification client, or cancel notifications for an
// existing client.
ntfnClientUpdates chan *topologyClientUpdate
}
// newTopologyManager creates a new instance of the topologyManager.
func newTopologyManager() *topologyManager {
return &topologyManager{
topologyUpdate: make(chan any),
topologyClients: &lnutils.SyncMap[uint64, *topologyClient]{},
ntfnClientUpdates: make(chan *topologyClientUpdate),
}
}
// TopologyClient represents an intent to receive notifications from the
// channel router regarding changes to the topology of the channel graph. The
// TopologyChanges channel will be sent upon with new updates to the channel
// graph in real-time as they're encountered.
type TopologyClient struct {
// TopologyChanges is a receive only channel that new channel graph
// updates will be sent over.
//
// TODO(roasbeef): chan for each update type instead?
TopologyChanges <-chan *TopologyChange
// Cancel is a function closure that should be executed when the client
// wishes to cancel their notification intent. Doing so allows the
// ChannelRouter to free up resources.
Cancel func()
}
// topologyClientUpdate is a message sent to the channel router to either
// register a new topology client or re-register an existing client.
type topologyClientUpdate struct {
// cancel indicates if the update to the client is cancelling an
// existing client's notifications. If not then this update will be to
// register a new set of notifications.
cancel bool
// clientID is the unique identifier for this client. Any further
// updates (deleting or adding) to this notification client will be
// dispatched according to the target clientID.
clientID uint64
// ntfnChan is a *send-only* channel in which notifications should be
// sent over from router -> client.
ntfnChan chan<- *TopologyChange
}
// SubscribeTopology returns a new topology client which can be used by the
// caller to receive notifications whenever a change in the channel graph
// topology occurs. Changes that will be sent at notifications include: new
// nodes appearing, node updating their attributes, new channels, channels
// closing, and updates in the routing policies of a channel's directed edges.
func (c *ChannelGraph) SubscribeTopology() (*TopologyClient, error) {
// If the router is not yet started, return an error to avoid a
// deadlock waiting for it to handle the subscription request.
if !c.started.Load() {
return nil, fmt.Errorf("router not started")
}
// We'll first atomically obtain the next ID for this client from the
// incrementing client ID counter.
clientID := c.ntfnClientCounter.Add(1)
log.Debugf("New graph topology client subscription, client %v",
clientID)
ntfnChan := make(chan *TopologyChange, 10)
select {
case c.ntfnClientUpdates <- &topologyClientUpdate{
cancel: false,
clientID: clientID,
ntfnChan: ntfnChan,
}:
case <-c.quit:
return nil, errors.New("ChannelRouter shutting down")
}
return &TopologyClient{
TopologyChanges: ntfnChan,
Cancel: func() {
select {
case c.ntfnClientUpdates <- &topologyClientUpdate{
cancel: true,
clientID: clientID,
}:
case <-c.quit:
return
}
},
}, nil
}
// topologyClient is a data-structure use by the channel router to couple the
// client's notification channel along with a special "exit" channel that can
// be used to cancel all lingering goroutines blocked on a send to the
// notification channel.
type topologyClient struct {
// ntfnChan is a send-only channel that's used to propagate
// notification s from the channel router to an instance of a
// topologyClient client.
ntfnChan chan<- *TopologyChange
// exit is a channel that is used internally by the channel router to
// cancel any active un-consumed goroutine notifications.
exit chan struct{}
wg sync.WaitGroup
}
// notifyTopologyChange notifies all registered clients of a new change in
// graph topology in a non-blocking.
func (c *ChannelGraph) notifyTopologyChange(topologyDiff *TopologyChange) {
// notifyClient is a helper closure that will send topology updates to
// the given client.
notifyClient := func(clientID uint64, client *topologyClient) bool {
client.wg.Add(1)
log.Tracef("Sending topology notification to client=%v, "+
"NodeUpdates=%v, ChannelEdgeUpdates=%v, "+
"ClosedChannels=%v", clientID,
len(topologyDiff.NodeUpdates),
len(topologyDiff.ChannelEdgeUpdates),
len(topologyDiff.ClosedChannels))
go func(t *topologyClient) {
defer t.wg.Done()
select {
// In this case we'll try to send the notification
// directly to the upstream client consumer.
case t.ntfnChan <- topologyDiff:
// If the client cancels the notifications, then we'll
// exit early.
case <-t.exit:
// Similarly, if the ChannelRouter itself exists early,
// then we'll also exit ourselves.
case <-c.quit:
}
}(client)
// Always return true here so the following Range will iterate
// all clients.
return true
}
// Range over the set of active clients, and attempt to send the
// topology updates.
c.topologyClients.Range(notifyClient)
}
// handleTopologyUpdate is responsible for sending any topology changes
// notifications to registered clients.
//
// NOTE: must be run inside goroutine.
func (c *ChannelGraph) handleTopologyUpdate(update any) {
defer c.wg.Done()
topChange := &TopologyChange{}
err := c.addToTopologyChange(topChange, update)
if err != nil {
log.Errorf("unable to update topology change notification: %v",
err)
return
}
if topChange.isEmpty() {
return
}
c.notifyTopologyChange(topChange)
}
// TopologyChange represents a new set of modifications to the channel graph.
// Topology changes will be dispatched in real-time as the ChannelGraph
// validates and process modifications to the authenticated channel graph.
type TopologyChange struct {
// NodeUpdates is a slice of nodes which are either new to the channel
// graph, or have had their attributes updated in an authenticated
// manner.
NodeUpdates []*NetworkNodeUpdate
// ChanelEdgeUpdates is a slice of channel edges which are either newly
// opened and authenticated, or have had their routing policies
// updated.
ChannelEdgeUpdates []*ChannelEdgeUpdate
// ClosedChannels contains a slice of close channel summaries which
// described which block a channel was closed at, and also carry
// supplemental information such as the capacity of the former channel.
ClosedChannels []*ClosedChanSummary
}
// isEmpty returns true if the TopologyChange is empty. A TopologyChange is
// considered empty, if it contains no *new* updates of any type.
func (t *TopologyChange) isEmpty() bool {
return len(t.NodeUpdates) == 0 && len(t.ChannelEdgeUpdates) == 0 &&
len(t.ClosedChannels) == 0
}
// ClosedChanSummary is a summary of a channel that was detected as being
// closed by monitoring the blockchain. Once a channel's funding point has been
// spent, the channel will automatically be marked as closed by the
// ChainNotifier.
//
// TODO(roasbeef): add nodes involved?
type ClosedChanSummary struct {
// ChanID is the short-channel ID which uniquely identifies the
// channel.
ChanID uint64
// Capacity was the total capacity of the channel before it was closed.
Capacity btcutil.Amount
// ClosedHeight is the height in the chain that the channel was closed
// at.
ClosedHeight uint32
// ChanPoint is the funding point, or the multi-sig utxo which
// previously represented the channel.
ChanPoint wire.OutPoint
}
// createCloseSummaries takes in a slice of channels closed at the target block
// height and creates a slice of summaries which of each channel closure.
func createCloseSummaries(blockHeight uint32,
closedChans ...*models.ChannelEdgeInfo) []*ClosedChanSummary {
closeSummaries := make([]*ClosedChanSummary, len(closedChans))
for i, closedChan := range closedChans {
closeSummaries[i] = &ClosedChanSummary{
ChanID: closedChan.ChannelID,
Capacity: closedChan.Capacity,
ClosedHeight: blockHeight,
ChanPoint: closedChan.ChannelPoint,
}
}
return closeSummaries
}
// NetworkNodeUpdate is an update for a node within the Lightning Network. A
// NetworkNodeUpdate is sent out either when a new node joins the network, or a
// node broadcasts a new update with a newer time stamp that supersedes its
// old update. All updates are properly authenticated.
type NetworkNodeUpdate struct {
// Addresses is a slice of all the node's known addresses.
Addresses []net.Addr
// IdentityKey is the identity public key of the target node. This is
// used to encrypt onion blobs as well as to authenticate any new
// updates.
IdentityKey *btcec.PublicKey
// Alias is the alias or nick name of the node.
Alias string
// Color is the node's color in hex code format.
Color string
// Features holds the set of features the node supports.
Features *lnwire.FeatureVector
}
// ChannelEdgeUpdate is an update for a new channel within the ChannelGraph.
// This update is sent out once a new authenticated channel edge is discovered
// within the network. These updates are directional, so if a channel is fully
// public, then there will be two updates sent out: one for each direction
// within the channel. Each update will carry that particular routing edge
// policy for the channel direction.
//
// An edge is a channel in the direction of AdvertisingNode -> ConnectingNode.
type ChannelEdgeUpdate struct {
// ChanID is the unique short channel ID for the channel. This encodes
// where in the blockchain the channel's funding transaction was
// originally confirmed.
ChanID uint64
// ChanPoint is the outpoint which represents the multi-sig funding
// output for the channel.
ChanPoint wire.OutPoint
// Capacity is the capacity of the newly created channel.
Capacity btcutil.Amount
// MinHTLC is the minimum HTLC amount that this channel will forward.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the maximum HTLC amount that this channel will forward.
MaxHTLC lnwire.MilliSatoshi
// BaseFee is the base fee that will charged for all HTLC's forwarded
// across the this channel direction.
BaseFee lnwire.MilliSatoshi
// FeeRate is the fee rate that will be shared for all HTLC's forwarded
// across this channel direction.
FeeRate lnwire.MilliSatoshi
// TimeLockDelta is the time-lock expressed in blocks that will be
// added to outgoing HTLC's from incoming HTLC's. This value is the
// difference of the incoming and outgoing HTLC's time-locks routed
// through this hop.
TimeLockDelta uint16
// AdvertisingNode is the node that's advertising this edge.
AdvertisingNode *btcec.PublicKey
// ConnectingNode is the node that the advertising node connects to.
ConnectingNode *btcec.PublicKey
// Disabled, if true, signals that the channel is unavailable to relay
// payments.
Disabled bool
// ExtraOpaqueData is the set of data that was appended to this message
// to fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraOpaqueData lnwire.ExtraOpaqueData
}
// appendTopologyChange appends the passed update message to the passed
// TopologyChange, properly identifying which type of update the message
// constitutes. This function will also fetch any required auxiliary
// information required to create the topology change update from the graph
// database.
func (c *ChannelGraph) addToTopologyChange(update *TopologyChange,
msg any) error {
switch m := msg.(type) {
// Any node announcement maps directly to a NetworkNodeUpdate struct.
// No further data munging or db queries are required.
case *models.LightningNode:
pubKey, err := m.PubKey()
if err != nil {
return err
}
nodeUpdate := &NetworkNodeUpdate{
Addresses: m.Addresses,
IdentityKey: pubKey,
Alias: m.Alias,
Color: EncodeHexColor(m.Color),
Features: m.Features.Clone(),
}
update.NodeUpdates = append(update.NodeUpdates, nodeUpdate)
return nil
// We ignore initial channel announcements as we'll only send out
// updates once the individual edges themselves have been updated.
case *models.ChannelEdgeInfo:
return nil
// Any new ChannelUpdateAnnouncements will generate a corresponding
// ChannelEdgeUpdate notification.
case *models.ChannelEdgePolicy:
// We'll need to fetch the edge's information from the database
// in order to get the information concerning which nodes are
// being connected.
edgeInfo, _, _, err := c.FetchChannelEdgesByID(m.ChannelID)
if err != nil {
return errors.Errorf("unable fetch channel edge: %v",
err)
}
// If the flag is one, then the advertising node is actually
// the second node.
sourceNode := edgeInfo.NodeKey1
connectingNode := edgeInfo.NodeKey2
if m.ChannelFlags&lnwire.ChanUpdateDirection == 1 {
sourceNode = edgeInfo.NodeKey2
connectingNode = edgeInfo.NodeKey1
}
aNode, err := sourceNode()
if err != nil {
return err
}
cNode, err := connectingNode()
if err != nil {
return err
}
edgeUpdate := &ChannelEdgeUpdate{
ChanID: m.ChannelID,
ChanPoint: edgeInfo.ChannelPoint,
TimeLockDelta: m.TimeLockDelta,
Capacity: edgeInfo.Capacity,
MinHTLC: m.MinHTLC,
MaxHTLC: m.MaxHTLC,
BaseFee: m.FeeBaseMSat,
FeeRate: m.FeeProportionalMillionths,
AdvertisingNode: aNode,
ConnectingNode: cNode,
Disabled: m.ChannelFlags&lnwire.ChanUpdateDisabled != 0,
ExtraOpaqueData: m.ExtraOpaqueData,
}
// TODO(roasbeef): add bit to toggle
update.ChannelEdgeUpdates = append(update.ChannelEdgeUpdates,
edgeUpdate)
return nil
default:
return fmt.Errorf("unable to add to topology change, "+
"unknown message type %T", msg)
}
}
// EncodeHexColor takes a color and returns it in hex code format.
func EncodeHexColor(color color.RGBA) string {
return fmt.Sprintf("#%02x%02x%02x", color.R, color.G, color.B)
}
package graphdb
import "time"
const (
// DefaultRejectCacheSize is the default number of rejectCacheEntries to
// cache for use in the rejection cache of incoming gossip traffic. This
// produces a cache size of around 1MB.
DefaultRejectCacheSize = 50000
// DefaultChannelCacheSize is the default number of ChannelEdges cached
// in order to reply to gossip queries. This produces a cache size of
// around 40MB.
DefaultChannelCacheSize = 20000
// DefaultPreAllocCacheNumNodes is the default number of channels we
// assume for mainnet for pre-allocating the graph cache. As of
// September 2021, there currently are 14k nodes in a strictly pruned
// graph, so we choose a number that is slightly higher.
DefaultPreAllocCacheNumNodes = 15000
)
// chanGraphOptions holds parameters for tuning and customizing the
// ChannelGraph.
type chanGraphOptions struct {
// useGraphCache denotes whether the in-memory graph cache should be
// used or a fallback version that uses the underlying database for
// path finding.
useGraphCache bool
// preAllocCacheNumNodes is the number of nodes we expect to be in the
// graph cache, so we can pre-allocate the map accordingly.
preAllocCacheNumNodes int
}
// defaultChanGraphOptions returns a new chanGraphOptions instance populated
// with default values.
func defaultChanGraphOptions() *chanGraphOptions {
return &chanGraphOptions{
useGraphCache: true,
preAllocCacheNumNodes: DefaultPreAllocCacheNumNodes,
}
}
// ChanGraphOption describes the signature of a functional option that can be
// used to customize a ChannelGraph instance.
type ChanGraphOption func(*chanGraphOptions)
// WithUseGraphCache sets whether the in-memory graph cache should be used.
func WithUseGraphCache(use bool) ChanGraphOption {
return func(o *chanGraphOptions) {
o.useGraphCache = use
}
}
// WithPreAllocCacheNumNodes sets the number of nodes we expect to be in the
// graph cache, so we can pre-allocate the map accordingly.
func WithPreAllocCacheNumNodes(n int) ChanGraphOption {
return func(o *chanGraphOptions) {
o.preAllocCacheNumNodes = n
}
}
// KVStoreOptions holds parameters for tuning and customizing a graph.DB.
type KVStoreOptions struct {
// RejectCacheSize is the maximum number of rejectCacheEntries to hold
// in the rejection cache.
RejectCacheSize int
// ChannelCacheSize is the maximum number of ChannelEdges to hold in the
// channel cache.
ChannelCacheSize int
// BatchCommitInterval is the maximum duration the batch schedulers will
// wait before attempting to commit a pending set of updates.
BatchCommitInterval time.Duration
// NoMigration specifies that underlying backend was opened in read-only
// mode and migrations shouldn't be performed. This can be useful for
// applications that use the channeldb package as a library.
NoMigration bool
}
// DefaultOptions returns a KVStoreOptions populated with default values.
func DefaultOptions() *KVStoreOptions {
return &KVStoreOptions{
RejectCacheSize: DefaultRejectCacheSize,
ChannelCacheSize: DefaultChannelCacheSize,
NoMigration: false,
}
}
// KVStoreOptionModifier is a function signature for modifying the default
// KVStoreOptions.
type KVStoreOptionModifier func(*KVStoreOptions)
// WithRejectCacheSize sets the RejectCacheSize to n.
func WithRejectCacheSize(n int) KVStoreOptionModifier {
return func(o *KVStoreOptions) {
o.RejectCacheSize = n
}
}
// WithChannelCacheSize sets the ChannelCacheSize to n.
func WithChannelCacheSize(n int) KVStoreOptionModifier {
return func(o *KVStoreOptions) {
o.ChannelCacheSize = n
}
}
// WithBatchCommitInterval sets the batch commit interval for the interval batch
// schedulers.
func WithBatchCommitInterval(interval time.Duration) KVStoreOptionModifier {
return func(o *KVStoreOptions) {
o.BatchCommitInterval = interval
}
}
package graphdb
// rejectFlags is a compact representation of various metadata stored by the
// reject cache about a particular channel.
type rejectFlags uint8
const (
// rejectFlagExists is a flag indicating whether the channel exists,
// i.e. the channel is open and has a recent channel update. If this
// flag is not set, the channel is either a zombie or unknown.
rejectFlagExists rejectFlags = 1 << iota
// rejectFlagZombie is a flag indicating whether the channel is a
// zombie, i.e. the channel is open but has no recent channel updates.
rejectFlagZombie
)
// packRejectFlags computes the rejectFlags corresponding to the passed boolean
// values indicating whether the edge exists or is a zombie.
func packRejectFlags(exists, isZombie bool) rejectFlags {
var flags rejectFlags
if exists {
flags |= rejectFlagExists
}
if isZombie {
flags |= rejectFlagZombie
}
return flags
}
// unpack returns the booleans packed into the rejectFlags. The first indicates
// if the edge exists in our graph, the second indicates if the edge is a
// zombie.
func (f rejectFlags) unpack() (bool, bool) {
return f&rejectFlagExists == rejectFlagExists,
f&rejectFlagZombie == rejectFlagZombie
}
// rejectCacheEntry caches frequently accessed information about a channel,
// including the timestamps of its latest edge policies and whether or not the
// channel exists in the graph.
type rejectCacheEntry struct {
upd1Time int64
upd2Time int64
flags rejectFlags
}
// rejectCache is an in-memory cache used to improve the performance of
// HasChannelEdge. It caches information about the whether or channel exists, as
// well as the most recent timestamps for each policy (if they exists).
type rejectCache struct {
n int
edges map[uint64]rejectCacheEntry
}
// newRejectCache creates a new rejectCache with maximum capacity of n entries.
func newRejectCache(n int) *rejectCache {
return &rejectCache{
n: n,
edges: make(map[uint64]rejectCacheEntry, n),
}
}
// get returns the entry from the cache for chanid, if it exists.
func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) {
entry, ok := c.edges[chanid]
return entry, ok
}
// insert adds the entry to the reject cache. If an entry for chanid already
// exists, it will be replaced with the new entry. If the entry doesn't exists,
// it will be inserted to the cache, performing a random eviction if the cache
// is at capacity.
func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) {
// If entry exists, replace it.
if _, ok := c.edges[chanid]; ok {
c.edges[chanid] = entry
return
}
// Otherwise, evict an entry at random and insert.
if len(c.edges) == c.n {
for id := range c.edges {
delete(c.edges, id)
break
}
}
c.edges[chanid] = entry
}
// remove deletes an entry for chanid from the cache, if it exists.
func (c *rejectCache) remove(chanid uint64) {
delete(c.edges, chanid)
}
package graph
import "github.com/go-errors/errors"
// ErrorCode is used to represent the various errors that can occur within this
// package.
type ErrorCode uint8
const (
// ErrOutdated is returned when the routing update already have
// been applied, or a newer update is already known.
ErrOutdated ErrorCode = iota
// ErrIgnored is returned when the update have been ignored because
// this update can't bring us something new, or because a node
// announcement was given for node not found in any channel.
ErrIgnored
)
// Error is a structure that represent the error inside the graph package,
// this structure carries additional information about error code in order to
// be able distinguish errors outside of the current package.
type Error struct {
err *errors.Error
code ErrorCode
}
// Error represents errors as the string
// NOTE: Part of the error interface.
func (e *Error) Error() string {
return e.err.Error()
}
// A compile time check to ensure Error implements the error interface.
var _ error = (*Error)(nil)
// NewErrf creates a Error by the given error formatted description and
// its corresponding error code.
func NewErrf(code ErrorCode, format string, a ...interface{}) *Error {
return &Error{
code: code,
err: errors.Errorf(format, a...),
}
}
// IsError is a helper function which is needed to have ability to check that
// returned error has specific error code.
func IsError(e interface{}, codes ...ErrorCode) bool {
err, ok := e.(*Error)
if !ok {
return false
}
for _, code := range codes {
if err.code == code {
return true
}
}
return false
}
package graph
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
const Subsystem = "GRPH"
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package graph
import (
"fmt"
"sync"
"time"
)
// builderStats is a struct that tracks various updates to the graph and
// facilitates aggregate logging of the statistics.
type builderStats struct {
numChannels uint32
numUpdates uint32
numNodes uint32
lastReset time.Time
mu sync.RWMutex
}
// incNumEdges increments the number of discovered edges.
func (g *builderStats) incNumEdgesDiscovered() {
g.mu.Lock()
g.numChannels++
g.mu.Unlock()
}
// incNumUpdates increments the number of channel updates processed.
func (g *builderStats) incNumChannelUpdates() {
g.mu.Lock()
g.numUpdates++
g.mu.Unlock()
}
// incNumNodeUpdates increments the number of node updates processed.
func (g *builderStats) incNumNodeUpdates() {
g.mu.Lock()
g.numNodes++
g.mu.Unlock()
}
// Empty returns true if all stats are zero.
func (g *builderStats) Empty() bool {
g.mu.RLock()
isEmpty := g.numChannels == 0 &&
g.numUpdates == 0 &&
g.numNodes == 0
g.mu.RUnlock()
return isEmpty
}
// Reset clears any stats and sets the lastReset field to now.
func (g *builderStats) Reset() {
g.mu.Lock()
g.numChannels = 0
g.numUpdates = 0
g.numNodes = 0
g.lastReset = time.Now()
g.mu.Unlock()
}
// String returns a human-readable description of the stats.
func (g *builderStats) String() string {
g.mu.RLock()
str := fmt.Sprintf("Processed channels=%d updates=%d nodes=%d in "+
"last %v", g.numChannels, g.numUpdates, g.numNodes,
time.Since(g.lastReset))
g.mu.RUnlock()
return str
}
package htlcswitch
import (
"encoding/binary"
"io"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
)
// EmptyCircuitKey is a default value for an outgoing circuit key returned when
// a circuit's keystone has not been set. Note that this value is invalid for
// use as a keystone, since the outgoing channel id can never be equal to
// sourceHop.
var EmptyCircuitKey CircuitKey
// CircuitKey is a tuple of channel ID and HTLC ID, used to uniquely identify
// HTLCs in a circuit. Circuits are identified primarily by the circuit key of
// the incoming HTLC. However, a circuit may also be referenced by its outgoing
// circuit key after the HTLC has been forwarded via the outgoing link.
type CircuitKey = models.CircuitKey
// PaymentCircuit is used by the switch as placeholder between when the
// switch makes a forwarding decision and the outgoing link determines the
// proper HTLC ID for the local log. After the outgoing HTLC ID has been
// determined, the half circuit will be converted into a full PaymentCircuit.
type PaymentCircuit struct {
// AddRef is the forward reference of the Add update in the incoming
// link's forwarding package. This value is set on the htlcPacket of the
// returned settle/fail so that it can be removed from disk.
AddRef channeldb.AddRef
// Incoming is the circuit key identifying the incoming channel and htlc
// index from which this ADD originates.
Incoming CircuitKey
// Outgoing is the circuit key identifying the outgoing channel, and the
// HTLC index that was used to forward the ADD. It will be nil if this
// circuit's keystone has not been set.
Outgoing *CircuitKey
// PaymentHash used as unique identifier of payment.
PaymentHash [32]byte
// IncomingAmount is the value of the HTLC from the incoming link.
IncomingAmount lnwire.MilliSatoshi
// OutgoingAmount specifies the value of the HTLC leaving the switch,
// either as a payment or forwarded amount.
OutgoingAmount lnwire.MilliSatoshi
// ErrorEncrypter is used to re-encrypt the onion failure before
// sending it back to the originator of the payment.
ErrorEncrypter hop.ErrorEncrypter
// LoadedFromDisk is set true for any circuits loaded after the circuit
// map is reloaded from disk.
//
// NOTE: This value is determined implicitly during a restart. It is not
// persisted, and should never be set outside the circuit map.
LoadedFromDisk bool
}
// HasKeystone returns true if an outgoing link has assigned this circuit's
// outgoing circuit key.
func (c *PaymentCircuit) HasKeystone() bool {
return c.Outgoing != nil
}
// newPaymentCircuit initializes a payment circuit on the heap using the payment
// hash and an in-memory htlc packet.
func newPaymentCircuit(hash *[32]byte, pkt *htlcPacket) *PaymentCircuit {
var addRef channeldb.AddRef
if pkt.sourceRef != nil {
addRef = *pkt.sourceRef
}
return &PaymentCircuit{
AddRef: addRef,
Incoming: CircuitKey{
ChanID: pkt.incomingChanID,
HtlcID: pkt.incomingHTLCID,
},
PaymentHash: *hash,
IncomingAmount: pkt.incomingAmount,
OutgoingAmount: pkt.amount,
ErrorEncrypter: pkt.obfuscator,
}
}
// makePaymentCircuit initializes a payment circuit on the stack using the
// payment hash and an in-memory htlc packet.
func makePaymentCircuit(hash *[32]byte, pkt *htlcPacket) PaymentCircuit {
var addRef channeldb.AddRef
if pkt.sourceRef != nil {
addRef = *pkt.sourceRef
}
return PaymentCircuit{
AddRef: addRef,
Incoming: CircuitKey{
ChanID: pkt.incomingChanID,
HtlcID: pkt.incomingHTLCID,
},
PaymentHash: *hash,
IncomingAmount: pkt.incomingAmount,
OutgoingAmount: pkt.amount,
ErrorEncrypter: pkt.obfuscator,
}
}
// Encode writes a PaymentCircuit to the provided io.Writer.
func (c *PaymentCircuit) Encode(w io.Writer) error {
if err := c.AddRef.Encode(w); err != nil {
return err
}
if err := c.Incoming.Encode(w); err != nil {
return err
}
if _, err := w.Write(c.PaymentHash[:]); err != nil {
return err
}
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], uint64(c.IncomingAmount))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
binary.BigEndian.PutUint64(scratch[:], uint64(c.OutgoingAmount))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
// Defaults to EncrypterTypeNone.
var encrypterType hop.EncrypterType
if c.ErrorEncrypter != nil {
encrypterType = c.ErrorEncrypter.Type()
}
err := binary.Write(w, binary.BigEndian, encrypterType)
if err != nil {
return err
}
// Skip encoding of error encrypter if this half add does not have one.
if encrypterType == hop.EncrypterTypeNone {
return nil
}
return c.ErrorEncrypter.Encode(w)
}
// Decode reads a PaymentCircuit from the provided io.Reader.
func (c *PaymentCircuit) Decode(r io.Reader) error {
if err := c.AddRef.Decode(r); err != nil {
return err
}
if err := c.Incoming.Decode(r); err != nil {
return err
}
if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil {
return err
}
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
c.IncomingAmount = lnwire.MilliSatoshi(
binary.BigEndian.Uint64(scratch[:]))
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
c.OutgoingAmount = lnwire.MilliSatoshi(
binary.BigEndian.Uint64(scratch[:]))
// Read the encrypter type used for this circuit.
var encrypterType hop.EncrypterType
err := binary.Read(r, binary.BigEndian, &encrypterType)
if err != nil {
return err
}
switch encrypterType {
case hop.EncrypterTypeNone:
// No encrypter was provided, such as when the payment is
// locally initiated.
return nil
case hop.EncrypterTypeSphinx:
// Sphinx encrypter was used as this is a forwarded HTLC.
c.ErrorEncrypter = hop.NewSphinxErrorEncrypter()
case hop.EncrypterTypeMock:
// Test encrypter.
c.ErrorEncrypter = NewMockObfuscator()
case hop.EncrypterTypeIntroduction:
c.ErrorEncrypter = hop.NewIntroductionErrorEncrypter()
case hop.EncrypterTypeRelaying:
c.ErrorEncrypter = hop.NewRelayingErrorEncrypter()
default:
return UnknownEncrypterType(encrypterType)
}
return c.ErrorEncrypter.Decode(r)
}
// InKey returns the primary identifier for the circuit corresponding to the
// incoming HTLC.
func (c *PaymentCircuit) InKey() CircuitKey {
return c.Incoming
}
// OutKey returns the keystone identifying the outgoing link and HTLC ID. If the
// circuit hasn't been completed, this method returns an EmptyKeystone, which is
// an invalid outgoing circuit key. Only call this method if HasKeystone returns
// true.
func (c *PaymentCircuit) OutKey() CircuitKey {
if c.Outgoing != nil {
return *c.Outgoing
}
return EmptyCircuitKey
}
package htlcswitch
import (
"bytes"
"fmt"
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrCorruptedCircuitMap indicates that the on-disk bucketing structure
// has altered since the circuit map instance was initialized.
ErrCorruptedCircuitMap = errors.New("circuit map has been corrupted")
// ErrCircuitNotInHashIndex indicates that a particular circuit did not
// appear in the in-memory hash index.
ErrCircuitNotInHashIndex = errors.New("payment circuit not found in " +
"hash index")
// ErrUnknownCircuit signals that circuit could not be removed from the
// map because it was not found.
ErrUnknownCircuit = errors.New("unknown payment circuit")
// ErrCircuitClosing signals that an htlc has already closed this
// circuit in-memory.
ErrCircuitClosing = errors.New("circuit has already been closed")
// ErrDuplicateCircuit signals that this circuit was previously
// added.
ErrDuplicateCircuit = errors.New("duplicate circuit add")
// ErrUnknownKeystone signals that no circuit was found using the
// outgoing circuit key.
ErrUnknownKeystone = errors.New("unknown circuit keystone")
// ErrDuplicateKeystone signals that this circuit was previously
// assigned a keystone.
ErrDuplicateKeystone = errors.New("cannot add duplicate keystone")
)
// CircuitModifier is a common interface used by channel links to modify the
// contents of the circuit map maintained by the switch.
type CircuitModifier interface {
// OpenCircuits preemptively records a batch keystones that will mark
// currently pending circuits as open. These changes can be rolled back
// on restart if the outgoing Adds do not make it into a commitment
// txn.
OpenCircuits(...Keystone) error
// TrimOpenCircuits removes a channel's open channels with htlc indexes
// above `start`.
TrimOpenCircuits(chanID lnwire.ShortChannelID, start uint64) error
// DeleteCircuits removes the incoming circuit key to remove all
// persistent references to a circuit. Returns a ErrUnknownCircuit if
// any of the incoming keys are not known.
DeleteCircuits(inKeys ...CircuitKey) error
}
// CircuitLookup is a common interface used to lookup information that is stored
// in the circuit map.
type CircuitLookup interface {
// LookupCircuit queries the circuit map for the circuit identified by
// inKey.
LookupCircuit(inKey CircuitKey) *PaymentCircuit
// LookupOpenCircuit queries the circuit map for a circuit identified
// by its outgoing circuit key.
LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit
}
// CircuitFwdActions represents the forwarding decision made by the circuit
// map, and is returned from CommitCircuits. The sequence of circuits provided
// to CommitCircuits is split into three sub-sequences, allowing the caller to
// do an in-order scan, comparing the head of each subsequence, to determine
// the decision made by the circuit map.
type CircuitFwdActions struct {
// Adds is the subsequence of circuits that were successfully committed
// in the circuit map.
Adds []*PaymentCircuit
// Drops is the subsequence of circuits for which no action should be
// done.
Drops []*PaymentCircuit
// Fails is the subsequence of circuits that should be failed back by
// the calling link.
Fails []*PaymentCircuit
}
// CircuitMap is an interface for managing the construction and teardown of
// payment circuits used by the switch.
type CircuitMap interface {
CircuitModifier
CircuitLookup
// CommitCircuits attempts to add the given circuits to the circuit
// map. The list of circuits is split into three distinct
// sub-sequences, corresponding to adds, drops, and fails. Adds should
// be forwarded to the switch, while fails should be failed back
// locally within the calling link.
CommitCircuits(circuit ...*PaymentCircuit) (*CircuitFwdActions, error)
// CloseCircuit marks the circuit identified by `outKey` as closing
// in-memory, which prevents duplicate settles/fails from completing an
// open circuit twice.
CloseCircuit(outKey CircuitKey) (*PaymentCircuit, error)
// FailCircuit is used by locally failed HTLCs to mark the circuit
// identified by `inKey` as closing in-memory, which prevents duplicate
// settles/fails from being accepted for the same circuit.
FailCircuit(inKey CircuitKey) (*PaymentCircuit, error)
// LookupByPaymentHash queries the circuit map and returns all open
// circuits that use the given payment hash.
LookupByPaymentHash(hash [32]byte) []*PaymentCircuit
// NumPending returns the total number of active circuits added by
// CommitCircuits.
NumPending() int
// NumOpen returns the number of circuits with HTLCs that have been
// forwarded via an outgoing link.
NumOpen() int
}
var (
// circuitAddKey is the key used to retrieve the bucket containing
// payment circuits. A circuit records information about how to return
// a packet to the source link, potentially including an error
// encrypter for applying this hop's encryption to the payload in the
// reverse direction.
//
// Bucket hierarchy:
//
// circuitAddKey(root-bucket)
// |
// |-- <incoming-circuit-key>: <encoded bytes of PaymentCircuit>
// |-- <incoming-circuit-key>: <encoded bytes of PaymentCircuit>
// |
// ...
//
circuitAddKey = []byte("circuit-adds")
// circuitKeystoneKey is used to retrieve the bucket containing circuit
// keystones, which are set in place once a forwarded packet is
// assigned an index on an outgoing commitment txn.
//
// Bucket hierarchy:
//
// circuitKeystoneKey(root-bucket)
// |
// |-- <outgoing-circuit-key>: <incoming-circuit-key>
// |-- <outgoing-circuit-key>: <incoming-circuit-key>
// |
// ...
//
circuitKeystoneKey = []byte("circuit-keystones")
)
// circuitMap is a data structure that implements thread safe, persistent
// storage of circuit routing information. The switch consults a circuit map to
// determine where to forward returning HTLC update messages. Circuits are
// always identifiable by their incoming CircuitKey, in addition to their
// outgoing CircuitKey if the circuit is fully-opened.
type circuitMap struct {
cfg *CircuitMapConfig
mtx sync.RWMutex
// pending is an in-memory mapping of all half payment circuits, and is
// kept in sync with the on-disk contents of the circuit map.
pending map[CircuitKey]*PaymentCircuit
// opened is an in-memory mapping of all full payment circuits, which
// is also synchronized with the persistent state of the circuit map.
opened map[CircuitKey]*PaymentCircuit
// closed is an in-memory set of circuits for which the switch has
// received a settle or fail. This precedes the actual deletion of a
// circuit from disk.
closed map[CircuitKey]struct{}
// hashIndex is a volatile index that facilitates fast queries by
// payment hash against the contents of circuits. This index can be
// reconstructed entirely from the set of persisted full circuits on
// startup.
hashIndex map[[32]byte]map[CircuitKey]struct{}
}
// CircuitMapConfig houses the critical interfaces and references necessary to
// parameterize an instance of circuitMap.
type CircuitMapConfig struct {
// DB provides the persistent storage engine for the circuit map.
DB kvdb.Backend
// FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// ExtractErrorEncrypter derives the shared secret used to encrypt
// errors from the obfuscator's ephemeral public key.
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
// CheckResolutionMsg checks whether a given resolution message exists
// for the passed CircuitKey.
CheckResolutionMsg func(outKey *CircuitKey) error
}
// NewCircuitMap creates a new instance of the circuitMap.
func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) {
cm := &circuitMap{
cfg: cfg,
}
// Initialize the on-disk buckets used by the circuit map.
if err := cm.initBuckets(); err != nil {
return nil, err
}
// Delete old circuits and keystones of closed channels.
if err := cm.cleanClosedChannels(); err != nil {
return nil, err
}
// Load any previously persisted circuit into back into memory.
if err := cm.restoreMemState(); err != nil {
return nil, err
}
// Trim any keystones that were not committed in an outgoing commit txn.
//
// NOTE: This operation will be applied to the persistent state of all
// active channels. Therefore, it must be called before any links are
// created to avoid interfering with normal operation.
if err := cm.trimAllOpenCircuits(); err != nil {
return nil, err
}
return cm, nil
}
// initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup.
func (cm *circuitMap) initBuckets() error {
return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(circuitKeystoneKey)
if err != nil {
return err
}
_, err = tx.CreateTopLevelBucket(circuitAddKey)
return err
}, func() {})
}
// cleanClosedChannels deletes all circuits and keystones related to closed
// channels. It first reads all the closed channels and caches the ShortChanIDs
// into a map for fast lookup. Then it iterates the circuit bucket and keystone
// bucket and deletes items whose ChanID matches the ShortChanID.
//
// NOTE: this operation can also be built into restoreMemState since the latter
// already opens and iterates the two root buckets, circuitAddKey and
// circuitKeystoneKey. Depending on the size of the buckets, this marginal gain
// may be worth investigating. Atm, for clarity, this operation is wrapped into
// its own function.
func (cm *circuitMap) cleanClosedChannels() error {
log.Infof("Cleaning circuits from disk for closed channels")
// closedChanIDSet stores the short channel IDs for closed channels.
closedChanIDSet := make(map[lnwire.ShortChannelID]struct{})
// circuitKeySet stores the incoming circuit keys of the payment
// circuits that need to be deleted.
circuitKeySet := make(map[CircuitKey]struct{})
// keystoneKeySet stores the outgoing keys of the keystones that need
// to be deleted.
keystoneKeySet := make(map[CircuitKey]struct{})
// isClosedChannel is a helper closure that returns a bool indicating
// the chanID belongs to a closed channel.
isClosedChannel := func(chanID lnwire.ShortChannelID) bool {
// Skip if the channel ID is zero value. This has the effect
// that a zero value incoming or outgoing key will never be
// matched and its corresponding circuits or keystones are not
// deleted.
if chanID.ToUint64() == 0 {
return false
}
_, ok := closedChanIDSet[chanID]
return ok
}
// Find closed channels and cache their ShortChannelIDs into a map.
// This map will be used for looking up relative circuits and keystones.
closedChannels, err := cm.cfg.FetchClosedChannels(false)
if err != nil {
return err
}
for _, closedChannel := range closedChannels {
// Skip if the channel close is pending.
if closedChannel.IsPending {
continue
}
closedChanIDSet[closedChannel.ShortChanID] = struct{}{}
}
log.Debugf("Found %v closed channels", len(closedChanIDSet))
// Exit early if there are no closed channels.
if len(closedChanIDSet) == 0 {
log.Infof("Finished cleaning: no closed channels found, " +
"no actions taken.",
)
return nil
}
// Find the payment circuits and keystones that need to be deleted.
if err := kvdb.View(cm.cfg.DB, func(tx kvdb.RTx) error {
circuitBkt := tx.ReadBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
keystoneBkt := tx.ReadBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
// If a circuit's incoming/outgoing key prefix matches the
// ShortChanID, it will be deleted. However, if the ShortChanID
// of the incoming key is zero, the circuit will be kept as it
// indicates a locally initiated payment.
if err := circuitBkt.ForEach(func(_, v []byte) error {
circuit, err := cm.decodeCircuit(v)
if err != nil {
return err
}
// Check if the incoming channel ID can be found in the
// closed channel ID map.
if !isClosedChannel(circuit.Incoming.ChanID) {
return nil
}
circuitKeySet[circuit.Incoming] = struct{}{}
return nil
}); err != nil {
return err
}
// If a keystone's InKey or OutKey matches the short channel id
// in the closed channel ID map, it will be deleted.
err := keystoneBkt.ForEach(func(k, v []byte) error {
var (
inKey CircuitKey
outKey CircuitKey
)
// Decode the incoming and outgoing circuit keys.
if err := inKey.SetBytes(v); err != nil {
return err
}
if err := outKey.SetBytes(k); err != nil {
return err
}
// Check if the incoming channel ID can be found in the
// closed channel ID map.
if isClosedChannel(inKey.ChanID) {
// If the incoming channel is closed, we can
// skip checking on outgoing channel ID because
// this keystone will be deleted.
keystoneKeySet[outKey] = struct{}{}
// Technically the incoming keys found in
// keystone bucket should be a subset of
// circuit bucket. So a previous loop should
// have this inKey put inside circuitAddKey map
// already. We do this again to be sure the
// circuits are properly cleaned. Even this
// inKey doesn't exist in circuit bucket, we
// are fine as db deletion is a noop.
circuitKeySet[inKey] = struct{}{}
return nil
}
// Check if the outgoing channel ID can be found in the
// closed channel ID map. Notice that we need to store
// the outgoing key because it's used for db query.
//
// NOTE: We skip this if a resolution message can be
// found under the outKey. This means that there is an
// existing resolution message(s) that need to get to
// the incoming links.
if isClosedChannel(outKey.ChanID) {
// Check the resolution message store. A return
// value of nil means we need to skip deleting
// these circuits.
if cm.cfg.CheckResolutionMsg(&outKey) == nil {
return nil
}
keystoneKeySet[outKey] = struct{}{}
// Also update circuitKeySet to mark the
// payment circuit needs to be deleted.
circuitKeySet[inKey] = struct{}{}
}
return nil
})
return err
}, func() {
// Reset the sets.
circuitKeySet = make(map[CircuitKey]struct{})
keystoneKeySet = make(map[CircuitKey]struct{})
}); err != nil {
return err
}
log.Debugf("To be deleted: num_circuits=%v, num_keystones=%v",
len(circuitKeySet), len(keystoneKeySet),
)
numCircuitsDeleted := 0
numKeystonesDeleted := 0
// Delete all the circuits and keystones for closed channels.
if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
// Delete the circuit.
for inKey := range circuitKeySet {
if err := circuitBkt.Delete(inKey.Bytes()); err != nil {
return err
}
numCircuitsDeleted++
}
// Delete the keystone using the outgoing key.
for outKey := range keystoneKeySet {
err := keystoneBkt.Delete(outKey.Bytes())
if err != nil {
return err
}
numKeystonesDeleted++
}
return nil
}, func() {}); err != nil {
numCircuitsDeleted = 0
numKeystonesDeleted = 0
return err
}
log.Infof("Finished cleaning: num_closed_channel=%v, "+
"num_circuits=%v, num_keystone=%v",
len(closedChannels), numCircuitsDeleted, numKeystonesDeleted,
)
return nil
}
// restoreMemState loads the contents of the half circuit and full circuit
// buckets from disk and reconstructs the in-memory representation of the
// circuit map. Afterwards, the state of the hash index is reconstructed using
// the recovered set of full circuits. This method will also remove any stray
// keystones, which are those that appear fully-opened, but have no pending
// circuit related to the intended incoming link.
func (cm *circuitMap) restoreMemState() error {
log.Infof("Restoring in-memory circuit state from disk")
var (
opened map[CircuitKey]*PaymentCircuit
pending map[CircuitKey]*PaymentCircuit
)
if err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
// Restore any of the circuits persisted in the circuit bucket
// back into memory.
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
if err := circuitBkt.ForEach(func(_, v []byte) error {
circuit, err := cm.decodeCircuit(v)
if err != nil {
return err
}
circuit.LoadedFromDisk = true
pending[circuit.Incoming] = circuit
return nil
}); err != nil {
return err
}
// Furthermore, load the keystone bucket and resurrect the
// keystones used in any open circuits.
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
var strayKeystones []Keystone
if err := keystoneBkt.ForEach(func(k, v []byte) error {
var (
inKey CircuitKey
outKey = &CircuitKey{}
)
// Decode the incoming and outgoing circuit keys.
if err := inKey.SetBytes(v); err != nil {
return err
}
if err := outKey.SetBytes(k); err != nil {
return err
}
// Retrieve the pending circuit, set its keystone, then
// add it to the opened map.
circuit, ok := pending[inKey]
if ok {
circuit.Outgoing = outKey
opened[*outKey] = circuit
} else {
strayKeystones = append(strayKeystones, Keystone{
InKey: inKey,
OutKey: *outKey,
})
}
return nil
}); err != nil {
return err
}
// If any stray keystones were found, we'll proceed to prune
// them from the circuit map's persistent storage. This may
// manifest on older nodes that had updated channels before
// their short channel id was set properly. We believe this
// issue has been fixed, though this will allow older nodes to
// recover without additional intervention.
for _, strayKeystone := range strayKeystones {
// As a precaution, we will only cleanup keystones
// related to locally-initiated payments. If a
// documented case of stray keystones emerges for
// forwarded payments, this check should be removed, but
// with extreme caution.
if strayKeystone.OutKey.ChanID != hop.Source {
continue
}
log.Infof("Removing stray keystone: %v", strayKeystone)
err := keystoneBkt.Delete(strayKeystone.OutKey.Bytes())
if err != nil {
return err
}
}
return nil
}, func() {
opened = make(map[CircuitKey]*PaymentCircuit)
pending = make(map[CircuitKey]*PaymentCircuit)
}); err != nil {
return err
}
cm.pending = pending
cm.opened = opened
cm.closed = make(map[CircuitKey]struct{})
log.Infof("Payment circuits loaded: num_pending=%v, num_open=%v",
len(pending), len(opened))
// Finally, reconstruct the hash index by running through our set of
// open circuits.
cm.hashIndex = make(map[[32]byte]map[CircuitKey]struct{})
for _, circuit := range opened {
cm.addCircuitToHashIndex(circuit)
}
return nil
}
// decodeCircuit reconstructs an in-memory payment circuit from a byte slice.
// The byte slice is assumed to have been generated by the circuit's Encode
// method. If the decoding is successful, the onion obfuscator will be
// reextracted, since it is not stored in plaintext on disk.
func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) {
var circuit = &PaymentCircuit{}
circuitReader := bytes.NewReader(v)
if err := circuit.Decode(circuitReader); err != nil {
return nil, err
}
// If the error encrypter is nil, this is locally-source payment so
// there is no encrypter.
if circuit.ErrorEncrypter == nil {
return circuit, nil
}
// Otherwise, we need to reextract the encrypter, so that the shared
// secret is rederived from what was decoded.
err := circuit.ErrorEncrypter.Reextract(
cm.cfg.ExtractErrorEncrypter,
)
if err != nil {
return nil, err
}
return circuit, nil
}
// trimAllOpenCircuits reads the set of active channels from disk and trims
// keystones for any non-pending channels using the next unallocated htlc index.
// This method is intended to be called on startup. Each link will also trim
// it's own circuits upon startup.
//
// NOTE: This operation will be applied to the persistent state of all active
// channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.cfg.FetchAllOpenChannels()
if err != nil {
return err
}
for _, activeChannel := range activeChannels {
if activeChannel.IsPending {
continue
}
// First, skip any channels that have not been assigned their
// final channel identifier, otherwise we would try to trim
// htlcs belonging to the all-zero, hop.Source ID.
chanID := activeChannel.ShortChanID()
if chanID == hop.Source {
continue
}
// Next, retrieve the next unallocated htlc index, which bounds
// the cutoff of confirmed htlc indexes.
start, err := activeChannel.NextLocalHtlcIndex()
if err != nil {
return err
}
// Finally, remove all pending circuits above at or above the
// next unallocated local htlc indexes. This has the effect of
// reverting any circuits that have either not been locked in,
// or had not been included in a pending commitment.
err = cm.TrimOpenCircuits(chanID, start)
if err != nil {
return err
}
}
return nil
}
// TrimOpenCircuits removes a channel's keystones above the short chan id's
// highest committed htlc index. This has the effect of returning those
// circuits to a half-open state. Since opening of circuits is done in advance
// of actually committing the Add htlcs into a commitment txn, this allows
// circuits to be opened preemptively, since we can roll them back after any
// failures.
func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
start uint64) error {
log.Infof("Trimming open circuits for chan_id=%v, start_htlc_id=%v",
chanID, start)
var trimmedOutKeys []CircuitKey
// Scan forward from the last unacked htlc id, stopping as soon as we
// don't find any more. Outgoing htlc id's must be assigned in order,
// so there should never be disjoint segments of keystones to trim.
cm.mtx.Lock()
for i := start; ; i++ {
outKey := CircuitKey{
ChanID: chanID,
HtlcID: i,
}
circuit, ok := cm.opened[outKey]
if !ok {
break
}
circuit.Outgoing = nil
delete(cm.opened, outKey)
trimmedOutKeys = append(trimmedOutKeys, outKey)
cm.removeCircuitFromHashIndex(circuit)
}
cm.mtx.Unlock()
if len(trimmedOutKeys) == 0 {
return nil
}
return kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
for _, outKey := range trimmedOutKeys {
err := keystoneBkt.Delete(outKey.Bytes())
if err != nil {
return err
}
}
return nil
}, func() {})
}
// LookupCircuit queries the circuit map for the circuit identified by its
// incoming circuit key. Returns nil if there is no such circuit.
func (cm *circuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return cm.pending[inKey]
}
// LookupOpenCircuit searches for the circuit identified by its outgoing circuit
// key.
func (cm *circuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return cm.opened[outKey]
}
// LookupByPaymentHash looks up and returns any payment circuits with a given
// payment hash.
func (cm *circuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
var circuits []*PaymentCircuit
if circuitSet, ok := cm.hashIndex[hash]; ok {
// Iterate over the outgoing circuit keys found with this hash,
// and retrieve the circuit from the opened map.
circuits = make([]*PaymentCircuit, 0, len(circuitSet))
for key := range circuitSet {
if circuit, ok := cm.opened[key]; ok {
circuits = append(circuits, circuit)
}
}
}
return circuits
}
// CommitCircuits accepts any number of circuits and persistently adds them to
// the switch's circuit map. The method returns a list of circuits that had not
// been seen prior by the switch. A link should only forward HTLCs corresponding
// to the returned circuits to the switch.
//
// NOTE: This method uses batched writes to improve performance, gains will only
// be realized if it is called concurrently from separate goroutines.
func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
*CircuitFwdActions, error) {
inKeys := make([]CircuitKey, 0, len(circuits))
for _, circuit := range circuits {
inKeys = append(inKeys, circuit.Incoming)
}
log.Tracef("Committing fresh circuits: %v", lnutils.SpewLogClosure(
inKeys))
actions := &CircuitFwdActions{}
// If an empty list was passed, return early to avoid grabbing the lock.
if len(circuits) == 0 {
return actions, nil
}
// First, we reconcile the provided circuits with our set of pending
// circuits to construct a set of new circuits that need to be written
// to disk. The circuit's pointer is stored so that we only permit this
// exact circuit to be forwarded through the switch. If a circuit is
// already pending, the htlc will be reforwarded by the switch.
//
// NOTE: We track an additional addFails subsequence, which permits us
// to fail back all packets that weren't dropped if we encounter an
// error when committing the circuits.
cm.mtx.Lock()
var adds, drops, fails, addFails []*PaymentCircuit
for _, circuit := range circuits {
inKey := circuit.InKey()
if foundCircuit, ok := cm.pending[inKey]; ok {
switch {
// This circuit has a keystone, it's waiting for a
// response from the remote peer on the outgoing link.
// Drop it like it's hot, ensure duplicates get caught.
case foundCircuit.HasKeystone():
drops = append(drops, circuit)
// If no keystone is set and the switch has not been
// restarted, the corresponding packet should still be
// in the outgoing link's mailbox. It will be delivered
// if it comes online before the switch goes down.
//
// NOTE: Dropping here prevents a flapping, incoming
// link from failing a duplicate add while it is still
// in the server's memory mailboxes.
case !foundCircuit.LoadedFromDisk:
drops = append(drops, circuit)
// Otherwise, the in-mem packet has been lost due to a
// restart. It is now safe to send back a failure along
// the incoming link. The incoming link should be able
// detect and ignore duplicate packets of this type.
default:
fails = append(fails, circuit)
addFails = append(addFails, circuit)
}
continue
}
cm.pending[inKey] = circuit
adds = append(adds, circuit)
addFails = append(addFails, circuit)
}
cm.mtx.Unlock()
// If all circuits are dropped or failed, we are done.
if len(adds) == 0 {
actions.Drops = drops
actions.Fails = fails
return actions, nil
}
// Now, optimistically serialize the circuits to add.
var bs = make([]bytes.Buffer, len(adds))
for i, circuit := range adds {
if err := circuit.Encode(&bs[i]); err != nil {
actions.Drops = drops
actions.Fails = addFails
return actions, err
}
}
// Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance.
err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
for i, circuit := range adds {
inKeyBytes := circuit.InKey().Bytes()
circuitBytes := bs[i].Bytes()
err := circuitBkt.Put(inKeyBytes, circuitBytes)
if err != nil {
return err
}
}
return nil
})
// Return if the write succeeded.
if err == nil {
actions.Adds = adds
actions.Drops = drops
actions.Fails = fails
return actions, nil
}
// Otherwise, rollback the circuits added to the pending set if the
// write failed.
cm.mtx.Lock()
for _, circuit := range adds {
delete(cm.pending, circuit.InKey())
}
cm.mtx.Unlock()
// Since our write failed, we will return the dropped packets and mark
// all other circuits as failed.
actions.Drops = drops
actions.Fails = addFails
return actions, err
}
// Keystone is a tuple binding an incoming and outgoing CircuitKey. Keystones
// are preemptively written by an outgoing link before signing a new commitment
// state, and cements which HTLCs we are awaiting a response from a remote
// peer.
type Keystone struct {
InKey CircuitKey
OutKey CircuitKey
}
// String returns a human readable description of the Keystone.
func (k Keystone) String() string {
return fmt.Sprintf("%s --> %s", k.InKey, k.OutKey)
}
// OpenCircuits sets the outgoing circuit key for the circuit identified by
// inKey, persistently marking the circuit as opened. After the changes have
// been persisted, the circuit map's in-memory indexes are updated so that this
// circuit can be queried using LookupByKeystone or LookupByPaymentHash.
func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
if len(keystones) == 0 {
return nil
}
log.Tracef("Opening finalized circuits: %v", lnutils.SpewLogClosure(
keystones))
// Check that all keystones correspond to committed-but-unopened
// circuits.
cm.mtx.RLock()
openedCircuits := make([]*PaymentCircuit, 0, len(keystones))
for _, ks := range keystones {
if _, ok := cm.opened[ks.OutKey]; ok {
cm.mtx.RUnlock()
return ErrDuplicateKeystone
}
circuit, ok := cm.pending[ks.InKey]
if !ok {
cm.mtx.RUnlock()
return ErrUnknownCircuit
}
openedCircuits = append(openedCircuits, circuit)
}
cm.mtx.RUnlock()
err := kvdb.Update(cm.cfg.DB, func(tx kvdb.RwTx) error {
// Now, load the circuit bucket to which we will write the
// already serialized circuit.
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
for _, ks := range keystones {
outBytes := ks.OutKey.Bytes()
inBytes := ks.InKey.Bytes()
err := keystoneBkt.Put(outBytes, inBytes)
if err != nil {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
cm.mtx.Lock()
for i, circuit := range openedCircuits {
ks := keystones[i]
// Since our persistent operation was successful, we can now
// modify the in memory representations. Set the outgoing
// circuit key on our pending circuit, add the same circuit to
// set of opened circuits, and add this circuit to the hash
// index.
circuit.Outgoing = &CircuitKey{}
*circuit.Outgoing = ks.OutKey
cm.opened[ks.OutKey] = circuit
cm.addCircuitToHashIndex(circuit)
}
cm.mtx.Unlock()
return nil
}
// addCirciutToHashIndex inserts a circuit into the circuit map's hash index, so
// that it can be queried using LookupByPaymentHash.
func (cm *circuitMap) addCircuitToHashIndex(c *PaymentCircuit) {
if _, ok := cm.hashIndex[c.PaymentHash]; !ok {
cm.hashIndex[c.PaymentHash] = make(map[CircuitKey]struct{})
}
cm.hashIndex[c.PaymentHash][c.OutKey()] = struct{}{}
}
// FailCircuit marks the circuit identified by `inKey` as closing in-memory,
// which prevents duplicate settles/fails from completing an open circuit twice.
func (cm *circuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit, error) {
cm.mtx.Lock()
defer cm.mtx.Unlock()
circuit, ok := cm.pending[inKey]
if !ok {
return nil, ErrUnknownCircuit
}
_, ok = cm.closed[inKey]
if ok {
return nil, ErrCircuitClosing
}
cm.closed[inKey] = struct{}{}
return circuit, nil
}
// CloseCircuit marks the circuit identified by `outKey` as closing in-memory,
// which prevents duplicate settles/fails from completing an open
// circuit twice.
func (cm *circuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit, error) {
cm.mtx.Lock()
defer cm.mtx.Unlock()
circuit, ok := cm.opened[outKey]
if !ok {
return nil, ErrUnknownCircuit
}
_, ok = cm.closed[circuit.Incoming]
if ok {
return nil, ErrCircuitClosing
}
cm.closed[circuit.Incoming] = struct{}{}
return circuit, nil
}
// DeleteCircuits destroys the target circuits by removing them from the circuit
// map, additionally removing the circuits' keystones if any HTLCs were
// forwarded through an outgoing link. The circuits should be identified by its
// incoming circuit key. If a given circuit is not found in the circuit map, it
// will be ignored from the query. This would typically indicate that the
// circuit was already cleaned up at a different point in time.
func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
log.Tracef("Deleting resolved circuits: %v", lnutils.SpewLogClosure(
inKeys))
var (
closingCircuits = make(map[CircuitKey]struct{})
removedCircuits = make(map[CircuitKey]*PaymentCircuit)
)
cm.mtx.Lock()
// Remove any references to the circuits from memory, keeping track of
// which circuits were removed, and which ones had been marked closed.
// This can be used to restore these entries later if the persistent
// removal fails.
for _, inKey := range inKeys {
circuit, ok := cm.pending[inKey]
if !ok {
continue
}
delete(cm.pending, inKey)
if _, ok := cm.closed[inKey]; ok {
closingCircuits[inKey] = struct{}{}
delete(cm.closed, inKey)
}
if circuit.HasKeystone() {
delete(cm.opened, circuit.OutKey())
cm.removeCircuitFromHashIndex(circuit)
}
removedCircuits[inKey] = circuit
}
cm.mtx.Unlock()
err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the
// outgoing circuit key.
if circuit.HasKeystone() {
keystoneBkt := tx.ReadWriteBucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
outKey := circuit.OutKey()
err := keystoneBkt.Delete(outKey.Bytes())
if err != nil {
return err
}
}
// Remove the circuit itself based on the incoming
// circuit key.
circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
inKey := circuit.InKey()
if err := circuitBkt.Delete(inKey.Bytes()); err != nil {
return err
}
}
return nil
})
// Return if the write succeeded.
if err == nil {
return nil
}
// If the persistent changes failed, restore the circuit map to it's
// previous state.
cm.mtx.Lock()
for inKey, circuit := range removedCircuits {
cm.pending[inKey] = circuit
if _, ok := closingCircuits[inKey]; ok {
cm.closed[inKey] = struct{}{}
}
if circuit.HasKeystone() {
cm.opened[circuit.OutKey()] = circuit
cm.addCircuitToHashIndex(circuit)
}
}
cm.mtx.Unlock()
return err
}
// removeCircuitFromHashIndex removes the given circuit from the hash index,
// pruning any unnecessary memory optimistically.
func (cm *circuitMap) removeCircuitFromHashIndex(c *PaymentCircuit) {
// Locate bucket containing this circuit's payment hashes.
circuitsWithHash, ok := cm.hashIndex[c.PaymentHash]
if !ok {
return
}
outKey := c.OutKey()
// Remove this circuit from the set of circuitsWithHash.
delete(circuitsWithHash, outKey)
// Prune the payment hash bucket if no other entries remain.
if len(circuitsWithHash) == 0 {
delete(cm.hashIndex, c.PaymentHash)
}
}
// NumPending returns the number of active circuits added to the circuit map.
func (cm *circuitMap) NumPending() int {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return len(cm.pending)
}
// NumOpen returns the number of circuits that have been opened by way of
// setting their keystones. This is the number of HTLCs that are waiting for a
// settle/fail response from a remote peer.
func (cm *circuitMap) NumOpen() int {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return len(cm.opened)
}
package htlcswitch
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync"
"sync/atomic"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/kvdb"
)
const (
// defaultDbDirectory is the default directory where our decayed log
// will store our (sharedHash, CLTV) key-value pairs.
defaultDbDirectory = "sharedhashes"
)
var (
// sharedHashBucket is a bucket which houses the first HashPrefixSize
// bytes of a received HTLC's hashed shared secret as the key and the HTLC's
// CLTV expiry as the value.
sharedHashBucket = []byte("shared-hash")
// batchReplayBucket is a bucket that maps batch identifiers to
// serialized ReplaySets. This is used to give idempotency in the event
// that a batch is processed more than once.
batchReplayBucket = []byte("batch-replay")
)
var (
// ErrDecayedLogInit is used to indicate a decayed log failed to create
// the proper bucketing structure on startup.
ErrDecayedLogInit = errors.New("unable to initialize decayed log")
// ErrDecayedLogCorrupted signals that the anticipated bucketing
// structure has diverged since initialization.
ErrDecayedLogCorrupted = errors.New("decayed log structure corrupted")
)
// NewBoltBackendCreator returns a function that creates a new bbolt backend for
// the decayed logs database.
func NewBoltBackendCreator(dbPath,
dbFileName string) func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) {
return func(boltCfg *kvdb.BoltConfig) (kvdb.Backend, error) {
cfg := &kvdb.BoltBackendConfig{
DBPath: dbPath,
DBFileName: dbFileName,
NoFreelistSync: boltCfg.NoFreelistSync,
AutoCompact: boltCfg.AutoCompact,
AutoCompactMinAge: boltCfg.AutoCompactMinAge,
DBTimeout: boltCfg.DBTimeout,
}
// Use default path for log database.
if dbPath == "" {
cfg.DBPath = defaultDbDirectory
}
db, err := kvdb.GetBoltBackend(cfg)
if err != nil {
return nil, fmt.Errorf("could not open boltdb: %w", err)
}
return db, nil
}
}
// DecayedLog implements the PersistLog interface. It stores the first
// HashPrefixSize bytes of a sha256-hashed shared secret along with a node's
// CLTV value. It is a decaying log meaning there will be a garbage collector
// to collect entries which are expired according to their stored CLTV value
// and the current block height. DecayedLog wraps boltdb for simplicity and
// batches writes to the database to decrease write contention.
type DecayedLog struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
db kvdb.Backend
notifier chainntnfs.ChainNotifier
wg sync.WaitGroup
quit chan struct{}
}
// NewDecayedLog creates a new DecayedLog, which caches recently seen hash
// shared secrets. Entries are evicted as their cltv expires using block epochs
// from the given notifier.
func NewDecayedLog(db kvdb.Backend,
notifier chainntnfs.ChainNotifier) *DecayedLog {
return &DecayedLog{
db: db,
notifier: notifier,
quit: make(chan struct{}),
}
}
// Start opens the database we will be using to store hashed shared secrets.
// It also starts the garbage collector in a goroutine to remove stale
// database entries.
func (d *DecayedLog) Start() error {
if !atomic.CompareAndSwapInt32(&d.started, 0, 1) {
return nil
}
// Initialize the primary buckets used by the decayed log.
if err := d.initBuckets(); err != nil {
return err
}
// Start garbage collector.
if d.notifier != nil {
epochClient, err := d.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return fmt.Errorf("unable to register for epoch "+
"notifications: %v", err)
}
d.wg.Add(1)
go d.garbageCollector(epochClient)
}
return nil
}
// initBuckets initializes the primary buckets used by the decayed log, namely
// the shared hash bucket, and batch replay
func (d *DecayedLog) initBuckets() error {
return kvdb.Update(d.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(sharedHashBucket)
if err != nil {
return ErrDecayedLogInit
}
_, err = tx.CreateTopLevelBucket(batchReplayBucket)
if err != nil {
return ErrDecayedLogInit
}
return nil
}, func() {})
}
// Stop halts the garbage collector and closes boltdb.
func (d *DecayedLog) Stop() error {
log.Debugf("DecayedLog shutting down...")
defer log.Debugf("DecayedLog shutdown complete")
if !atomic.CompareAndSwapInt32(&d.stopped, 0, 1) {
return nil
}
// Stop garbage collector.
close(d.quit)
d.wg.Wait()
return nil
}
// garbageCollector deletes entries from sharedHashBucket whose expiry height
// has already past. This function MUST be run as a goroutine.
func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) {
defer d.wg.Done()
defer epochClient.Cancel()
for {
select {
case epoch, ok := <-epochClient.Epochs:
if !ok {
// Block epoch was canceled, shutting down.
log.Infof("Block epoch canceled, " +
"decaying hash log shutting down")
return
}
// Perform a bout of garbage collection using the
// epoch's block height.
height := uint32(epoch.Height)
numExpired, err := d.gcExpiredHashes(height)
if err != nil {
log.Errorf("unable to expire hashes at "+
"height=%d", height)
}
if numExpired > 0 {
log.Infof("Garbage collected %v shared "+
"secret hashes at height=%v",
numExpired, height)
}
case <-d.quit:
// Received shutdown request.
log.Infof("Decaying hash log received " +
"shutdown request")
return
}
}
}
// gcExpiredHashes purges the decaying log of all entries whose CLTV expires
// below the provided height.
func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) {
var numExpiredHashes uint32
err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
numExpiredHashes = 0
// Grab the shared hash bucket
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashBucket " +
"is nil")
}
var expiredCltv [][]byte
if err := sharedHashes.ForEach(func(k, v []byte) error {
// Deserialize the CLTV value for this entry.
cltv := uint32(binary.BigEndian.Uint32(v))
if cltv < height {
// This CLTV is expired. We must add it to an
// array which we'll loop over and delete every
// hash contained from the db.
expiredCltv = append(expiredCltv, k)
numExpiredHashes++
}
return nil
}); err != nil {
return err
}
// Delete every item in the array. This must
// be done explicitly outside of the ForEach
// function for safety reasons.
for _, hash := range expiredCltv {
err := sharedHashes.Delete(hash)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return 0, err
}
return numExpiredHashes, nil
}
// Delete removes a <shared secret hash, CLTV> key-pair from the
// sharedHashBucket.
func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error {
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
return sharedHashes.Delete(hash[:])
})
}
// Get retrieves the CLTV of a processed HTLC given the first 20 bytes of the
// Sha-256 hash of the shared secret.
func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) {
var value uint32
err := kvdb.View(d.db, func(tx kvdb.RTx) error {
// Grab the shared hash bucket which stores the mapping from
// truncated sha-256 hashes of shared secrets to CLTV's.
sharedHashes := tx.ReadBucket(sharedHashBucket)
if sharedHashes == nil {
return fmt.Errorf("sharedHashes is nil, could " +
"not retrieve CLTV value")
}
// Retrieve the bytes which represents the CLTV
valueBytes := sharedHashes.Get(hash[:])
if valueBytes == nil {
return sphinx.ErrLogEntryNotFound
}
// The first 4 bytes represent the CLTV, store it in value.
value = uint32(binary.BigEndian.Uint32(valueBytes))
return nil
}, func() {
value = 0
})
if err != nil {
return value, err
}
return value, nil
}
// Put stores a shared secret hash as the key and the CLTV as the value.
func (d *DecayedLog) Put(hash *sphinx.HashPrefix, cltv uint32) error {
// Optimisitically serialize the cltv value into the scratch buffer.
var scratch [4]byte
binary.BigEndian.PutUint32(scratch[:], cltv)
return kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
// Check to see if this hash prefix has been recorded before. If
// a value is found, this packet is being replayed.
valueBytes := sharedHashes.Get(hash[:])
if valueBytes != nil {
return sphinx.ErrReplayedPacket
}
return sharedHashes.Put(hash[:], scratch[:])
})
}
// PutBatch accepts a pending batch of hashed secret entries to write to disk.
// Each hashed secret is inserted with a corresponding time value, dictating
// when the entry will be evicted from the log.
// NOTE: This method enforces idempotency by writing the replay set obtained
// from the first attempt for a particular batch ID, and decoding the return
// value to subsequent calls. For the indices of the replay set to be aligned
// properly, the batch MUST be constructed identically to the first attempt,
// pruning will cause the indices to become invalid.
func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) {
// Since batched boltdb txns may be executed multiple times before
// succeeding, we will create a new replay set for each invocation to
// avoid any side-effects. If the txn is successful, this replay set
// will be merged with the replay set computed during batch construction
// to generate the complete replay set. If this batch was previously
// processed, the replay set will be deserialized from disk.
var replays *sphinx.ReplaySet
if err := kvdb.Batch(d.db, func(tx kvdb.RwTx) error {
sharedHashes := tx.ReadWriteBucket(sharedHashBucket)
if sharedHashes == nil {
return ErrDecayedLogCorrupted
}
// Load the batch replay bucket, which will be used to either
// retrieve the result of previously processing this batch, or
// to write the result of this operation.
batchReplayBkt := tx.ReadWriteBucket(batchReplayBucket)
if batchReplayBkt == nil {
return ErrDecayedLogCorrupted
}
// Check for the existence of this batch's id in the replay
// bucket. If a non-nil value is found, this indicates that we
// have already processed this batch before. We deserialize the
// resulting and return it to ensure calls to put batch are
// idempotent.
replayBytes := batchReplayBkt.Get(b.ID)
if replayBytes != nil {
replays = sphinx.NewReplaySet()
return replays.Decode(bytes.NewReader(replayBytes))
}
// The CLTV will be stored into scratch and then stored into the
// sharedHashBucket.
var scratch [4]byte
replays = sphinx.NewReplaySet()
err := b.ForEach(func(seqNum uint16, hashPrefix *sphinx.HashPrefix, cltv uint32) error {
// Retrieve the bytes which represents the CLTV
valueBytes := sharedHashes.Get(hashPrefix[:])
if valueBytes != nil {
replays.Add(seqNum)
return nil
}
// Serialize the cltv value and write an entry keyed by
// the hash prefix.
binary.BigEndian.PutUint32(scratch[:], cltv)
return sharedHashes.Put(hashPrefix[:], scratch[:])
})
if err != nil {
return err
}
// Merge the replay set computed from checking the on-disk
// entries with the in-batch replays computed during this
// batch's construction.
replays.Merge(b.ReplaySet)
// Write the replay set under the batch identifier to the batch
// replays bucket. This can be used during recovery to test (1)
// that a particular batch was successfully processed and (2)
// recover the indexes of the adds that were rejected as
// replays.
var replayBuf bytes.Buffer
if err := replays.Encode(&replayBuf); err != nil {
return err
}
return batchReplayBkt.Put(b.ID, replayBuf.Bytes())
}); err != nil {
return nil, err
}
b.ReplaySet = replays
b.IsCommitted = true
return replays, nil
}
// A compile time check to see if DecayedLog adheres to the PersistLog
// interface.
var _ sphinx.ReplayLog = (*DecayedLog)(nil)
package htlcswitch
import (
"bytes"
"fmt"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
)
// ClearTextError is an interface which is implemented by errors that occur
// when we know the underlying wire failure message. These errors are the
// opposite to opaque errors which are onion-encrypted blobs only understandable
// to the initiating node. ClearTextErrors are used when we fail a htlc at our
// node, or one of our initiated payments failed and we can decrypt the onion
// encrypted error fully.
type ClearTextError interface {
error
// WireMessage extracts a valid wire failure message from an internal
// error which may contain additional metadata (which should not be
// exposed to the network). This value may be nil in the case where
// an unknown wire error is returned by one of our peers.
WireMessage() lnwire.FailureMessage
}
// LinkError is an implementation of the ClearTextError interface which
// represents failures that occur on our incoming or outgoing link.
type LinkError struct {
// msg returns the wire failure associated with the error.
// This value should *not* be nil, because we should always
// know the failure type for failures which occur at our own
// node.
msg lnwire.FailureMessage
// FailureDetail enriches the wire error with additional information.
FailureDetail
}
// NewLinkError returns a LinkError with the failure message provided.
// The failure message provided should *not* be nil, because we should
// always know the failure type for failures which occur at our own node.
func NewLinkError(msg lnwire.FailureMessage) *LinkError {
return &LinkError{msg: msg}
}
// NewDetailedLinkError returns a link error that enriches a wire message with
// a failure detail.
func NewDetailedLinkError(msg lnwire.FailureMessage,
detail FailureDetail) *LinkError {
return &LinkError{
msg: msg,
FailureDetail: detail,
}
}
// WireMessage extracts a valid wire failure message from an internal
// error which may contain additional metadata (which should not be
// exposed to the network). This value should never be nil for LinkErrors,
// because we are the ones failing the htlc.
//
// Note this is part of the ClearTextError interface.
func (l *LinkError) WireMessage() lnwire.FailureMessage {
return l.msg
}
// Error returns the string representation of a link error.
//
// Note this is part of the ClearTextError interface.
func (l *LinkError) Error() string {
// If the link error has no failure detail, return the wire message's
// error.
if l.FailureDetail == nil {
return l.msg.Error()
}
return l.FailureDetail.FailureString()
}
// ForwardingError wraps an lnwire.FailureMessage in a struct that also
// includes the source of the error.
type ForwardingError struct {
// FailureSourceIdx is the index of the node that sent the failure. With
// this information, the dispatcher of a payment can modify their set of
// candidate routes in response to the type of failure extracted. Index
// zero is the self node.
FailureSourceIdx int
// msg is the wire message associated with the error. This value may
// be nil in the case where we fail to decode failure message sent by
// a peer.
msg lnwire.FailureMessage
}
// WireMessage extracts a valid wire failure message from an internal
// error which may contain additional metadata (which should not be
// exposed to the network). This value may be nil in the case where
// an unknown wire error is returned by one of our peers.
//
// Note this is part of the ClearTextError interface.
func (f *ForwardingError) WireMessage() lnwire.FailureMessage {
return f.msg
}
// Error implements the built-in error interface. We use this method to allow
// the switch or any callers to insert additional context to the error message
// returned.
func (f *ForwardingError) Error() string {
return fmt.Sprintf(
"%v@%v", f.msg, f.FailureSourceIdx,
)
}
// NewForwardingError creates a new payment error which wraps a wire error
// with additional metadata.
func NewForwardingError(failure lnwire.FailureMessage,
index int) *ForwardingError {
return &ForwardingError{
FailureSourceIdx: index,
msg: failure,
}
}
// NewUnknownForwardingError returns a forwarding error which has a nil failure
// message. This constructor should only be used in the case where we cannot
// decode the failure we have received from a peer.
func NewUnknownForwardingError(index int) *ForwardingError {
return &ForwardingError{
FailureSourceIdx: index,
}
}
// ErrorDecrypter is an interface that is used to decrypt the onion encrypted
// failure reason an extra out a well formed error.
type ErrorDecrypter interface {
// DecryptError peels off each layer of onion encryption from the first
// hop, to the source of the error. A fully populated
// lnwire.FailureMessage is returned along with the source of the
// error.
DecryptError(lnwire.OpaqueReason) (*ForwardingError, error)
}
// UnknownEncrypterType is an error message used to signal that an unexpected
// EncrypterType was encountered during decoding.
type UnknownEncrypterType hop.EncrypterType
// Error returns a formatted error indicating the invalid EncrypterType.
func (e UnknownEncrypterType) Error() string {
return fmt.Sprintf("unknown error encrypter type: %d", e)
}
// OnionErrorDecrypter is the interface that provides onion level error
// decryption.
type OnionErrorDecrypter interface {
// DecryptError attempts to decrypt the passed encrypted error response.
// The onion failure is encrypted in backward manner, starting from the
// node where error have occurred. As a result, in order to decrypt the
// error we need get all shared secret and apply decryption in the
// reverse order.
DecryptError(encryptedData []byte) (*sphinx.DecryptedError, error)
}
// SphinxErrorDecrypter wraps the sphinx data SphinxErrorDecrypter and maps the
// returned errors to concrete lnwire.FailureMessage instances.
type SphinxErrorDecrypter struct {
OnionErrorDecrypter
}
// DecryptError peels off each layer of onion encryption from the first hop, to
// the source of the error. A fully populated lnwire.FailureMessage is returned
// along with the source of the error.
//
// NOTE: Part of the ErrorDecrypter interface.
func (s *SphinxErrorDecrypter) DecryptError(reason lnwire.OpaqueReason) (
*ForwardingError, error) {
failure, err := s.OnionErrorDecrypter.DecryptError(reason)
if err != nil {
return nil, err
}
// Decode the failure. If an error occurs, we leave the failure message
// field nil.
r := bytes.NewReader(failure.Message)
failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil {
return NewUnknownForwardingError(failure.SenderIdx), nil
}
return NewForwardingError(failureMsg, failure.SenderIdx), nil
}
// A compile time check to ensure ErrorDecrypter implements the Deobfuscator
// interface.
var _ ErrorDecrypter = (*SphinxErrorDecrypter)(nil)
package htlcswitch
// FailureDetail is an interface implemented by failures that occur on
// our incoming or outgoing link, or within the switch itself.
type FailureDetail interface {
// FailureString returns the string representation of a failure
// detail.
FailureString() string
}
// OutgoingFailure is an enum which is used to enrich failures which occur in
// the switch or on our outgoing link with additional metadata.
type OutgoingFailure int
const (
// OutgoingFailureNone is returned when the wire message contains
// sufficient information.
OutgoingFailureNone OutgoingFailure = iota
// OutgoingFailureDecodeError indicates that we could not decode the
// failure reason provided for a failed payment.
OutgoingFailureDecodeError
// OutgoingFailureLinkNotEligible indicates that a routing attempt was
// made over a link that is not eligible for routing.
OutgoingFailureLinkNotEligible
// OutgoingFailureOnChainTimeout indicates that a payment had to be
// timed out on chain before it got past the first hop by us or the
// remote party.
OutgoingFailureOnChainTimeout
// OutgoingFailureHTLCExceedsMax is returned when a htlc exceeds our
// policy's maximum htlc amount.
OutgoingFailureHTLCExceedsMax
// OutgoingFailureInsufficientBalance is returned when we cannot route a
// htlc due to insufficient outgoing capacity.
OutgoingFailureInsufficientBalance
// OutgoingFailureCircularRoute is returned when an attempt is made
// to forward a htlc through our node which arrives and leaves on the
// same channel.
OutgoingFailureCircularRoute
// OutgoingFailureIncompleteForward is returned when we cancel an incomplete
// forward.
OutgoingFailureIncompleteForward
// OutgoingFailureDownstreamHtlcAdd is returned when we fail to add a
// downstream htlc to our outgoing link.
OutgoingFailureDownstreamHtlcAdd
// OutgoingFailureForwardsDisabled is returned when the switch is
// configured to disallow forwards.
OutgoingFailureForwardsDisabled
)
// FailureString returns the string representation of a failure detail.
//
// Note: it is part of the FailureDetail interface.
func (fd OutgoingFailure) FailureString() string {
switch fd {
case OutgoingFailureNone:
return "no failure detail"
case OutgoingFailureDecodeError:
return "could not decode wire failure"
case OutgoingFailureLinkNotEligible:
return "link not eligible"
case OutgoingFailureOnChainTimeout:
return "payment was resolved on-chain, then canceled back"
case OutgoingFailureHTLCExceedsMax:
return "htlc exceeds maximum policy amount"
case OutgoingFailureInsufficientBalance:
return "insufficient bandwidth to route htlc"
case OutgoingFailureCircularRoute:
return "same incoming and outgoing channel"
case OutgoingFailureIncompleteForward:
return "failed after detecting incomplete forward"
case OutgoingFailureDownstreamHtlcAdd:
return "could not add downstream htlc"
case OutgoingFailureForwardsDisabled:
return "node configured to disallow forwards"
default:
return "unknown failure detail"
}
}
package htlcswitch
import (
"errors"
"fmt"
"github.com/lightningnetwork/lnd/graph/db/models"
)
// heldHtlcSet keeps track of outstanding intercepted forwards. It exposes
// several methods to manipulate the underlying map structure in a consistent
// way.
type heldHtlcSet struct {
set map[models.CircuitKey]InterceptedForward
}
func newHeldHtlcSet() *heldHtlcSet {
return &heldHtlcSet{
set: make(map[models.CircuitKey]InterceptedForward),
}
}
// forEach iterates over all held forwards and calls the given callback for each
// of them.
func (h *heldHtlcSet) forEach(cb func(InterceptedForward)) {
for _, fwd := range h.set {
cb(fwd)
}
}
// popAll calls the callback for each forward and removes them from the set.
func (h *heldHtlcSet) popAll(cb func(InterceptedForward)) {
for _, fwd := range h.set {
cb(fwd)
}
h.set = make(map[models.CircuitKey]InterceptedForward)
}
// popAutoFails calls the callback for each forward that has an auto-fail height
// equal or less then the specified pop height and removes them from the set.
func (h *heldHtlcSet) popAutoFails(height uint32, cb func(InterceptedForward)) {
for key, fwd := range h.set {
if uint32(fwd.Packet().AutoFailHeight) > height {
continue
}
cb(fwd)
delete(h.set, key)
}
}
// pop returns the specified forward and removes it from the set.
func (h *heldHtlcSet) pop(key models.CircuitKey) (InterceptedForward, error) {
intercepted, ok := h.set[key]
if !ok {
return nil, fmt.Errorf("fwd %v not found", key)
}
delete(h.set, key)
return intercepted, nil
}
// exists tests whether the specified forward is part of the set.
func (h *heldHtlcSet) exists(key models.CircuitKey) bool {
_, ok := h.set[key]
return ok
}
// push adds the specified forward to the set. An error is returned if the
// forward exists already.
func (h *heldHtlcSet) push(key models.CircuitKey,
fwd InterceptedForward) error {
if fwd == nil {
return errors.New("nil fwd pushed")
}
if h.exists(key) {
return errors.New("htlc already exists in set")
}
h.set[key] = fwd
return nil
}
//go:build !dev
// +build !dev
package hodl
// Config is an empty struct disabling command line hodl flags in production.
type Config struct{}
// Mask in production always returns MaskNone.
func (c *Config) Mask() Mask {
return MaskNone
}
package hodl
import "fmt"
// MaskNone represents the empty Mask, in which no breakpoints are
// active.
const MaskNone = Mask(0)
type (
// Flag represents a single breakpoint where an HTLC should be dropped
// during forwarding. Flags can be composed into a Mask to express more
// complex combinations.
Flag uint32
// Mask is a bitvector combining multiple Flags that can be queried to
// see which breakpoints are active.
Mask uint32
)
const (
// ExitSettle drops an incoming ADD for which we are the exit node,
// before processing in the link.
ExitSettle Flag = 1 << iota
// AddIncoming drops an incoming ADD before processing if we are not
// the exit node.
AddIncoming
// SettleIncoming drops an incoming SETTLE before processing if we
// are not the exit node.
SettleIncoming
// FailIncoming drops an incoming FAIL before processing if we are
// not the exit node.
FailIncoming
// TODO(conner): add modes for switch breakpoints
// AddOutgoing drops an outgoing ADD before it is added to the
// in-memory commitment state of the link.
AddOutgoing
// SettleOutgoing drops an SETTLE before it is added to the
// in-memory commitment state of the link.
SettleOutgoing
// FailOutgoing drops an outgoing FAIL before is is added to the
// in-memory commitment state of the link.
FailOutgoing
// Commit drops all HTLC after any outgoing circuits have been
// opened, but before the in-memory commitment state is persisted.
Commit
// BogusSettle attempts to settle back any incoming HTLC for which we
// are the exit node with a bogus preimage.
BogusSettle
)
// String returns a human-readable identifier for a given Flag.
func (f Flag) String() string {
switch f {
case ExitSettle:
return "ExitSettle"
case AddIncoming:
return "AddIncoming"
case SettleIncoming:
return "SettleIncoming"
case FailIncoming:
return "FailIncoming"
case AddOutgoing:
return "AddOutgoing"
case SettleOutgoing:
return "SettleOutgoing"
case FailOutgoing:
return "FailOutgoing"
case Commit:
return "Commit"
case BogusSettle:
return "BogusSettle"
default:
return "UnknownHodlFlag"
}
}
// Warning generates a warning message to log if a particular breakpoint is
// triggered during execution.
func (f Flag) Warning() string {
var msg string
switch f {
case ExitSettle:
msg = "will not attempt to settle ADD with sender"
case AddIncoming:
msg = "will not attempt to forward ADD to switch"
case SettleIncoming:
msg = "will not attempt to forward SETTLE to switch"
case FailIncoming:
msg = "will not attempt to forward FAIL to switch"
case AddOutgoing:
msg = "will not update channel state with downstream ADD"
case SettleOutgoing:
msg = "will not update channel state with downstream SETTLE"
case FailOutgoing:
msg = "will not update channel state with downstream FAIL"
case Commit:
msg = "will not commit pending channel updates"
case BogusSettle:
msg = "will settle HTLC with bogus preimage"
default:
msg = "incorrect hodl flag usage"
}
return fmt.Sprintf("%s mode enabled -- %s", f, msg)
}
// Mask returns the Mask consisting solely of this Flag.
func (f Flag) Mask() Mask {
return Mask(f)
}
//go:build !dev
// +build !dev
package hodl
// MaskFromFlags in production always returns MaskNone.
func MaskFromFlags(_ ...Flag) Mask {
return MaskNone
}
// Active in production always returns false for all Flags.
func (m Mask) Active(_ Flag) bool {
return false
}
// String returns the human-readable identifier for MaskNone.
func (m Mask) String() string {
return "hodl.Mask(NONE)"
}
package hop
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
)
// EncrypterType establishes an enum used in serialization to indicate how to
// decode a concrete instance of the ErrorEncrypter interface.
type EncrypterType byte
const (
// EncrypterTypeNone signals that no error encyrpter is present, this
// can happen if the htlc is originates in the switch.
EncrypterTypeNone EncrypterType = 0
// EncrypterTypeSphinx is used to identify a sphinx onion error
// encrypter instance.
EncrypterTypeSphinx = 1
// EncrypterTypeMock is used to identify a mock obfuscator instance.
EncrypterTypeMock = 2
// EncrypterTypeIntroduction is used to identify a sphinx onion error
// encrypter where we are the introduction node in a blinded route. It
// has the same functionality as EncrypterTypeSphinx, but is used to
// mark our special-case error handling.
EncrypterTypeIntroduction = 3
// EncrypterTypeRelaying is used to identify a sphinx onion error
// encryper where we are a relaying node in a blinded route. It has
// the same functionality as a EncrypterTypeSphinx, but is used to mark
// our special-case error handling.
EncrypterTypeRelaying = 4
)
// IsBlinded returns a boolean indicating whether the error encrypter belongs
// to a blinded route.
func (e EncrypterType) IsBlinded() bool {
return e == EncrypterTypeIntroduction || e == EncrypterTypeRelaying
}
// ErrorEncrypterExtracter defines a function signature that extracts an
// ErrorEncrypter from an sphinx OnionPacket.
type ErrorEncrypterExtracter func(*btcec.PublicKey) (ErrorEncrypter,
lnwire.FailCode)
// ErrorEncrypter is an interface that is used to encrypt HTLC related errors
// at the source of the error, and also at each intermediate hop all the way
// back to the source of the payment.
type ErrorEncrypter interface {
// EncryptFirstHop transforms a concrete failure message into an
// encrypted opaque failure reason. This method will be used at the
// source that the error occurs. It differs from IntermediateEncrypt
// slightly, in that it computes a proper MAC over the error.
EncryptFirstHop(lnwire.FailureMessage) (lnwire.OpaqueReason, error)
// EncryptMalformedError is similar to EncryptFirstHop (it adds the
// MAC), but it accepts an opaque failure reason rather than a failure
// message. This method is used when we receive an
// UpdateFailMalformedHTLC from the remote peer and then need to
// convert that into a proper error from only the raw bytes.
EncryptMalformedError(lnwire.OpaqueReason) lnwire.OpaqueReason
// IntermediateEncrypt wraps an already encrypted opaque reason error
// in an additional layer of onion encryption. This process repeats
// until the error arrives at the source of the payment.
IntermediateEncrypt(lnwire.OpaqueReason) lnwire.OpaqueReason
// Type returns an enum indicating the underlying concrete instance
// backing this interface.
Type() EncrypterType
// Encode serializes the encrypter's ephemeral public key to the given
// io.Writer.
Encode(io.Writer) error
// Decode deserializes the encrypter' ephemeral public key from the
// given io.Reader.
Decode(io.Reader) error
// Reextract rederives the encrypter using the extracter, performing an
// ECDH with the sphinx router's key and the ephemeral public key.
//
// NOTE: This should be called shortly after Decode to properly
// reinitialize the error encrypter.
Reextract(ErrorEncrypterExtracter) error
}
// SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter
// interface backed by an implementation of the Sphinx packet format. As a
// result, all errors handled are themselves wrapped in layers of onion
// encryption and must be treated as such accordingly.
type SphinxErrorEncrypter struct {
*sphinx.OnionErrorEncrypter
EphemeralKey *btcec.PublicKey
}
// NewSphinxErrorEncrypter initializes a blank sphinx error encrypter, that
// should be used to deserialize an encoded SphinxErrorEncrypter. Since the
// actual encrypter is not stored in plaintext while at rest, reconstructing the
// error encrypter requires:
// 1. Decode: to deserialize the ephemeral public key.
// 2. Reextract: to "unlock" the actual error encrypter using an active
// OnionProcessor.
func NewSphinxErrorEncrypter() *SphinxErrorEncrypter {
return &SphinxErrorEncrypter{
OnionErrorEncrypter: nil,
EphemeralKey: &btcec.PublicKey{},
}
}
// EncryptFirstHop transforms a concrete failure message into an encrypted
// opaque failure reason. This method will be used at the source that the error
// occurs. It differs from BackwardObfuscate slightly, in that it computes a
// proper MAC over the error.
//
// NOTE: Part of the ErrorEncrypter interface.
func (s *SphinxErrorEncrypter) EncryptFirstHop(
failure lnwire.FailureMessage) (lnwire.OpaqueReason, error) {
var b bytes.Buffer
if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
return nil, err
}
// We pass a true as the first parameter to indicate that a MAC should
// be added.
return s.EncryptError(true, b.Bytes()), nil
}
// EncryptMalformedError is similar to EncryptFirstHop (it adds the MAC), but
// it accepts an opaque failure reason rather than a failure message. This
// method is used when we receive an UpdateFailMalformedHTLC from the remote
// peer and then need to convert that into an proper error from only the raw
// bytes.
//
// NOTE: Part of the ErrorEncrypter interface.
func (s *SphinxErrorEncrypter) EncryptMalformedError(
reason lnwire.OpaqueReason) lnwire.OpaqueReason {
return s.EncryptError(true, reason)
}
// IntermediateEncrypt wraps an already encrypted opaque reason error in an
// additional layer of onion encryption. This process repeats until the error
// arrives at the source of the payment. We re-encrypt the message on the
// backwards path to ensure that the error is indistinguishable from any other
// error seen.
//
// NOTE: Part of the ErrorEncrypter interface.
func (s *SphinxErrorEncrypter) IntermediateEncrypt(
reason lnwire.OpaqueReason) lnwire.OpaqueReason {
return s.EncryptError(false, reason)
}
// Type returns the identifier for a sphinx error encrypter.
func (s *SphinxErrorEncrypter) Type() EncrypterType {
return EncrypterTypeSphinx
}
// Encode serializes the error encrypter' ephemeral public key to the provided
// io.Writer.
func (s *SphinxErrorEncrypter) Encode(w io.Writer) error {
ephemeral := s.EphemeralKey.SerializeCompressed()
_, err := w.Write(ephemeral)
return err
}
// Decode reconstructs the error encrypter's ephemeral public key from the
// provided io.Reader.
func (s *SphinxErrorEncrypter) Decode(r io.Reader) error {
var ephemeral [33]byte
if _, err := io.ReadFull(r, ephemeral[:]); err != nil {
return err
}
var err error
s.EphemeralKey, err = btcec.ParsePubKey(ephemeral[:])
if err != nil {
return err
}
return nil
}
// Reextract rederives the error encrypter from the currently held EphemeralKey.
// This intended to be used shortly after Decode, to fully initialize a
// SphinxErrorEncrypter.
func (s *SphinxErrorEncrypter) Reextract(
extract ErrorEncrypterExtracter) error {
obfuscator, failcode := extract(s.EphemeralKey)
if failcode != lnwire.CodeNone {
// This should never happen, since we already validated that
// this obfuscator can be extracted when it was received in the
// link.
return fmt.Errorf("unable to reconstruct onion "+
"obfuscator, got failcode: %d", failcode)
}
sphinxEncrypter, ok := obfuscator.(*SphinxErrorEncrypter)
if !ok {
return fmt.Errorf("incorrect onion error extracter")
}
// Copy the freshly extracted encrypter.
s.OnionErrorEncrypter = sphinxEncrypter.OnionErrorEncrypter
return nil
}
// A compile time check to ensure SphinxErrorEncrypter implements the
// ErrorEncrypter interface.
var _ ErrorEncrypter = (*SphinxErrorEncrypter)(nil)
// A compile time check to ensure that IntroductionErrorEncrypter implements
// the ErrorEncrypter interface.
var _ ErrorEncrypter = (*IntroductionErrorEncrypter)(nil)
// IntroductionErrorEncrypter is a wrapper type on SphinxErrorEncrypter which
// is used to signal that we have special HTLC error handling for this hop.
type IntroductionErrorEncrypter struct {
// ErrorEncrypter is the underlying error encrypter, embedded
// directly in the struct so that we don't have to re-implement the
// ErrorEncrypter interface.
ErrorEncrypter
}
// NewIntroductionErrorEncrypter returns a blank IntroductionErrorEncrypter.
func NewIntroductionErrorEncrypter() *IntroductionErrorEncrypter {
return &IntroductionErrorEncrypter{
ErrorEncrypter: NewSphinxErrorEncrypter(),
}
}
// Type returns the identifier for an introduction error encrypter.
func (i *IntroductionErrorEncrypter) Type() EncrypterType {
return EncrypterTypeIntroduction
}
// Reextract rederives the error encrypter from the currently held EphemeralKey,
// relying on the logic in the underlying SphinxErrorEncrypter.
func (i *IntroductionErrorEncrypter) Reextract(
extract ErrorEncrypterExtracter) error {
return i.ErrorEncrypter.Reextract(extract)
}
// A compile time check to ensure that RelayingErrorEncrypte implements
// the ErrorEncrypter interface.
var _ ErrorEncrypter = (*RelayingErrorEncrypter)(nil)
// RelayingErrorEncrypter is a wrapper type on SphinxErrorEncrypter which
// is used to signal that we have special HTLC error handling for this hop.
type RelayingErrorEncrypter struct {
ErrorEncrypter
}
// NewRelayingErrorEncrypter returns a blank RelayingErrorEncrypter with
// an underlying SphinxErrorEncrypter.
func NewRelayingErrorEncrypter() *RelayingErrorEncrypter {
return &RelayingErrorEncrypter{
ErrorEncrypter: NewSphinxErrorEncrypter(),
}
}
// Type returns the identifier for a relaying error encrypter.
func (r *RelayingErrorEncrypter) Type() EncrypterType {
return EncrypterTypeRelaying
}
// Reextract rederives the error encrypter from the currently held EphemeralKey,
// relying on the logic in the underlying SphinxErrorEncrypter.
func (r *RelayingErrorEncrypter) Reextract(
extract ErrorEncrypterExtracter) error {
return r.ErrorEncrypter.Reextract(extract)
}
package hop
import (
"bytes"
"errors"
"fmt"
"io"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// ErrDecodeFailed is returned when we can't decode blinded data.
ErrDecodeFailed = errors.New("could not decode blinded data")
// ErrNoBlindingPoint is returned when we have not provided a blinding
// point for a validated payload with encrypted data set.
ErrNoBlindingPoint = errors.New("no blinding point set for validated " +
"blinded hop")
)
// RouteRole represents the different types of roles a node can have as a
// recipient of a HTLC.
type RouteRole uint8
const (
// RouteRoleCleartext represents a regular route hop.
RouteRoleCleartext RouteRole = iota
// RouteRoleIntroduction represents an introduction node in a blinded
// path, characterized by a blinding point in the onion payload.
RouteRoleIntroduction
// RouteRoleRelaying represents a relaying node in a blinded path,
// characterized by a blinding point in update_add_htlc.
RouteRoleRelaying
)
// String representation of a role in a route.
func (h RouteRole) String() string {
switch h {
case RouteRoleCleartext:
return "cleartext"
case RouteRoleRelaying:
return "blinded relay"
case RouteRoleIntroduction:
return "introduction node"
default:
return fmt.Sprintf("unknown route role: %d", h)
}
}
// NewRouteRole returns the role we're playing in a route depending on the
// blinding points set (or not). If we are in the situation where we received
// blinding points in both the update add message and the payload:
// - We must have had a valid update add blinding point, because we were able
// to decrypt our onion to get the payload blinding point.
// - We return a relaying node role, because an introduction node (by
// definition) does not receive a blinding point in update add.
// - We assume the sending node to be buggy (including a payload blinding
// where it shouldn't), and rely on validation elsewhere to handle this.
func NewRouteRole(updateAddBlinding, payloadBlinding bool) RouteRole {
switch {
case updateAddBlinding:
return RouteRoleRelaying
case payloadBlinding:
return RouteRoleIntroduction
default:
return RouteRoleCleartext
}
}
// Iterator is an interface that abstracts away the routing information
// included in HTLC's which includes the entirety of the payment path of an
// HTLC. This interface provides two basic method which carry out: how to
// interpret the forwarding information encoded within the HTLC packet, and hop
// to encode the forwarding information for the _next_ hop.
type Iterator interface {
// HopPayload returns the set of fields that detail exactly _how_ this
// hop should forward the HTLC to the next hop. Additionally, the
// information encoded within the returned ForwardingInfo is to be used
// by each hop to authenticate the information given to it by the prior
// hop. The payload will also contain any additional TLV fields provided
// by the sender. The role that this hop plays in the context of
// route blinding (regular, introduction or relaying) is returned
// whenever the payload is successfully parsed, even if we subsequently
// face a validation error.
HopPayload() (*Payload, RouteRole, error)
// EncodeNextHop encodes the onion packet destined for the next hop
// into the passed io.Writer.
EncodeNextHop(w io.Writer) error
// ExtractErrorEncrypter returns the ErrorEncrypter needed for this hop,
// along with a failure code to signal if the decoding was successful.
ExtractErrorEncrypter(extractor ErrorEncrypterExtracter,
introductionNode bool) (ErrorEncrypter, lnwire.FailCode)
}
// sphinxHopIterator is the Sphinx implementation of hop iterator which uses
// onion routing to encode the payment route in such a way so that node might
// see only the next hop in the route.
type sphinxHopIterator struct {
// ogPacket is the original packet from which the processed packet is
// derived.
ogPacket *sphinx.OnionPacket
// processedPacket is the outcome of processing an onion packet. It
// includes the information required to properly forward the packet to
// the next hop.
processedPacket *sphinx.ProcessedPacket
// blindingKit contains the elements required to process hops that are
// part of a blinded route.
blindingKit BlindingKit
// rHash holds the payment hash for this payment. This is needed for
// when a new hop iterator is constructed.
rHash []byte
// router holds the router which can be used to decrypt onion payloads.
// This is required for peeling of dummy hops in a blinded path where
// the same node will iteratively need to unwrap the onion.
router *sphinx.Router
}
// makeSphinxHopIterator converts a processed packet returned from a sphinx
// router and converts it into an hop iterator for usage in the link. A
// blinding kit is passed through for the link to obtain forwarding information
// for blinded routes.
func makeSphinxHopIterator(router *sphinx.Router, ogPacket *sphinx.OnionPacket,
packet *sphinx.ProcessedPacket, blindingKit BlindingKit,
rHash []byte) *sphinxHopIterator {
return &sphinxHopIterator{
router: router,
ogPacket: ogPacket,
processedPacket: packet,
blindingKit: blindingKit,
rHash: rHash,
}
}
// A compile time check to ensure sphinxHopIterator implements the HopIterator
// interface.
var _ Iterator = (*sphinxHopIterator)(nil)
// Encode encodes iterator and writes it to the writer.
//
// NOTE: Part of the HopIterator interface.
func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error {
return r.processedPacket.NextPacket.Encode(w)
}
// HopPayload returns the set of fields that detail exactly _how_ this hop
// should forward the HTLC to the next hop. Additionally, the information
// encoded within the returned ForwardingInfo is to be used by each hop to
// authenticate the information given to it by the prior hop. The role that
// this hop plays in the context of route blinding (regular, introduction or
// relaying) is returned whenever the payload is successfully parsed, even if
// we subsequently face a validation error. The payload will also contain any
// additional TLV fields provided by the sender.
//
// NOTE: Part of the HopIterator interface.
func (r *sphinxHopIterator) HopPayload() (*Payload, RouteRole, error) {
switch r.processedPacket.Payload.Type {
// If this is the legacy payload, then we'll extract the information
// directly from the pre-populated ForwardingInstructions field.
case sphinx.PayloadLegacy:
fwdInst := r.processedPacket.ForwardingInstructions
return NewLegacyPayload(fwdInst), RouteRoleCleartext, nil
// Otherwise, if this is the TLV payload, then we'll make a new stream
// to decode only what we need to make routing decisions.
case sphinx.PayloadTLV:
return extractTLVPayload(r)
default:
return nil, RouteRoleCleartext,
fmt.Errorf("unknown sphinx payload type: %v",
r.processedPacket.Payload.Type)
}
}
// extractTLVPayload parses the hop payload and assumes that it uses the TLV
// format. It returns the parsed payload along with the RouteRole that this hop
// plays given the contents of the payload.
func extractTLVPayload(r *sphinxHopIterator) (*Payload, RouteRole, error) {
isFinal := r.processedPacket.Action == sphinx.ExitNode
// Initial payload parsing and validation
payload, routeRole, recipientData, err := parseAndValidateSenderPayload(
r.processedPacket.Payload.Payload, isFinal,
r.blindingKit.UpdateAddBlinding.IsSome(),
)
if err != nil {
return nil, routeRole, err
}
// If the payload contained no recipient data, then we can exit now.
if !recipientData {
return payload, routeRole, nil
}
return parseAndValidateRecipientData(r, payload, isFinal, routeRole)
}
// parseAndValidateRecipientData decrypts the payload from the recipient and
// then continues handling and validation based on if we are a forwarding node
// in this blinded path or the final destination node.
func parseAndValidateRecipientData(r *sphinxHopIterator, payload *Payload,
isFinal bool, routeRole RouteRole) (*Payload, RouteRole, error) {
// Decrypt and validate the blinded route data
routeData, blindingPoint, err := decryptAndValidateBlindedRouteData(
r, payload,
)
if err != nil {
return nil, routeRole, err
}
// This is the final node in the blinded route.
if isFinal {
return deriveBlindedRouteFinalHopForwardingInfo(
routeData, payload, routeRole,
)
}
// Else, we are a forwarding node in this blinded path.
return deriveBlindedRouteForwardingInfo(
r, routeData, payload, routeRole, blindingPoint,
)
}
// deriveBlindedRouteFinalHopForwardingInfo extracts the PathID from the
// routeData and constructs the ForwardingInfo accordingly.
func deriveBlindedRouteFinalHopForwardingInfo(
routeData *record.BlindedRouteData, payload *Payload,
routeRole RouteRole) (*Payload, RouteRole, error) {
var pathID *chainhash.Hash
routeData.PathID.WhenSome(func(r tlv.RecordT[tlv.TlvType6, []byte]) {
var id chainhash.Hash
copy(id[:], r.Val)
pathID = &id
})
if pathID == nil {
return nil, routeRole, ErrInvalidPayload{
Type: tlv.Type(6),
Violation: InsufficientViolation,
}
}
payload.FwdInfo = ForwardingInfo{
PathID: pathID,
}
return payload, routeRole, nil
}
// deriveBlindedRouteForwardingInfo uses the parsed BlindedRouteData from the
// recipient to derive the ForwardingInfo for the payment.
func deriveBlindedRouteForwardingInfo(r *sphinxHopIterator,
routeData *record.BlindedRouteData, payload *Payload,
routeRole RouteRole, blindingPoint *btcec.PublicKey) (*Payload,
RouteRole, error) {
relayInfo, err := routeData.RelayInfo.UnwrapOrErr(
fmt.Errorf("relay info not set for non-final blinded hop"),
)
if err != nil {
return nil, routeRole, err
}
fwdAmt, err := calculateForwardingAmount(
r.blindingKit.IncomingAmount, relayInfo.Val.BaseFee,
relayInfo.Val.FeeRate,
)
if err != nil {
return nil, routeRole, err
}
nextEph, err := routeData.NextBlindingOverride.UnwrapOrFuncErr(
func() (tlv.RecordT[tlv.TlvType8, *btcec.PublicKey], error) {
next, err := r.blindingKit.Processor.NextEphemeral(
blindingPoint,
)
if err != nil {
return routeData.NextBlindingOverride.Zero(),
err
}
return tlv.NewPrimitiveRecord[tlv.TlvType8](next), nil
})
if err != nil {
return nil, routeRole, err
}
// If the payload signals that the following hop is a dummy hop, then
// we will iteratively peel the dummy hop until we reach the final
// payload.
if checkForDummyHop(routeData, r.router.OnionPublicKey()) {
return peelBlindedPathDummyHop(
r, uint32(relayInfo.Val.CltvExpiryDelta), fwdAmt,
routeRole, nextEph,
)
}
nextSCID, err := routeData.ShortChannelID.UnwrapOrErr(
fmt.Errorf("next SCID not set for non-final blinded hop"),
)
if err != nil {
return nil, routeRole, err
}
payload.FwdInfo = ForwardingInfo{
NextHop: nextSCID.Val,
AmountToForward: fwdAmt,
OutgoingCTLV: r.blindingKit.IncomingCltv - uint32(
relayInfo.Val.CltvExpiryDelta,
),
// Remap from blinding override type to blinding point type.
NextBlinding: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
nextEph.Val,
),
),
}
return payload, routeRole, nil
}
// checkForDummyHop returns whether the given BlindedRouteData packet indicates
// the presence of a dummy hop.
func checkForDummyHop(routeData *record.BlindedRouteData,
routerPubKey *btcec.PublicKey) bool {
var isDummy bool
routeData.NextNodeID.WhenSome(
func(r tlv.RecordT[tlv.TlvType4, *btcec.PublicKey]) {
isDummy = r.Val.IsEqual(routerPubKey)
},
)
return isDummy
}
// peelBlindedPathDummyHop packages the next onion packet and then constructs
// a new hop iterator using our router and then proceeds to process the next
// packet. This can only be done for blinded route dummy hops since we expect
// to be the final hop on the path.
func peelBlindedPathDummyHop(r *sphinxHopIterator, cltvExpiryDelta uint32,
fwdAmt lnwire.MilliSatoshi, routeRole RouteRole,
nextEph tlv.RecordT[tlv.TlvType8, *btcec.PublicKey]) (*Payload,
RouteRole, error) {
onionPkt := r.processedPacket.NextPacket
sphinxPacket, err := r.router.ReconstructOnionPacket(
onionPkt, r.rHash, sphinx.WithBlindingPoint(nextEph.Val),
)
if err != nil {
return nil, routeRole, err
}
iterator := makeSphinxHopIterator(
r.router, onionPkt, sphinxPacket, BlindingKit{
Processor: r.router,
UpdateAddBlinding: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( //nolint:ll
nextEph.Val,
),
),
IncomingAmount: fwdAmt,
IncomingCltv: r.blindingKit.IncomingCltv -
cltvExpiryDelta,
}, r.rHash,
)
return extractTLVPayload(iterator)
}
// decryptAndValidateBlindedRouteData decrypts the encrypted payload from the
// payment recipient using a blinding key. The incoming HTLC amount and CLTV
// values are then verified against the policy values from the recipient.
func decryptAndValidateBlindedRouteData(r *sphinxHopIterator,
payload *Payload) (*record.BlindedRouteData, *btcec.PublicKey, error) {
blindingPoint, err := r.blindingKit.getBlindingPoint(
payload.blindingPoint,
)
if err != nil {
return nil, nil, err
}
decrypted, err := r.blindingKit.Processor.DecryptBlindedHopData(
blindingPoint, payload.encryptedData,
)
if err != nil {
return nil, nil, fmt.Errorf("decrypt blinded data: %w", err)
}
buf := bytes.NewBuffer(decrypted)
routeData, err := record.DecodeBlindedRouteData(buf)
if err != nil {
return nil, nil, fmt.Errorf("%w: %w", ErrDecodeFailed, err)
}
err = ValidateBlindedRouteData(
routeData, r.blindingKit.IncomingAmount,
r.blindingKit.IncomingCltv,
)
if err != nil {
return nil, nil, err
}
return routeData, blindingPoint, nil
}
// parseAndValidateSenderPayload parses the payload bytes received from the
// onion constructor (the sender) and validates that various fields have been
// set. It also uses the presence of a blinding key in either the
// update_add_htlc message or in the payload to determine the RouteRole.
// The RouteRole is returned even if an error is returned. The boolean return
// value indicates that the sender payload includes encrypted data from the
// recipient that should be parsed.
func parseAndValidateSenderPayload(payloadBytes []byte, isFinalHop,
updateAddBlindingSet bool) (*Payload, RouteRole, bool, error) {
// Extract TLVs from the packet constructor (the sender).
payload, parsed, err := ParseTLVPayload(bytes.NewReader(payloadBytes))
if err != nil {
// If we couldn't even parse our payload then we do a
// best-effort of determining our role in a blinded route,
// accepting that we can't know whether we were the introduction
// node (as the payload is not parseable).
routeRole := RouteRoleCleartext
if updateAddBlindingSet {
routeRole = RouteRoleRelaying
}
return nil, routeRole, false, err
}
// Now that we've parsed our payload we can determine which role we're
// playing in the route.
_, payloadBlinding := parsed[record.BlindingPointOnionType]
routeRole := NewRouteRole(updateAddBlindingSet, payloadBlinding)
// Validate the presence of the various payload fields we received from
// the sender.
err = ValidateTLVPayload(parsed, isFinalHop, updateAddBlindingSet)
if err != nil {
return nil, routeRole, false, err
}
// If there is no encrypted data from the receiver then return the
// payload as is since the forwarding info would have been received
// from the sender.
if payload.encryptedData == nil {
return payload, routeRole, false, nil
}
// Validate the presence of various fields in the sender payload given
// that we now know that this is a hop with instructions from the
// recipient.
err = ValidatePayloadWithBlinded(isFinalHop, parsed)
if err != nil {
return payload, routeRole, true, err
}
return payload, routeRole, true, nil
}
// ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop,
// along with a failure code to signal if the decoding was successful. The
// ErrorEncrypter is used to encrypt errors back to the sender in the event that
// a payment fails.
//
// NOTE: Part of the HopIterator interface.
func (r *sphinxHopIterator) ExtractErrorEncrypter(
extracter ErrorEncrypterExtracter, introductionNode bool) (
ErrorEncrypter, lnwire.FailCode) {
encrypter, errCode := extracter(r.ogPacket.EphemeralKey)
if errCode != lnwire.CodeNone {
return nil, errCode
}
// If we're in a blinded path, wrap the error encrypter that we just
// derived in a "marker" type which we'll use to know what type of
// error we're handling.
switch {
case introductionNode:
return &IntroductionErrorEncrypter{
ErrorEncrypter: encrypter,
}, errCode
case r.blindingKit.UpdateAddBlinding.IsSome():
return &RelayingErrorEncrypter{
ErrorEncrypter: encrypter,
}, errCode
default:
return encrypter, errCode
}
}
// BlindingProcessor is an interface that provides the cryptographic operations
// required for processing blinded hops.
//
// This interface is extracted to allow more granular testing of blinded
// forwarding calculations.
type BlindingProcessor interface {
// DecryptBlindedHopData decrypts a blinded blob of data using the
// ephemeral key provided.
DecryptBlindedHopData(ephemPub *btcec.PublicKey,
encryptedData []byte) ([]byte, error)
// NextEphemeral returns the next hop's ephemeral key, calculated
// from the current ephemeral key provided.
NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey, error)
}
// BlindingKit contains the components required to extract forwarding
// information for hops in a blinded route.
type BlindingKit struct {
// Processor provides the low-level cryptographic operations to
// handle an encrypted blob of data in a blinded forward.
Processor BlindingProcessor
// UpdateAddBlinding holds a blinding point that was passed to the
// node via update_add_htlc's TLVs.
UpdateAddBlinding lnwire.BlindingPointRecord
// IncomingCltv is the expiry of the incoming HTLC.
IncomingCltv uint32
// IncomingAmount is the amount of the incoming HTLC.
IncomingAmount lnwire.MilliSatoshi
}
// getBlindingPoint returns either the payload or updateAddHtlc blinding point,
// assuming that validation that these values are appropriately set has already
// been handled elsewhere.
func (b *BlindingKit) getBlindingPoint(payloadBlinding *btcec.PublicKey) (
*btcec.PublicKey, error) {
payloadBlindingSet := payloadBlinding != nil
updateBlindingSet := b.UpdateAddBlinding.IsSome()
switch {
case payloadBlindingSet:
return payloadBlinding, nil
case updateBlindingSet:
pk, err := b.UpdateAddBlinding.UnwrapOrErr(
fmt.Errorf("expected update add blinding"),
)
if err != nil {
return nil, err
}
return pk.Val, nil
default:
return nil, ErrNoBlindingPoint
}
}
// calculateForwardingAmount calculates the amount to forward for a blinded
// hop based on the incoming amount and forwarding parameters.
//
// When forwarding a payment, the fee we take is calculated, not on the
// incoming amount, but rather on the amount we forward. We charge fees based
// on our own liquidity we are forwarding downstream.
//
// With route blinding, we are NOT given the amount to forward. This
// unintuitive looking formula comes from the fact that without the amount to
// forward, we cannot compute the fees taken directly.
//
// The amount to be forwarded can be computed as follows:
//
// amt_to_forward = incoming_amount - total_fees
// total_fees = base_fee + amt_to_forward*(fee_rate/1000000)
//
// Solving for amount_to_forward:
// amt_to_forward = incoming_amount - base_fee - (amount_to_forward * fee_rate)/1e6
// amt_to_forward + (amount_to_forward * fee_rate) / 1e6 = incoming_amount - base_fee
// amt_to_forward * 1e6 + (amount_to_forward * fee_rate) = (incoming_amount - base_fee) * 1e6
// amt_to_forward * (1e6 + fee_rate) = (incoming_amount - base_fee) * 1e6
// amt_to_forward = ((incoming_amount - base_fee) * 1e6) / (1e6 + fee_rate)
//
// From there we use a ceiling formula for integer division so that we always
// round up, otherwise the sender may receive slightly less than intended:
//
// ceil(a/b) = (a + b - 1)/(b).
//
//nolint:ll,dupword
func calculateForwardingAmount(incomingAmount, baseFee lnwire.MilliSatoshi,
proportionalFee uint32) (lnwire.MilliSatoshi, error) {
// Sanity check to prevent overflow.
if incomingAmount < baseFee {
return 0, fmt.Errorf("incoming amount: %v < base fee: %v",
incomingAmount, baseFee)
}
numerator := (uint64(incomingAmount) - uint64(baseFee)) * 1e6
denominator := 1e6 + uint64(proportionalFee)
ceiling := (numerator + denominator - 1) / denominator
return lnwire.MilliSatoshi(ceiling), nil
}
// OnionProcessor is responsible for keeping all sphinx dependent parts inside
// and expose only decoding function. With such approach we give freedom for
// subsystems which wants to decode sphinx path to not be dependable from
// sphinx at all.
//
// NOTE: The reason for keeping decoder separated from hop iterator is too
// maintain the hop iterator abstraction. Without it the structures which using
// the hop iterator should contain sphinx router which makes their creations in
// tests dependent from the sphinx internal parts.
type OnionProcessor struct {
router *sphinx.Router
}
// NewOnionProcessor creates new instance of decoder.
func NewOnionProcessor(router *sphinx.Router) *OnionProcessor {
return &OnionProcessor{router}
}
// Start spins up the onion processor's sphinx router.
func (p *OnionProcessor) Start() error {
log.Info("Onion processor starting")
return p.router.Start()
}
// Stop shutsdown the onion processor's sphinx router.
func (p *OnionProcessor) Stop() error {
log.Info("Onion processor shutting down...")
defer log.Debug("Onion processor shutdown complete")
p.router.Stop()
return nil
}
// ReconstructBlindingInfo contains the information required to reconstruct a
// blinded onion.
type ReconstructBlindingInfo struct {
// BlindingKey is the blinding point set in UpdateAddHTLC.
BlindingKey lnwire.BlindingPointRecord
// IncomingAmt is the amount for the incoming HTLC.
IncomingAmt lnwire.MilliSatoshi
// IncomingExpiry is the expiry height of the incoming HTLC.
IncomingExpiry uint32
}
// ReconstructHopIterator attempts to decode a valid sphinx packet from the
// passed io.Reader instance using the rHash as the associated data when
// checking the relevant MACs during the decoding process.
func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte,
blindingInfo ReconstructBlindingInfo) (Iterator, error) {
onionPkt := &sphinx.OnionPacket{}
if err := onionPkt.Decode(r); err != nil {
return nil, err
}
var opts []sphinx.ProcessOnionOpt
blindingInfo.BlindingKey.WhenSome(func(
r tlv.RecordT[lnwire.BlindingPointTlvType, *btcec.PublicKey]) {
opts = append(opts, sphinx.WithBlindingPoint(r.Val))
})
// Attempt to process the Sphinx packet. We include the payment hash of
// the HTLC as it's authenticated within the Sphinx packet itself as
// associated data in order to thwart attempts a replay attacks. In the
// case of a replay, an attacker is *forced* to use the same payment
// hash twice, thereby losing their money entirely.
sphinxPacket, err := p.router.ReconstructOnionPacket(
onionPkt, rHash, opts...,
)
if err != nil {
return nil, err
}
return makeSphinxHopIterator(p.router, onionPkt, sphinxPacket,
BlindingKit{
Processor: p.router,
UpdateAddBlinding: blindingInfo.BlindingKey,
IncomingAmount: blindingInfo.IncomingAmt,
IncomingCltv: blindingInfo.IncomingExpiry,
}, rHash,
), nil
}
// DecodeHopIteratorRequest encapsulates all date necessary to process an onion
// packet, perform sphinx replay detection, and schedule the entry for garbage
// collection.
type DecodeHopIteratorRequest struct {
OnionReader io.Reader
RHash []byte
IncomingCltv uint32
IncomingAmount lnwire.MilliSatoshi
BlindingPoint lnwire.BlindingPointRecord
}
// DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion
// processing.
type DecodeHopIteratorResponse struct {
HopIterator Iterator
FailCode lnwire.FailCode
}
// Result returns the (HopIterator, lnwire.FailCode) tuple, which should
// correspond to the index of a particular DecodeHopIteratorRequest.
//
// NOTE: The HopIterator should be considered invalid if the fail code is
// anything but lnwire.CodeNone.
func (r *DecodeHopIteratorResponse) Result() (Iterator, lnwire.FailCode) {
return r.HopIterator, r.FailCode
}
// DecodeHopIterators performs batched decoding and validation of incoming
// sphinx packets. For the same `id`, this method will return the same iterators
// and failcodes upon subsequent invocations.
//
// NOTE: In order for the responses to be valid, the caller must guarantee that
// the presented readers and rhashes *NEVER* deviate across invocations for the
// same id.
func (p *OnionProcessor) DecodeHopIterators(id []byte,
reqs []DecodeHopIteratorRequest) ([]DecodeHopIteratorResponse, error) {
var (
batchSize = len(reqs)
onionPkts = make([]sphinx.OnionPacket, batchSize)
resps = make([]DecodeHopIteratorResponse, batchSize)
)
tx := p.router.BeginTxn(id, batchSize)
decode := func(seqNum uint16, onionPkt *sphinx.OnionPacket,
req DecodeHopIteratorRequest) lnwire.FailCode {
err := onionPkt.Decode(req.OnionReader)
switch err {
case nil:
// success
case sphinx.ErrInvalidOnionVersion:
return lnwire.CodeInvalidOnionVersion
case sphinx.ErrInvalidOnionKey:
return lnwire.CodeInvalidOnionKey
default:
log.Errorf("unable to decode onion packet: %v", err)
return lnwire.CodeInvalidOnionKey
}
var opts []sphinx.ProcessOnionOpt
req.BlindingPoint.WhenSome(func(
b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {
opts = append(opts, sphinx.WithBlindingPoint(
b.Val,
))
})
err = tx.ProcessOnionPacket(
seqNum, onionPkt, req.RHash, req.IncomingCltv, opts...,
)
switch err {
case nil:
// success
return lnwire.CodeNone
case sphinx.ErrInvalidOnionVersion:
return lnwire.CodeInvalidOnionVersion
case sphinx.ErrInvalidOnionHMAC:
return lnwire.CodeInvalidOnionHmac
case sphinx.ErrInvalidOnionKey:
return lnwire.CodeInvalidOnionKey
default:
log.Errorf("unable to process onion packet: %v", err)
return lnwire.CodeInvalidOnionKey
}
}
// Execute cpu-heavy onion decoding in parallel.
var wg sync.WaitGroup
for i := range reqs {
wg.Add(1)
go func(seqNum uint16) {
defer wg.Done()
onionPkt := &onionPkts[seqNum]
resps[seqNum].FailCode = decode(
seqNum, onionPkt, reqs[seqNum],
)
}(uint16(i))
}
wg.Wait()
// With that batch created, we will now attempt to write the shared
// secrets to disk. This operation will returns the set of indices that
// were detected as replays, and the computed sphinx packets for all
// indices that did not fail the above loop. Only indices that are not
// in the replay set should be considered valid, as they are
// opportunistically computed.
packets, replays, err := tx.Commit()
if err != nil {
log.Errorf("unable to process onion packet batch %x: %v",
id, err)
// If we failed to commit the batch to the secret share log, we
// will mark all not-yet-failed channels with a temporary
// channel failure and exit since we cannot proceed.
for i := range resps {
resp := &resps[i]
// Skip any indexes that already failed onion decoding.
if resp.FailCode != lnwire.CodeNone {
continue
}
log.Errorf("unable to process onion packet %x-%v",
id, i)
resp.FailCode = lnwire.CodeTemporaryChannelFailure
}
// TODO(conner): return real errors to caller so link can fail?
return resps, err
}
// Otherwise, the commit was successful. Now we will post process any
// remaining packets, additionally failing any that were included in the
// replay set.
for i := range resps {
resp := &resps[i]
// Skip any indexes that already failed onion decoding.
if resp.FailCode != lnwire.CodeNone {
continue
}
// If this index is contained in the replay set, mark it with a
// temporary channel failure error code. We infer that the
// offending error was due to a replayed packet because this
// index was found in the replay set.
if replays.Contains(uint16(i)) {
log.Errorf("unable to process onion packet: %v",
sphinx.ErrReplayedPacket)
// We set FailCode to CodeInvalidOnionVersion even
// though the ephemeral key isn't the problem. We need
// to set the BADONION bit since we're sending back a
// malformed packet, but as there isn't a specific
// failure code for replays, we reuse one of the
// failure codes that has BADONION.
resp.FailCode = lnwire.CodeInvalidOnionVersion
continue
}
// Finally, construct a hop iterator from our processed sphinx
// packet, simultaneously caching the original onion packet.
resp.HopIterator = makeSphinxHopIterator(
p.router, &onionPkts[i], &packets[i], BlindingKit{
Processor: p.router,
UpdateAddBlinding: reqs[i].BlindingPoint,
IncomingAmount: reqs[i].IncomingAmount,
IncomingCltv: reqs[i].IncomingCltv,
}, reqs[i].RHash,
)
}
return resps, nil
}
// ExtractErrorEncrypter takes an io.Reader which should contain the onion
// packet as original received by a forwarding node and creates an
// ErrorEncrypter instance using the derived shared secret. In the case that en
// error occurs, a lnwire failure code detailing the parsing failure will be
// returned.
func (p *OnionProcessor) ExtractErrorEncrypter(ephemeralKey *btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
onionObfuscator, err := sphinx.NewOnionErrorEncrypter(
p.router, ephemeralKey,
)
if err != nil {
switch err {
case sphinx.ErrInvalidOnionVersion:
return nil, lnwire.CodeInvalidOnionVersion
case sphinx.ErrInvalidOnionHMAC:
return nil, lnwire.CodeInvalidOnionHmac
case sphinx.ErrInvalidOnionKey:
return nil, lnwire.CodeInvalidOnionKey
default:
log.Errorf("unable to process onion packet: %v", err)
return nil, lnwire.CodeInvalidOnionKey
}
}
return &SphinxErrorEncrypter{
OnionErrorEncrypter: onionObfuscator,
EphemeralKey: ephemeralKey,
}, lnwire.CodeNone
}
package hop
import (
"github.com/btcsuite/btclog/v2"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// UseLogger uses a specified Logger to output package logging info. This
// function is called from the parent package htlcswitch logger initialization.
func UseLogger(logger btclog.Logger) {
log = logger
}
package hop
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
)
// PayloadViolation is an enum encapsulating the possible invalid payload
// violations that can occur when processing or validating a payload.
type PayloadViolation byte
const (
// OmittedViolation indicates that a type was expected to be found the
// payload but was absent.
OmittedViolation PayloadViolation = iota
// IncludedViolation indicates that a type was expected to be omitted
// from the payload but was present.
IncludedViolation
// RequiredViolation indicates that an unknown even type was found in
// the payload that we could not process.
RequiredViolation
// InsufficientViolation indicates that the provided type does
// not satisfy constraints.
InsufficientViolation
)
// String returns a human-readable description of the violation as a verb.
func (v PayloadViolation) String() string {
switch v {
case OmittedViolation:
return "omitted"
case IncludedViolation:
return "included"
case RequiredViolation:
return "required"
case InsufficientViolation:
return "insufficient"
default:
return "unknown violation"
}
}
// ErrInvalidPayload is an error returned when a parsed onion payload either
// included or omitted incorrect records for a particular hop type.
type ErrInvalidPayload struct {
// Type the record's type that cause the violation.
Type tlv.Type
// Violation is an enum indicating the type of violation detected in
// processing Type.
Violation PayloadViolation
// FinalHop if true, indicates that the violation is for the final hop
// in the route (identified by next hop id), otherwise the violation is
// for an intermediate hop.
FinalHop bool
}
// Error returns a human-readable description of the invalid payload error.
func (e ErrInvalidPayload) Error() string {
hopType := "intermediate"
if e.FinalHop {
hopType = "final"
}
return fmt.Sprintf("onion payload for %s hop %v record with type %d",
hopType, e.Violation, e.Type)
}
// Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be
// accessed by other member functions.
type Payload struct {
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
// amount, cltv, and next hop.
FwdInfo ForwardingInfo
// MPP holds the info provided in an option_mpp record when parsed from
// a TLV onion payload.
MPP *record.MPP
// AMP holds the info provided in an option_amp record when parsed from
// a TLV onion payload.
AMP *record.AMP
// customRecords are user-defined records in the custom type range that
// were included in the payload.
customRecords record.CustomSet
// encryptedData is a blob of data encrypted by the receiver for use
// in blinded routes.
encryptedData []byte
// blindingPoint is an ephemeral pubkey for use in blinded routes.
blindingPoint *btcec.PublicKey
// metadata is additional data that is sent along with the payment to
// the payee.
metadata []byte
// totalAmtMsat holds the info provided in total_amount_msat when
// parsed from a TLV onion payload.
totalAmtMsat lnwire.MilliSatoshi
}
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
// parameters provided by leegacy onion payloads.
func NewLegacyPayload(f *sphinx.HopData) *Payload {
nextHop := binary.BigEndian.Uint64(f.NextAddress[:])
return &Payload{
FwdInfo: ForwardingInfo{
NextHop: lnwire.NewShortChanIDFromInt(nextHop),
AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount),
OutgoingCTLV: f.OutgoingCltv,
},
customRecords: make(record.CustomSet),
}
}
// ParseTLVPayload builds a new Hop from the passed io.Reader and returns
// a map of all the types that were found in the payload. This function
// does not perform validation of TLV types included in the payload.
func ParseTLVPayload(r io.Reader) (*Payload, map[tlv.Type][]byte, error) {
var (
cid uint64
amt uint64
totalAmtMsat uint64
cltv uint32
mpp = &record.MPP{}
amp = &record.AMP{}
encryptedData []byte
blindingPoint *btcec.PublicKey
metadata []byte
)
tlvStream, err := tlv.NewStream(
record.NewAmtToFwdRecord(&amt),
record.NewLockTimeRecord(&cltv),
record.NewNextHopIDRecord(&cid),
mpp.Record(),
record.NewEncryptedDataRecord(&encryptedData),
record.NewBlindingPointRecord(&blindingPoint),
amp.Record(),
record.NewMetadataRecord(&metadata),
record.NewTotalAmtMsatBlinded(&totalAmtMsat),
)
if err != nil {
return nil, nil, err
}
// Since this data is provided by a potentially malicious peer, pass it
// into the P2P decoding variant.
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return nil, nil, err
}
// If no MPP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
mpp = nil
}
// If no AMP field was parsed, set the MPP field on the resulting
// payload to nil.
if _, ok := parsedTypes[record.AMPOnionType]; !ok {
amp = nil
}
// If no encrypted data was parsed, set the field on our resulting
// payload to nil.
if _, ok := parsedTypes[record.EncryptedDataOnionType]; !ok {
encryptedData = nil
}
// If no metadata field was parsed, set the metadata field on the
// resulting payload to nil.
if _, ok := parsedTypes[record.MetadataOnionType]; !ok {
metadata = nil
}
// Filter out the custom records.
customRecords := NewCustomRecords(parsedTypes)
return &Payload{
FwdInfo: ForwardingInfo{
NextHop: lnwire.NewShortChanIDFromInt(cid),
AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv,
},
MPP: mpp,
AMP: amp,
metadata: metadata,
encryptedData: encryptedData,
blindingPoint: blindingPoint,
customRecords: customRecords,
totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat),
}, parsedTypes, nil
}
// ValidateTLVPayload validates the TLV fields that were included in a TLV
// payload.
func ValidateTLVPayload(parsedTypes map[tlv.Type][]byte,
finalHop bool, updateAddBlinding bool) error {
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err := ValidateParsedPayloadTypes(
parsedTypes, finalHop, updateAddBlinding,
)
if err != nil {
return err
}
// Check for violation of the rules for mandatory fields.
violatingType := getMinRequiredViolation(parsedTypes)
if violatingType != nil {
return ErrInvalidPayload{
Type: *violatingType,
Violation: RequiredViolation,
FinalHop: finalHop,
}
}
return nil
}
// ForwardingInfo returns the basic parameters required for HTLC forwarding,
// e.g. amount, cltv, and next hop.
func (h *Payload) ForwardingInfo() ForwardingInfo {
return h.FwdInfo
}
// NewCustomRecords filters the types parsed from the tlv stream for custom
// records.
func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet {
customRecords := make(record.CustomSet)
for t, parseResult := range parsedTypes {
if parseResult == nil || t < record.CustomTypeStart {
continue
}
customRecords[uint64(t)] = parseResult
}
return customRecords
}
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
// ensure that the proper fields are either included or omitted. The finalHop
// boolean should be true if the payload was parsed for an exit hop. The
// requirements for this method are described in BOLT 04.
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
isFinalHop, updateAddBlinding bool) error {
_, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType]
_, hasMPP := parsedTypes[record.MPPOnionType]
_, hasAMP := parsedTypes[record.AMPOnionType]
_, hasEncryptedData := parsedTypes[record.EncryptedDataOnionType]
_, hasBlinding := parsedTypes[record.BlindingPointOnionType]
// All cleartext hops (including final hop) and the final hop in a
// blinded path require the forwading amount and expiry TLVs to be set.
needFwdInfo := isFinalHop || !hasEncryptedData
// No blinded hops should have a next hop specified, and only the final
// hop in a cleartext route should exclude it.
needNextHop := !(hasEncryptedData || isFinalHop)
switch {
// Both blinding point being set is invalid.
case hasBlinding && updateAddBlinding:
return ErrInvalidPayload{
Type: record.BlindingPointOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// If encrypted data is not provided, blinding points should not be
// set.
case !hasEncryptedData && (hasBlinding || updateAddBlinding):
return ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
// If encrypted data is present, we require that one blinding point
// is set.
case hasEncryptedData && !(hasBlinding || updateAddBlinding):
return ErrInvalidPayload{
Type: record.EncryptedDataOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// Hops that need forwarding info must include an amount to forward.
case needFwdInfo && !hasAmt:
return ErrInvalidPayload{
Type: record.AmtOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
// Hops that need forwarding info must include a cltv expiry.
case needFwdInfo && !hasLockTime:
return ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
// Hops that don't need forwarding info shouldn't have an amount TLV.
case !needFwdInfo && hasAmt:
return ErrInvalidPayload{
Type: record.AmtOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// Hops that don't need forwarding info shouldn't have a cltv TLV.
case !needFwdInfo && hasLockTime:
return ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// The exit hop and all blinded hops should omit the next hop id.
case !needNextHop && hasNextHop:
return ErrInvalidPayload{
Type: record.NextHopOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// Require that the next hop is set for intermediate hops in regular
// routes.
case needNextHop && !hasNextHop:
return ErrInvalidPayload{
Type: record.NextHopOnionType,
Violation: OmittedViolation,
FinalHop: isFinalHop,
}
// Intermediate nodes should never receive MPP fields.
case !isFinalHop && hasMPP:
return ErrInvalidPayload{
Type: record.MPPOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
// Intermediate nodes should never receive AMP fields.
case !isFinalHop && hasAMP:
return ErrInvalidPayload{
Type: record.AMPOnionType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
}
return nil
}
// MultiPath returns the record corresponding the option_mpp parsed from the
// onion payload.
func (h *Payload) MultiPath() *record.MPP {
return h.MPP
}
// AMPRecord returns the record corresponding with option_amp parsed from the
// onion payload.
func (h *Payload) AMPRecord() *record.AMP {
return h.AMP
}
// CustomRecords returns the custom tlv type records that were parsed from the
// payload.
func (h *Payload) CustomRecords() record.CustomSet {
return h.customRecords
}
// EncryptedData returns the route blinding encrypted data parsed from the
// onion payload.
func (h *Payload) EncryptedData() []byte {
return h.encryptedData
}
// BlindingPoint returns the route blinding point parsed from the onion payload.
func (h *Payload) BlindingPoint() *btcec.PublicKey {
return h.blindingPoint
}
// PathID returns the path ID that was encoded in the final hop payload of a
// blinded payment.
func (h *Payload) PathID() *chainhash.Hash {
return h.FwdInfo.PathID
}
// Metadata returns the additional data that is sent along with the
// payment to the payee.
func (h *Payload) Metadata() []byte {
return h.metadata
}
// TotalAmtMsat returns the total amount sent to the final hop, as set by the
// payee.
func (h *Payload) TotalAmtMsat() lnwire.MilliSatoshi {
return h.totalAmtMsat
}
// getMinRequiredViolation checks for unrecognized required (even) fields in the
// standard range and returns the lowest required type. Always returning the
// lowest required type allows a failure message to be deterministic.
func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type {
var (
requiredViolation bool
minRequiredViolationType tlv.Type
)
for t, parseResult := range set {
// If a type is even but not known to us, we cannot process the
// payload. We are required to understand a field that we don't
// support.
//
// We always accept custom fields, because a higher level
// application may understand them.
if parseResult == nil || t%2 != 0 ||
t >= record.CustomTypeStart {
continue
}
if !requiredViolation || t < minRequiredViolationType {
minRequiredViolationType = t
}
requiredViolation = true
}
if requiredViolation {
return &minRequiredViolationType
}
return nil
}
// ValidateBlindedRouteData performs the additional validation that is
// required for payments that rely on data provided in an encrypted blob to
// be forwarded. We enforce the blinded route's maximum expiry height so that
// the route "expires" and a malicious party does not have endless opportunity
// to probe the blinded route and compare it to updated channel policies in
// the network.
func ValidateBlindedRouteData(blindedData *record.BlindedRouteData,
incomingAmount lnwire.MilliSatoshi, incomingTimelock uint32) error {
// Bolt 04 notes that we should enforce payment constraints _if_ they
// are present, so we do not fail if not provided.
var err error
blindedData.Constraints.WhenSome(
func(c tlv.RecordT[tlv.TlvType12, record.PaymentConstraints]) {
// MUST fail if the expiry is greater than
// max_cltv_expiry.
if incomingTimelock > c.Val.MaxCltvExpiry {
err = ErrInvalidPayload{
Type: record.LockTimeOnionType,
Violation: InsufficientViolation,
}
}
// MUST fail if the amount is below htlc_minimum_msat.
if incomingAmount < c.Val.HtlcMinimumMsat {
err = ErrInvalidPayload{
Type: record.AmtOnionType,
Violation: InsufficientViolation,
}
}
},
)
if err != nil {
return err
}
// Fail if we don't understand any features (even or odd), because we
// expect the features to have been set from our announcement. If the
// feature vector TLV is not included, it's interpreted as an empty
// vector (no validation required).
// expect the features to have been set from our announcement.
//
// Note that we do not yet check the features that the blinded payment
// is using against our own features, because there are currently no
// payment-related features that they utilize other than tlv-onion,
// which is implicitly supported.
blindedData.Features.WhenSome(
func(f tlv.RecordT[tlv.TlvType14, lnwire.FeatureVector]) {
if f.Val.UnknownFeatures() {
err = ErrInvalidPayload{
Type: 14,
Violation: IncludedViolation,
}
}
},
)
if err != nil {
return err
}
return nil
}
// ValidatePayloadWithBlinded validates a payload against the contents of
// its encrypted data blob.
func ValidatePayloadWithBlinded(isFinalHop bool,
payloadParsed map[tlv.Type][]byte) error {
// Blinded routes restrict the presence of TLVs more strictly than
// regular routes, check that intermediate and final hops only have
// the TLVs the spec allows them to have.
allowedTLVs := map[tlv.Type]bool{
record.EncryptedDataOnionType: true,
record.BlindingPointOnionType: true,
}
if isFinalHop {
allowedTLVs[record.AmtOnionType] = true
allowedTLVs[record.LockTimeOnionType] = true
allowedTLVs[record.TotalAmtMsatBlindedType] = true
}
for tlvType := range payloadParsed {
if _, ok := allowedTLVs[tlvType]; ok {
continue
}
return ErrInvalidPayload{
Type: tlvType,
Violation: IncludedViolation,
FinalHop: isFinalHop,
}
}
return nil
}
package htlcswitch
import (
"fmt"
"strings"
"sync"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/subscribe"
)
// HtlcNotifier notifies clients of htlc forwards, failures and settles for
// htlcs that the switch handles. It takes subscriptions for its events and
// notifies them when htlc events occur. These are served on a best-effort
// basis; events are not persisted, delivery is not guaranteed (in the event
// of a crash in the switch, forward events may be lost) and some events may
// be replayed upon restart. Events consumed from this package should be
// de-duplicated by the htlc's unique combination of incoming+outgoing circuit
// and not relied upon for critical operations.
//
// The htlc notifier sends the following kinds of events:
// Forwarding Event:
// - Represents a htlc which is forwarded onward from our node.
// - Present for htlc forwards through our node and local sends.
//
// Link Failure Event:
// - Indicates that a htlc has failed on our incoming or outgoing link,
// with an incoming boolean which indicates where the failure occurred.
// - Incoming link failures are present for failed attempts to pay one of
// our invoices (insufficient amount or mpp timeout, for example) and for
// forwards that we cannot decode to forward onwards.
// - Outgoing link failures are present for forwards or local payments that
// do not meet our outgoing link's policy (insufficient fees, for example)
// and when we fail to forward the payment on (insufficient outgoing
// capacity, or an unknown outgoing link).
//
// Forwarding Failure Event:
// - Forwarding failures indicate that a htlc we forwarded has failed at
// another node down the route.
// - Present for local sends and htlc forwards which fail after they left
// our node.
//
// Settle event:
// - Settle events are present when a htlc which we added is settled through
// the release of a preimage.
// - Present for local receives, and successful local sends or forwards.
//
// Each htlc is identified by its incoming and outgoing circuit key. Htlcs,
// and their subsequent settles or fails, can be identified by the combination
// of incoming and outgoing circuits. Note that receives to our node will
// have a zero outgoing circuit key because the htlc terminates at our
// node, and sends from our node will have a zero incoming circuit key because
// the send originates at our node.
type HtlcNotifier struct {
started sync.Once
stopped sync.Once
// now returns the current time, it is set in the htlcnotifier to allow
// for timestamp mocking in tests.
now func() time.Time
ntfnServer *subscribe.Server
}
// NewHtlcNotifier creates a new HtlcNotifier which gets htlc forwarded,
// failed and settled events from links our node has established with peers
// and sends notifications to subscribing clients.
func NewHtlcNotifier(now func() time.Time) *HtlcNotifier {
return &HtlcNotifier{
now: now,
ntfnServer: subscribe.NewServer(),
}
}
// Start starts the HtlcNotifier and all goroutines it needs to consume
// events and provide subscriptions to clients.
func (h *HtlcNotifier) Start() error {
var err error
h.started.Do(func() {
log.Info("HtlcNotifier starting")
err = h.ntfnServer.Start()
})
return err
}
// Stop signals the notifier for a graceful shutdown.
func (h *HtlcNotifier) Stop() error {
var err error
h.stopped.Do(func() {
log.Info("HtlcNotifier shutting down...")
defer log.Debug("HtlcNotifier shutdown complete")
if err = h.ntfnServer.Stop(); err != nil {
log.Warnf("error stopping htlc notifier: %v", err)
}
})
return err
}
// SubscribeHtlcEvents returns a subscribe.Client that will receive updates
// any time the server is made aware of a new event.
func (h *HtlcNotifier) SubscribeHtlcEvents() (*subscribe.Client, error) {
return h.ntfnServer.Subscribe()
}
// HtlcKey uniquely identifies the htlc.
type HtlcKey struct {
// IncomingCircuit is the channel an htlc id of the incoming htlc.
IncomingCircuit models.CircuitKey
// OutgoingCircuit is the channel and htlc id of the outgoing htlc.
OutgoingCircuit models.CircuitKey
}
// String returns a string representation of a htlc key.
func (k HtlcKey) String() string {
switch {
case k.IncomingCircuit.ChanID == hop.Source:
return k.OutgoingCircuit.String()
case k.OutgoingCircuit.ChanID == hop.Exit:
return k.IncomingCircuit.String()
default:
return fmt.Sprintf("%v -> %v", k.IncomingCircuit,
k.OutgoingCircuit)
}
}
// HtlcInfo provides the details of a htlc that our node has processed. For
// forwards, incoming and outgoing values are set, whereas sends and receives
// will only have outgoing or incoming details set.
type HtlcInfo struct {
// IncomingTimelock is the time lock of the htlc on our incoming
// channel.
IncomingTimeLock uint32
// OutgoingTimelock is the time lock the htlc on our outgoing channel.
OutgoingTimeLock uint32
// IncomingAmt is the amount of the htlc on our incoming channel.
IncomingAmt lnwire.MilliSatoshi
// OutgoingAmt is the amount of the htlc on our outgoing channel.
OutgoingAmt lnwire.MilliSatoshi
}
// String returns a string representation of a htlc.
func (h HtlcInfo) String() string {
var details []string
// If the incoming information is not zero, as is the case for a send,
// we include the incoming amount and timelock.
if h.IncomingAmt != 0 || h.IncomingTimeLock != 0 {
str := fmt.Sprintf("incoming amount: %v, "+
"incoming timelock: %v", h.IncomingAmt,
h.IncomingTimeLock)
details = append(details, str)
}
// If the outgoing information is not zero, as is the case for a
// receive, we include the outgoing amount and timelock.
if h.OutgoingAmt != 0 || h.OutgoingTimeLock != 0 {
str := fmt.Sprintf("outgoing amount: %v, "+
"outgoing timelock: %v", h.OutgoingAmt,
h.OutgoingTimeLock)
details = append(details, str)
}
return strings.Join(details, ", ")
}
// HtlcEventType represents the type of event that a htlc was part of.
type HtlcEventType int
const (
// HtlcEventTypeSend represents a htlc that was part of a send from
// our node.
HtlcEventTypeSend HtlcEventType = iota
// HtlcEventTypeReceive represents a htlc that was part of a receive
// to our node.
HtlcEventTypeReceive
// HtlcEventTypeForward represents a htlc that was forwarded through
// our node.
HtlcEventTypeForward
)
// String returns a string representation of a htlc event type.
func (h HtlcEventType) String() string {
switch h {
case HtlcEventTypeSend:
return "send"
case HtlcEventTypeReceive:
return "receive"
case HtlcEventTypeForward:
return "forward"
default:
return "unknown"
}
}
// ForwardingEvent represents a htlc that was forwarded onwards from our node.
// Sends which originate from our node will report forward events with zero
// incoming circuits in their htlc key.
type ForwardingEvent struct {
// HtlcKey uniquely identifies the htlc, and can be used to match the
// forwarding event with subsequent settle/fail events.
HtlcKey
// HtlcInfo contains details about the htlc.
HtlcInfo
// HtlcEventType classifies the event as part of a local send or
// receive, or as part of a forward.
HtlcEventType
// Timestamp is the time when this htlc was forwarded.
Timestamp time.Time
}
// LinkFailEvent describes a htlc that failed on our incoming or outgoing
// link. The incoming bool is true for failures on incoming links, and false
// for failures on outgoing links. The failure reason is provided by a lnwire
// failure message which is enriched with a failure detail in the cases where
// the wire failure message does not contain full information about the
// failure.
type LinkFailEvent struct {
// HtlcKey uniquely identifies the htlc.
HtlcKey
// HtlcInfo contains details about the htlc.
HtlcInfo
// HtlcEventType classifies the event as part of a local send or
// receive, or as part of a forward.
HtlcEventType
// LinkError is the reason that we failed the htlc.
LinkError *LinkError
// Incoming is true if the htlc was failed on an incoming link.
// If it failed on the outgoing link, it is false.
Incoming bool
// Timestamp is the time when the link failure occurred.
Timestamp time.Time
}
// ForwardingFailEvent represents a htlc failure which occurred down the line
// after we forwarded a htlc onwards. An error is not included in this event
// because errors returned down the route are encrypted. HtlcInfo is not
// reliably available for forwarding failures, so it is omitted. These events
// should be matched with their corresponding forward event to obtain this
// information.
type ForwardingFailEvent struct {
// HtlcKey uniquely identifies the htlc, and can be used to match the
// htlc with its corresponding forwarding event.
HtlcKey
// HtlcEventType classifies the event as part of a local send or
// receive, or as part of a forward.
HtlcEventType
// Timestamp is the time when the forwarding failure was received.
Timestamp time.Time
}
// SettleEvent represents a htlc that was settled. HtlcInfo is not reliably
// available for forwarding failures, so it is omitted. These events should
// be matched with corresponding forward events or invoices (for receives)
// to obtain additional information about the htlc.
type SettleEvent struct {
// HtlcKey uniquely identifies the htlc, and can be used to match
// forwards with their corresponding forwarding event.
HtlcKey
// Preimage that was released for settling the htlc.
Preimage lntypes.Preimage
// HtlcEventType classifies the event as part of a local send or
// receive, or as part of a forward.
HtlcEventType
// Timestamp is the time when this htlc was settled.
Timestamp time.Time
}
type FinalHtlcEvent struct {
CircuitKey
Settled bool
// Offchain is indicating whether the htlc was resolved off-chain.
Offchain bool
// Timestamp is the time when this htlc was settled.
Timestamp time.Time
}
// NotifyForwardingEvent notifies the HtlcNotifier than a htlc has been
// forwarded.
//
// Note this is part of the htlcNotifier interface.
func (h *HtlcNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
eventType HtlcEventType) {
event := &ForwardingEvent{
HtlcKey: key,
HtlcInfo: info,
HtlcEventType: eventType,
Timestamp: h.now(),
}
log.Tracef("Notifying forward event: %v over %v, %v", eventType, key,
info)
if err := h.ntfnServer.SendUpdate(event); err != nil {
log.Warnf("Unable to send forwarding event: %v", err)
}
}
// NotifyLinkFailEvent notifies that a htlc has failed on our incoming
// or outgoing link.
//
// Note this is part of the htlcNotifier interface.
func (h *HtlcNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
eventType HtlcEventType, linkErr *LinkError, incoming bool) {
event := &LinkFailEvent{
HtlcKey: key,
HtlcInfo: info,
HtlcEventType: eventType,
LinkError: linkErr,
Incoming: incoming,
Timestamp: h.now(),
}
log.Tracef("Notifying link failure event: %v over %v, %v", eventType,
key, info)
if err := h.ntfnServer.SendUpdate(event); err != nil {
log.Warnf("Unable to send link fail event: %v", err)
}
}
// NotifyForwardingFailEvent notifies the HtlcNotifier that a htlc we
// forwarded has failed down the line.
//
// Note this is part of the htlcNotifier interface.
func (h *HtlcNotifier) NotifyForwardingFailEvent(key HtlcKey,
eventType HtlcEventType) {
event := &ForwardingFailEvent{
HtlcKey: key,
HtlcEventType: eventType,
Timestamp: h.now(),
}
log.Tracef("Notifying forwarding failure event: %v over %v", eventType,
key)
if err := h.ntfnServer.SendUpdate(event); err != nil {
log.Warnf("Unable to send forwarding fail event: %v", err)
}
}
// NotifySettleEvent notifies the HtlcNotifier that a htlc that we committed
// to as part of a forward or a receive to our node has been settled.
//
// Note this is part of the htlcNotifier interface.
func (h *HtlcNotifier) NotifySettleEvent(key HtlcKey,
preimage lntypes.Preimage, eventType HtlcEventType) {
event := &SettleEvent{
HtlcKey: key,
Preimage: preimage,
HtlcEventType: eventType,
Timestamp: h.now(),
}
log.Tracef("Notifying settle event: %v over %v", eventType, key)
if err := h.ntfnServer.SendUpdate(event); err != nil {
log.Warnf("Unable to send settle event: %v", err)
}
}
// NotifyFinalHtlcEvent notifies the HtlcNotifier that the final outcome for an
// htlc has been determined.
//
// Note this is part of the htlcNotifier interface.
func (h *HtlcNotifier) NotifyFinalHtlcEvent(key models.CircuitKey,
info channeldb.FinalHtlcInfo) {
event := &FinalHtlcEvent{
CircuitKey: key,
Settled: info.Settled,
Offchain: info.Offchain,
Timestamp: h.now(),
}
log.Tracef("Notifying final settle event: %v", key)
if err := h.ntfnServer.SendUpdate(event); err != nil {
log.Warnf("Unable to send settle event: %v", err)
}
}
// newHtlc key returns a htlc key for the packet provided. If the packet
// has a zero incoming channel ID, the packet is for one of our own sends,
// which has the payment id stashed in the incoming htlc id. If this is the
// case, we replace the incoming htlc id with zero so that the notifier
// consistently reports zero circuit keys for events that terminate or
// originate at our node.
func newHtlcKey(pkt *htlcPacket) HtlcKey {
htlcKey := HtlcKey{
IncomingCircuit: models.CircuitKey{
ChanID: pkt.incomingChanID,
HtlcID: pkt.incomingHTLCID,
},
OutgoingCircuit: CircuitKey{
ChanID: pkt.outgoingChanID,
HtlcID: pkt.outgoingHTLCID,
},
}
// If the packet has a zero incoming channel ID, it is a send that was
// initiated at our node. If this is the case, our internal pid is in
// the incoming htlc ID, so we overwrite it with 0 for notification
// purposes.
if pkt.incomingChanID == hop.Source {
htlcKey.IncomingCircuit.HtlcID = 0
}
return htlcKey
}
// newHtlcInfo returns HtlcInfo for the packet provided.
func newHtlcInfo(pkt *htlcPacket) HtlcInfo {
return HtlcInfo{
IncomingTimeLock: pkt.incomingTimeout,
OutgoingTimeLock: pkt.outgoingTimeout,
IncomingAmt: pkt.incomingAmount,
OutgoingAmt: pkt.amount,
}
}
// getEventType returns the htlc type based on the fields set in the htlc
// packet. Sends that originate at our node have the source (zero) incoming
// channel ID. Receives to our node have the exit (zero) outgoing channel ID
// and forwards have both fields set.
func getEventType(pkt *htlcPacket) HtlcEventType {
switch {
case pkt.incomingChanID == hop.Source:
return HtlcEventTypeSend
case pkt.outgoingChanID == hop.Exit:
return HtlcEventTypeReceive
default:
return HtlcEventTypeForward
}
}
package htlcswitch
import (
"crypto/sha256"
"fmt"
"sync"
"sync/atomic"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrFwdNotExists is an error returned when the caller tries to resolve
// a forward that doesn't exist anymore.
ErrFwdNotExists = errors.New("forward does not exist")
// ErrUnsupportedFailureCode when processing of an unsupported failure
// code is attempted.
ErrUnsupportedFailureCode = errors.New("unsupported failure code")
errBlockStreamStopped = errors.New("block epoch stream stopped")
)
// InterceptableSwitch is an implementation of ForwardingSwitch interface.
// This implementation is used like a proxy that wraps the switch and
// intercepts forward requests. A reference to the Switch is held in order
// to communicate back the interception result where the options are:
// Resume - forwards the original request to the switch as is.
// Settle - routes UpdateFulfillHTLC to the originating link.
// Fail - routes UpdateFailHTLC to the originating link.
type InterceptableSwitch struct {
started atomic.Bool
stopped atomic.Bool
// htlcSwitch is the underline switch
htlcSwitch *Switch
// intercepted is where we stream all intercepted packets coming from
// the switch.
intercepted chan *interceptedPackets
// resolutionChan is where we stream all responses coming from the
// interceptor client.
resolutionChan chan *fwdResolution
onchainIntercepted chan InterceptedForward
// interceptorRegistration is a channel that we use to synchronize
// client connect and disconnect.
interceptorRegistration chan ForwardInterceptor
// requireInterceptor indicates whether processing should block if no
// interceptor is connected.
requireInterceptor bool
// interceptor is the handler for intercepted packets.
interceptor ForwardInterceptor
// heldHtlcSet keeps track of outstanding intercepted forwards.
heldHtlcSet *heldHtlcSet
// cltvRejectDelta defines the number of blocks before the expiry of the
// htlc where we no longer intercept it and instead cancel it back.
cltvRejectDelta uint32
// cltvInterceptDelta defines the number of blocks before the expiry of
// the htlc where we don't intercept anymore. This value must be greater
// than CltvRejectDelta, because we don't want to offer htlcs to the
// interceptor client for which there is no time left to resolve them
// anymore.
cltvInterceptDelta uint32
// notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
notifier chainntnfs.ChainNotifier
// blockEpochStream is an active block epoch event stream backed by an
// active ChainNotifier instance. This will be used to retrieve the
// latest height of the chain.
blockEpochStream *chainntnfs.BlockEpochEvent
// currentHeight is the currently best known height.
currentHeight int32
wg sync.WaitGroup
quit chan struct{}
}
type interceptedPackets struct {
packets []*htlcPacket
linkQuit <-chan struct{}
isReplay bool
}
// FwdAction defines the various resolution types.
type FwdAction int
const (
// FwdActionResume forwards the intercepted packet to the switch.
FwdActionResume FwdAction = iota
// FwdActionSettle settles the intercepted packet with a preimage.
FwdActionSettle
// FwdActionFail fails the intercepted packet back to the sender.
FwdActionFail
// FwdActionResumeModified forwards the intercepted packet to the switch
// with modifications.
FwdActionResumeModified
)
// FwdResolution defines the action to be taken on an intercepted packet.
type FwdResolution struct {
// Key is the incoming circuit key of the htlc.
Key models.CircuitKey
// Action is the action to take on the intercepted htlc.
Action FwdAction
// Preimage is the preimage that is to be used for settling if Action is
// FwdActionSettle.
Preimage lntypes.Preimage
// InAmountMsat is the amount that is to be used for validating if
// Action is FwdActionResumeModified.
InAmountMsat fn.Option[lnwire.MilliSatoshi]
// OutAmountMsat is the amount that is to be used for forwarding if
// Action is FwdActionResumeModified.
OutAmountMsat fn.Option[lnwire.MilliSatoshi]
// OutWireCustomRecords is the custom records that are to be used for
// forwarding if Action is FwdActionResumeModified.
OutWireCustomRecords fn.Option[lnwire.CustomRecords]
// FailureMessage is the encrypted failure message that is to be passed
// back to the sender if action is FwdActionFail.
FailureMessage []byte
// FailureCode is the failure code that is to be passed back to the
// sender if action is FwdActionFail.
FailureCode lnwire.FailCode
}
type fwdResolution struct {
resolution *FwdResolution
errChan chan error
}
// InterceptableSwitchConfig contains the configuration of InterceptableSwitch.
type InterceptableSwitchConfig struct {
// Switch is a reference to the actual switch implementation that
// packets get sent to on resume.
Switch *Switch
// Notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
Notifier chainntnfs.ChainNotifier
// CltvRejectDelta defines the number of blocks before the expiry of the
// htlc where we auto-fail an intercepted htlc to prevent channel
// force-closure.
CltvRejectDelta uint32
// CltvInterceptDelta defines the number of blocks before the expiry of
// the htlc where we don't intercept anymore. This value must be greater
// than CltvRejectDelta, because we don't want to offer htlcs to the
// interceptor client for which there is no time left to resolve them
// anymore.
CltvInterceptDelta uint32
// RequireInterceptor indicates whether processing should block if no
// interceptor is connected.
RequireInterceptor bool
}
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) (
*InterceptableSwitch, error) {
if cfg.CltvInterceptDelta <= cfg.CltvRejectDelta {
return nil, fmt.Errorf("cltv intercept delta %v not greater "+
"than cltv reject delta %v",
cfg.CltvInterceptDelta, cfg.CltvRejectDelta)
}
return &InterceptableSwitch{
htlcSwitch: cfg.Switch,
intercepted: make(chan *interceptedPackets),
onchainIntercepted: make(chan InterceptedForward),
interceptorRegistration: make(chan ForwardInterceptor),
heldHtlcSet: newHeldHtlcSet(),
resolutionChan: make(chan *fwdResolution),
requireInterceptor: cfg.RequireInterceptor,
cltvRejectDelta: cfg.CltvRejectDelta,
cltvInterceptDelta: cfg.CltvInterceptDelta,
notifier: cfg.Notifier,
quit: make(chan struct{}),
}, nil
}
// SetInterceptor sets the ForwardInterceptor to be used. A nil argument
// unregisters the current interceptor.
func (s *InterceptableSwitch) SetInterceptor(
interceptor ForwardInterceptor) {
// Synchronize setting the handler with the main loop to prevent race
// conditions.
select {
case s.interceptorRegistration <- interceptor:
case <-s.quit:
}
}
func (s *InterceptableSwitch) Start() error {
log.Info("InterceptableSwitch starting...")
if s.started.Swap(true) {
return fmt.Errorf("InterceptableSwitch started more than once")
}
blockEpochStream, err := s.notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
}
s.blockEpochStream = blockEpochStream
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := s.run()
if err != nil {
log.Errorf("InterceptableSwitch stopped: %v", err)
}
}()
log.Debug("InterceptableSwitch started")
return nil
}
func (s *InterceptableSwitch) Stop() error {
log.Info("InterceptableSwitch shutting down...")
if s.stopped.Swap(true) {
return fmt.Errorf("InterceptableSwitch stopped more than once")
}
close(s.quit)
s.wg.Wait()
// We need to check whether the start routine run and initialized the
// `blockEpochStream`.
if s.blockEpochStream != nil {
s.blockEpochStream.Cancel()
}
log.Debug("InterceptableSwitch shutdown complete")
return nil
}
func (s *InterceptableSwitch) run() error {
// The block epoch stream will immediately stream the current height.
// Read it out here.
select {
case currentBlock, ok := <-s.blockEpochStream.Epochs:
if !ok {
return errBlockStreamStopped
}
s.currentHeight = currentBlock.Height
case <-s.quit:
return nil
}
log.Debugf("InterceptableSwitch running: height=%v, "+
"requireInterceptor=%v", s.currentHeight, s.requireInterceptor)
for {
select {
// An interceptor registration or de-registration came in.
case interceptor := <-s.interceptorRegistration:
s.setInterceptor(interceptor)
case packets := <-s.intercepted:
var notIntercepted []*htlcPacket
for _, p := range packets.packets {
intercepted, err := s.interceptForward(
p, packets.isReplay,
)
if err != nil {
return err
}
if !intercepted {
notIntercepted = append(
notIntercepted, p,
)
}
}
err := s.htlcSwitch.ForwardPackets(
packets.linkQuit, notIntercepted...,
)
if err != nil {
log.Errorf("Cannot forward packets: %v", err)
}
case fwd := <-s.onchainIntercepted:
// For on-chain interceptions, we don't know if it has
// already been offered before. This information is in
// the forwarding package which isn't easily accessible
// from contractcourt. It is likely though that it was
// already intercepted in the off-chain flow. And even
// if not, it is safe to signal replay so that we won't
// unexpectedly skip over this htlc.
if _, err := s.forward(fwd, true); err != nil {
return err
}
case res := <-s.resolutionChan:
res.errChan <- s.resolve(res.resolution)
case currentBlock, ok := <-s.blockEpochStream.Epochs:
if !ok {
return errBlockStreamStopped
}
s.currentHeight = currentBlock.Height
// A new block is appended. Fail any held htlcs that
// expire at this height to prevent channel force-close.
s.failExpiredHtlcs()
case <-s.quit:
return nil
}
}
}
func (s *InterceptableSwitch) failExpiredHtlcs() {
s.heldHtlcSet.popAutoFails(
uint32(s.currentHeight),
func(fwd InterceptedForward) {
err := fwd.FailWithCode(
lnwire.CodeTemporaryChannelFailure,
)
if err != nil {
log.Errorf("Cannot fail packet: %v", err)
}
},
)
}
func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) {
err := s.interceptor(fwd.Packet())
if err != nil {
// Only log the error. If we couldn't send the packet, we assume
// that the interceptor will reconnect so that we can retry.
log.Debugf("Interceptor cannot handle forward: %v", err)
}
}
func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
s.interceptor = interceptor
// Replay all currently held htlcs. When an interceptor is not required,
// there may be none because they've been cleared after the previous
// disconnect.
if interceptor != nil {
log.Debugf("Interceptor connected")
s.heldHtlcSet.forEach(s.sendForward)
return
}
// The interceptor disconnects. If an interceptor is required, keep the
// held htlcs.
if s.requireInterceptor {
log.Infof("Interceptor disconnected, retaining held packets")
return
}
// Interceptor is not required. Release held forwards.
log.Infof("Interceptor disconnected, resolving held packets")
s.heldHtlcSet.popAll(func(fwd InterceptedForward) {
err := fwd.Resume()
if err != nil {
log.Errorf("Failed to resume hold forward %v", err)
}
})
}
// resolve processes a HTLC given the resolution type specified by the
// intercepting client.
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
intercepted, err := s.heldHtlcSet.pop(res.Key)
if err != nil {
return err
}
switch res.Action {
case FwdActionResume:
return intercepted.Resume()
case FwdActionResumeModified:
return intercepted.ResumeModified(
res.InAmountMsat, res.OutAmountMsat,
res.OutWireCustomRecords,
)
case FwdActionSettle:
return intercepted.Settle(res.Preimage)
case FwdActionFail:
if len(res.FailureMessage) > 0 {
return intercepted.Fail(res.FailureMessage)
}
return intercepted.FailWithCode(res.FailureCode)
default:
return fmt.Errorf("unrecognized action %v", res.Action)
}
}
// Resolve resolves an intercepted packet.
func (s *InterceptableSwitch) Resolve(res *FwdResolution) error {
internalRes := &fwdResolution{
resolution: res,
errChan: make(chan error, 1),
}
select {
case s.resolutionChan <- internalRes:
case <-s.quit:
return errors.New("switch shutting down")
}
select {
case err := <-internalRes.errChan:
return err
case <-s.quit:
return errors.New("switch shutting down")
}
}
// ForwardPackets attempts to forward the batch of htlcs to a connected
// interceptor. If the interceptor signals the resume action, the htlcs are
// forwarded to the switch. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
func (s *InterceptableSwitch) ForwardPackets(linkQuit <-chan struct{},
isReplay bool, packets ...*htlcPacket) error {
// Synchronize with the main event loop. This should be light in the
// case where there is no interceptor.
select {
case s.intercepted <- &interceptedPackets{
packets: packets,
linkQuit: linkQuit,
isReplay: isReplay,
}:
case <-linkQuit:
log.Debugf("Forward cancelled because link quit")
case <-s.quit:
return errors.New("interceptable switch quit")
}
return nil
}
// ForwardPacket forwards a single htlc to the external interceptor.
func (s *InterceptableSwitch) ForwardPacket(
fwd InterceptedForward) error {
select {
case s.onchainIntercepted <- fwd:
case <-s.quit:
return errors.New("interceptable switch quit")
}
return nil
}
// interceptForward forwards the packet to the external interceptor after
// checking the interception criteria.
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
isReplay bool) (bool, error) {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
// We are not interested in intercepting initiated payments.
if packet.incomingChanID == hop.Source {
return false, nil
}
intercepted := &interceptedForward{
htlc: htlc,
packet: packet,
htlcSwitch: s.htlcSwitch,
autoFailHeight: int32(packet.incomingTimeout -
s.cltvRejectDelta),
}
// Handle forwards that are too close to expiry.
handled, err := s.handleExpired(intercepted)
if err != nil {
log.Errorf("Error handling intercepted htlc "+
"that expires too soon: circuit=%v, "+
"incoming_timeout=%v, err=%v",
packet.inKey(), packet.incomingTimeout, err)
// Return false so that the packet is offered as normal
// to the switch. This isn't ideal because interception
// may be configured as always-on and is skipped now.
// Returning true isn't great either, because the htlc
// will remain stuck and potentially force-close the
// channel. But in the end, we should never get here, so
// the actual return value doesn't matter that much.
return false, nil
}
if handled {
return true, nil
}
return s.forward(intercepted, isReplay)
default:
return false, nil
}
}
// forward records the intercepted htlc and forwards it to the interceptor.
func (s *InterceptableSwitch) forward(
fwd InterceptedForward, isReplay bool) (bool, error) {
inKey := fwd.Packet().IncomingCircuit
// Ignore already held htlcs.
if s.heldHtlcSet.exists(inKey) {
return true, nil
}
// If there is no interceptor currently registered, configuration and packet
// replay status determine how the packet is handled.
if s.interceptor == nil {
// Process normally if an interceptor is not required.
if !s.requireInterceptor {
return false, nil
}
// We are in interceptor-required mode. If this is a new packet, it is
// still safe to fail back. The interceptor has never seen this packet
// yet. This limits the backlog of htlcs when the interceptor is down.
if !isReplay {
err := fwd.FailWithCode(
lnwire.CodeTemporaryChannelFailure,
)
if err != nil {
log.Errorf("Cannot fail packet: %v", err)
}
return true, nil
}
// This packet is a replay. It is not safe to fail back, because the
// interceptor may still signal otherwise upon reconnect. Keep the
// packet in the queue until then.
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
return false, err
}
return true, nil
}
// There is an interceptor registered. We can forward the packet right now.
// Hold it in the queue too to track what is outstanding.
if err := s.heldHtlcSet.push(inKey, fwd); err != nil {
return false, err
}
s.sendForward(fwd)
return true, nil
}
// handleExpired checks that the htlc isn't too close to the channel
// force-close broadcast height. If it is, it is cancelled back.
func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
bool, error) {
height := uint32(s.currentHeight)
if fwd.packet.incomingTimeout >= height+s.cltvInterceptDelta {
return false, nil
}
log.Debugf("Interception rejected because htlc "+
"expires too soon: circuit=%v, "+
"height=%v, incoming_timeout=%v",
fwd.packet.inKey(), height,
fwd.packet.incomingTimeout)
err := fwd.FailWithCode(
lnwire.CodeExpiryTooSoon,
)
if err != nil {
return false, err
}
return true, nil
}
// interceptedForward implements the InterceptedForward interface.
// It is passed from the switch to external interceptors that are interested
// in holding forwards and resolve them manually.
type interceptedForward struct {
htlc *lnwire.UpdateAddHTLC
packet *htlcPacket
htlcSwitch *Switch
autoFailHeight int32
}
// Packet returns the intercepted htlc packet.
func (f *interceptedForward) Packet() InterceptedPacket {
return InterceptedPacket{
IncomingCircuit: models.CircuitKey{
ChanID: f.packet.incomingChanID,
HtlcID: f.packet.incomingHTLCID,
},
OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout,
InOnionCustomRecords: f.packet.inOnionCustomRecords,
OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight,
InWireCustomRecords: f.packet.inWireCustomRecords,
}
}
// Resume resumes the default behavior as if the packet was not intercepted.
func (f *interceptedForward) Resume() error {
// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
}
// ResumeModified resumes the default behavior with field modifications. The
// input amount (if provided) specifies that the value of the inbound HTLC
// should be interpreted differently from the on-chain amount during further
// validation. The presence of an output amount and/or custom records indicates
// that those values should be modified on the outgoing HTLC.
func (f *interceptedForward) ResumeModified(
inAmountMsat fn.Option[lnwire.MilliSatoshi],
outAmountMsat fn.Option[lnwire.MilliSatoshi],
outWireCustomRecords fn.Option[lnwire.CustomRecords]) error {
// Convert the optional custom records to the correct type and validate
// them.
validatedRecords, err := fn.MapOptionZ(
outWireCustomRecords,
func(cr lnwire.CustomRecords) fn.Result[lnwire.CustomRecords] {
if len(cr) == 0 {
return fn.Ok[lnwire.CustomRecords](nil)
}
// Type cast and validate custom records.
err := cr.Validate()
if err != nil {
return fn.Err[lnwire.CustomRecords](
fmt.Errorf("failed to validate "+
"custom records: %w", err),
)
}
return fn.Ok(cr)
},
).Unpack()
if err != nil {
return fmt.Errorf("failed to encode custom records: %w",
err)
}
// Set the incoming amount, if it is provided, on the packet.
inAmountMsat.WhenSome(func(amount lnwire.MilliSatoshi) {
f.packet.incomingAmount = amount
})
// Modify the wire message contained in the packet.
switch htlc := f.packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
outAmountMsat.WhenSome(func(amount lnwire.MilliSatoshi) {
f.packet.amount = amount
htlc.Amount = amount
})
// Merge custom records with any validated records that were
// added in the modify request, overwriting any existing values
// with those supplied in the modifier API.
htlc.CustomRecords = htlc.CustomRecords.MergedCopy(
validatedRecords,
)
case *lnwire.UpdateFulfillHTLC:
if len(validatedRecords) > 0 {
htlc.CustomRecords = validatedRecords
}
}
log.Tracef("Forwarding packet %v", lnutils.SpewLogClosure(f.packet))
// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
}
// Fail notifies the intention to Fail an existing hold forward with an
// encrypted failure reason.
func (f *interceptedForward) Fail(reason []byte) error {
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)
return f.resolve(&lnwire.UpdateFailHTLC{
Reason: obfuscatedReason,
})
}
// FailWithCode notifies the intention to fail an existing hold forward with the
// specified failure code.
func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {
shaOnionBlob := func() [32]byte {
return sha256.Sum256(f.htlc.OnionBlob[:])
}
// Create a local failure.
var failureMsg lnwire.FailureMessage
switch code {
case lnwire.CodeInvalidOnionVersion:
failureMsg = &lnwire.FailInvalidOnionVersion{
OnionSHA256: shaOnionBlob(),
}
case lnwire.CodeInvalidOnionHmac:
failureMsg = &lnwire.FailInvalidOnionHmac{
OnionSHA256: shaOnionBlob(),
}
case lnwire.CodeInvalidOnionKey:
failureMsg = &lnwire.FailInvalidOnionKey{
OnionSHA256: shaOnionBlob(),
}
case lnwire.CodeTemporaryChannelFailure:
update := f.htlcSwitch.failAliasUpdate(
f.packet.incomingChanID, true,
)
if update == nil {
// Fallback to the original, non-alias behavior.
var err error
update, err = f.htlcSwitch.cfg.FetchLastChannelUpdate(
f.packet.incomingChanID,
)
if err != nil {
return err
}
}
failureMsg = lnwire.NewTemporaryChannelFailure(update)
case lnwire.CodeExpiryTooSoon:
update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate(
f.packet.incomingChanID,
)
if err != nil {
return err
}
failureMsg = lnwire.NewExpiryTooSoon(*update)
default:
return ErrUnsupportedFailureCode
}
// Encrypt the failure for the first hop. This node will be the origin
// of the failure.
reason, err := f.packet.obfuscator.EncryptFirstHop(failureMsg)
if err != nil {
return fmt.Errorf("failed to encrypt failure reason %w", err)
}
return f.resolve(&lnwire.UpdateFailHTLC{
Reason: reason,
})
}
// Settle forwards a settled packet to the switch.
func (f *interceptedForward) Settle(preimage lntypes.Preimage) error {
if !preimage.Matches(f.htlc.PaymentHash) {
return errors.New("preimage does not match hash")
}
return f.resolve(&lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
})
}
// resolve is used for both Settle and Fail and forwards the message to the
// switch.
func (f *interceptedForward) resolve(message lnwire.Message) error {
pkt := &htlcPacket{
incomingChanID: f.packet.incomingChanID,
incomingHTLCID: f.packet.incomingHTLCID,
outgoingChanID: f.packet.outgoingChanID,
outgoingHTLCID: f.packet.outgoingHTLCID,
isResolution: true,
circuit: f.packet.circuit,
htlc: message,
obfuscator: f.packet.obfuscator,
sourceRef: f.packet.sourceRef,
}
return f.htlcSwitch.mailOrchestrator.Deliver(pkt.incomingChanID, pkt)
}
package htlcswitch
import (
"bytes"
"context"
crand "crypto/rand"
"crypto/sha256"
"errors"
"fmt"
prand "math/rand"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
)
func init() {
prand.Seed(time.Now().UnixNano())
}
const (
// DefaultMaxOutgoingCltvExpiry is the maximum outgoing time lock that
// the node accepts for forwarded payments. The value is relative to the
// current block height. The reason to have a maximum is to prevent
// funds getting locked up unreasonably long. Otherwise, an attacker
// willing to lock its own funds too, could force the funds of this node
// to be locked up for an indefinite (max int32) number of blocks.
//
// The value 2016 corresponds to on average two weeks worth of blocks
// and is based on the maximum number of hops (20), the default CLTV
// delta (40), and some extra margin to account for the other lightning
// implementations and past lnd versions which used to have a default
// CLTV delta of 144.
DefaultMaxOutgoingCltvExpiry = 2016
// DefaultMinLinkFeeUpdateTimeout represents the minimum interval in
// which a link should propose to update its commitment fee rate.
DefaultMinLinkFeeUpdateTimeout = 10 * time.Minute
// DefaultMaxLinkFeeUpdateTimeout represents the maximum interval in
// which a link should propose to update its commitment fee rate.
DefaultMaxLinkFeeUpdateTimeout = 60 * time.Minute
// DefaultMaxLinkFeeAllocation is the highest allocation we'll allow
// a channel's commitment fee to be of its balance. This only applies to
// the initiator of the channel.
DefaultMaxLinkFeeAllocation float64 = 0.5
)
// ExpectedFee computes the expected fee for a given htlc amount. The value
// returned from this function is to be used as a sanity check when forwarding
// HTLC's to ensure that an incoming HTLC properly adheres to our propagated
// forwarding policy.
//
// TODO(roasbeef): also add in current available channel bandwidth, inverse
// func
func ExpectedFee(f models.ForwardingPolicy,
htlcAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return f.BaseFee + (htlcAmt*f.FeeRate)/1000000
}
// ChannelLinkConfig defines the configuration for the channel link. ALL
// elements within the configuration MUST be non-nil for channel link to carry
// out its duties.
type ChannelLinkConfig struct {
// FwrdingPolicy is the initial forwarding policy to be used when
// deciding whether to forwarding incoming HTLC's or not. This value
// can be updated with subsequent calls to UpdateForwardingPolicy
// targeted at a given ChannelLink concrete interface implementation.
FwrdingPolicy models.ForwardingPolicy
// Circuits provides restricted access to the switch's circuit map,
// allowing the link to open and close circuits.
Circuits CircuitModifier
// BestHeight returns the best known height.
BestHeight func() uint32
// ForwardPackets attempts to forward the batch of htlcs through the
// switch. The function returns and error in case it fails to send one or
// more packets. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
ForwardPackets func(<-chan struct{}, bool, ...*htlcPacket) error
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
// blobs, which are then used to inform how to forward an HTLC.
//
// NOTE: This function assumes the same set of readers and preimages
// are always presented for the same identifier.
DecodeHopIterators func([]byte, []hop.DecodeHopIteratorRequest) (
[]hop.DecodeHopIteratorResponse, error)
// ExtractErrorEncrypter function is responsible for decoding HTLC
// Sphinx onion blob, and creating onion failure obfuscator.
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
// FetchLastChannelUpdate retrieves the latest routing policy for a
// target channel. This channel will typically be the outgoing channel
// specified when we receive an incoming HTLC. This will be used to
// provide payment senders our latest policy when sending encrypted
// error messages.
FetchLastChannelUpdate func(lnwire.ShortChannelID) (
*lnwire.ChannelUpdate1, error)
// Peer is a lightning network node with which we have the channel link
// opened.
Peer lnpeer.Peer
// Registry is a sub-system which responsible for managing the invoices
// in thread-safe manner.
Registry InvoiceDatabase
// PreimageCache is a global witness beacon that houses any new
// preimages discovered by other links. We'll use this to add new
// witnesses that we discover which will notify any sub-systems
// subscribed to new events.
PreimageCache contractcourt.WitnessBeacon
// OnChannelFailure is a function closure that we'll call if the
// channel failed for some reason. Depending on the severity of the
// error, the closure potentially must force close this channel and
// disconnect the peer.
//
// NOTE: The method must return in order for the ChannelLink to be able
// to shut down properly.
OnChannelFailure func(lnwire.ChannelID, lnwire.ShortChannelID,
LinkFailureError)
// UpdateContractSignals is a function closure that we'll use to update
// outside sub-systems with this channel's latest ShortChannelID.
UpdateContractSignals func(*contractcourt.ContractSignals) error
// NotifyContractUpdate is a function closure that we'll use to update
// the contractcourt and more specifically the ChannelArbitrator of the
// latest channel state.
NotifyContractUpdate func(*contractcourt.ContractUpdate) error
// ChainEvents is an active subscription to the chain watcher for this
// channel to be notified of any on-chain activity related to this
// channel.
ChainEvents *contractcourt.ChainEventSubscription
// FeeEstimator is an instance of a live fee estimator which will be
// used to dynamically regulate the current fee of the commitment
// transaction to ensure timely confirmation.
FeeEstimator chainfee.Estimator
// hodl.Mask is a bitvector composed of hodl.Flags, specifying breakpoints
// for HTLC forwarding internal to the switch.
//
// NOTE: This should only be used for testing.
HodlMask hodl.Mask
// SyncStates is used to indicate that we need send the channel
// reestablishment message to the remote peer. It should be done if our
// clients have been restarted, or remote peer have been reconnected.
SyncStates bool
// BatchTicker is the ticker that determines the interval that we'll
// use to check the batch to see if there're any updates we should
// flush out. By batching updates into a single commit, we attempt to
// increase throughput by maximizing the number of updates coalesced
// into a single commit.
BatchTicker ticker.Ticker
// FwdPkgGCTicker is the ticker determining the frequency at which
// garbage collection of forwarding packages occurs. We use a
// time-based approach, as opposed to block epochs, as to not hinder
// syncing.
FwdPkgGCTicker ticker.Ticker
// PendingCommitTicker is a ticker that allows the link to determine if
// a locally initiated commitment dance gets stuck waiting for the
// remote party to revoke.
PendingCommitTicker ticker.Ticker
// BatchSize is the max size of a batch of updates done to the link
// before we do a state update.
BatchSize uint32
// UnsafeReplay will cause a link to replay the adds in its latest
// commitment txn after the link is restarted. This should only be used
// in testing, it is here to ensure the sphinx replay detection on the
// receiving node is persistent.
UnsafeReplay bool
// MinUpdateTimeout represents the minimum interval in which a link
// will propose to update its commitment fee rate. A random timeout will
// be selected between this and MaxUpdateTimeout.
MinUpdateTimeout time.Duration
// MaxUpdateTimeout represents the maximum interval in which a link
// will propose to update its commitment fee rate. A random timeout will
// be selected between this and MinUpdateTimeout.
MaxUpdateTimeout time.Duration
// OutgoingCltvRejectDelta defines the number of blocks before expiry of
// an htlc where we don't offer an htlc anymore. This should be at least
// the outgoing broadcast delta, because in any case we don't want to
// risk offering an htlc that triggers channel closure.
OutgoingCltvRejectDelta uint32
// TowerClient is an optional engine that manages the signing,
// encrypting, and uploading of justice transactions to the daemon's
// configured set of watchtowers for legacy channels.
TowerClient TowerClient
// MaxOutgoingCltvExpiry is the maximum outgoing timelock that the link
// should accept for a forwarded HTLC. The value is relative to the
// current block height.
MaxOutgoingCltvExpiry uint32
// MaxFeeAllocation is the highest allocation we'll allow a channel's
// commitment fee to be of its balance. This only applies to the
// initiator of the channel.
MaxFeeAllocation float64
// MaxAnchorsCommitFeeRate is the max commitment fee rate we'll use as
// the initiator for channels of the anchor type.
MaxAnchorsCommitFeeRate chainfee.SatPerKWeight
// NotifyActiveLink allows the link to tell the ChannelNotifier when a
// link is first started.
NotifyActiveLink func(wire.OutPoint)
// NotifyActiveChannel allows the link to tell the ChannelNotifier when
// channels becomes active.
NotifyActiveChannel func(wire.OutPoint)
// NotifyInactiveChannel allows the switch to tell the ChannelNotifier
// when channels become inactive.
NotifyInactiveChannel func(wire.OutPoint)
// NotifyInactiveLinkEvent allows the switch to tell the
// ChannelNotifier when a channel link become inactive.
NotifyInactiveLinkEvent func(wire.OutPoint)
// HtlcNotifier is an instance of a htlcNotifier which we will pipe htlc
// events through.
HtlcNotifier htlcNotifier
// FailAliasUpdate is a function used to fail an HTLC for an
// option_scid_alias channel.
FailAliasUpdate func(sid lnwire.ShortChannelID,
incoming bool) *lnwire.ChannelUpdate1
// GetAliases is used by the link and switch to fetch the set of
// aliases for a given link.
GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID
// PreviouslySentShutdown is an optional value that is set if, at the
// time of the link being started, persisted shutdown info was found for
// the channel. This value being set means that we previously sent a
// Shutdown message to our peer, and so we should do so again on
// re-establish and should not allow anymore HTLC adds on the outgoing
// direction of the link.
PreviouslySentShutdown fn.Option[lnwire.Shutdown]
// Adds the option to disable forwarding payments in blinded routes
// by failing back any blinding-related payloads as if they were
// invalid.
DisallowRouteBlinding bool
// DisallowQuiescence is a flag that can be used to disable the
// quiescence protocol.
DisallowQuiescence bool
// MaxFeeExposure is the threshold in milli-satoshis after which we'll
// restrict the flow of HTLCs and fee updates.
MaxFeeExposure lnwire.MilliSatoshi
// ShouldFwdExpEndorsement is a closure that indicates whether the link
// should forward experimental endorsement signals.
ShouldFwdExpEndorsement func() bool
// AuxTrafficShaper is an optional auxiliary traffic shaper that can be
// used to manage the bandwidth of the link.
AuxTrafficShaper fn.Option[AuxTrafficShaper]
}
// channelLink is the service which drives a channel's commitment update
// state-machine. In the event that an HTLC needs to be propagated to another
// link, the forward handler from config is used which sends HTLC to the
// switch. Additionally, the link encapsulate logic of commitment protocol
// message ordering and updates.
type channelLink struct {
// The following fields are only meant to be used *atomically*
started int32
reestablished int32
shutdown int32
// failed should be set to true in case a link error happens, making
// sure we don't process any more updates.
failed bool
// keystoneBatch represents a volatile list of keystones that must be
// written before attempting to sign the next commitment txn. These
// represent all the HTLC's forwarded to the link from the switch. Once
// we lock them into our outgoing commitment, then the circuit has a
// keystone, and is fully opened.
keystoneBatch []Keystone
// openedCircuits is the set of all payment circuits that will be open
// once we make our next commitment. After making the commitment we'll
// ACK all these from our mailbox to ensure that they don't get
// re-delivered if we reconnect.
openedCircuits []CircuitKey
// closedCircuits is the set of all payment circuits that will be
// closed once we make our next commitment. After taking the commitment
// we'll ACK all these to ensure that they don't get re-delivered if we
// reconnect.
closedCircuits []CircuitKey
// channel is a lightning network channel to which we apply htlc
// updates.
channel *lnwallet.LightningChannel
// cfg is a structure which carries all dependable fields/handlers
// which may affect behaviour of the service.
cfg ChannelLinkConfig
// mailBox is the main interface between the outside world and the
// link. All incoming messages will be sent over this mailBox. Messages
// include new updates from our connected peer, and new packets to be
// forwarded sent by the switch.
mailBox MailBox
// upstream is a channel that new messages sent from the remote peer to
// the local peer will be sent across.
upstream chan lnwire.Message
// downstream is a channel in which new multi-hop HTLC's to be
// forwarded will be sent across. Messages from this channel are sent
// by the HTLC switch.
downstream chan *htlcPacket
// updateFeeTimer is the timer responsible for updating the link's
// commitment fee every time it fires.
updateFeeTimer *time.Timer
// uncommittedPreimages stores a list of all preimages that have been
// learned since receiving the last CommitSig from the remote peer. The
// batch will be flushed just before accepting the subsequent CommitSig
// or on shutdown to avoid doing a write for each preimage received.
uncommittedPreimages []lntypes.Preimage
sync.RWMutex
// hodlQueue is used to receive exit hop htlc resolutions from invoice
// registry.
hodlQueue *queue.ConcurrentQueue
// hodlMap stores related htlc data for a circuit key. It allows
// resolving those htlcs when we receive a message on hodlQueue.
hodlMap map[models.CircuitKey]hodlHtlc
// log is a link-specific logging instance.
log btclog.Logger
// isOutgoingAddBlocked tracks whether the channelLink can send an
// UpdateAddHTLC.
isOutgoingAddBlocked atomic.Bool
// isIncomingAddBlocked tracks whether the channelLink can receive an
// UpdateAddHTLC.
isIncomingAddBlocked atomic.Bool
// flushHooks is a hookMap that is triggered when we reach a channel
// state with no live HTLCs.
flushHooks hookMap
// outgoingCommitHooks is a hookMap that is triggered after we send our
// next CommitSig.
outgoingCommitHooks hookMap
// incomingCommitHooks is a hookMap that is triggered after we receive
// our next CommitSig.
incomingCommitHooks hookMap
// quiescer is the state machine that tracks where this channel is with
// respect to the quiescence protocol.
quiescer Quiescer
// quiescenceReqs is a queue of requests to quiesce this link. The
// members of the queue are send-only channels we should call back with
// the result.
quiescenceReqs chan StfuReq
// cg is a helper that encapsulates a wait group and quit channel and
// allows contexts that either block or cancel on those depending on
// the use case.
cg *fn.ContextGuard
}
// hookMap is a data structure that is used to track the hooks that need to be
// called in various parts of the channelLink's lifecycle.
//
// WARNING: NOT thread-safe.
type hookMap struct {
// allocIdx keeps track of the next id we haven't yet allocated.
allocIdx atomic.Uint64
// transient is a map of hooks that are only called the next time invoke
// is called. These hooks are deleted during invoke.
transient map[uint64]func()
// newTransients is a channel that we use to accept new hooks into the
// hookMap.
newTransients chan func()
}
// newHookMap initializes a new empty hookMap.
func newHookMap() hookMap {
return hookMap{
allocIdx: atomic.Uint64{},
transient: make(map[uint64]func()),
newTransients: make(chan func()),
}
}
// alloc allocates space in the hook map for the supplied hook, the second
// argument determines whether it goes into the transient or persistent part
// of the hookMap.
func (m *hookMap) alloc(hook func()) uint64 {
// We assume we never overflow a uint64. Seems OK.
hookID := m.allocIdx.Add(1)
if hookID == 0 {
panic("hookMap allocIdx overflow")
}
m.transient[hookID] = hook
return hookID
}
// invoke is used on a hook map to call all the registered hooks and then clear
// out the transient hooks so they are not called again.
func (m *hookMap) invoke() {
for _, hook := range m.transient {
hook()
}
m.transient = make(map[uint64]func())
}
// hodlHtlc contains htlc data that is required for resolution.
type hodlHtlc struct {
add lnwire.UpdateAddHTLC
sourceRef channeldb.AddRef
obfuscator hop.ErrorEncrypter
}
// NewChannelLink creates a new instance of a ChannelLink given a configuration
// and active channel that will be used to verify/apply updates to.
func NewChannelLink(cfg ChannelLinkConfig,
channel *lnwallet.LightningChannel) ChannelLink {
logPrefix := fmt.Sprintf("ChannelLink(%v):", channel.ChannelPoint())
// If the max fee exposure isn't set, use the default.
if cfg.MaxFeeExposure == 0 {
cfg.MaxFeeExposure = DefaultMaxFeeExposure
}
var qsm Quiescer
if !cfg.DisallowQuiescence {
qsm = NewQuiescer(QuiescerCfg{
chanID: lnwire.NewChanIDFromOutPoint(
channel.ChannelPoint(),
),
channelInitiator: channel.Initiator(),
sendMsg: func(s lnwire.Stfu) error {
return cfg.Peer.SendMessage(false, &s)
},
timeoutDuration: defaultQuiescenceTimeout,
onTimeout: func() {
cfg.Peer.Disconnect(ErrQuiescenceTimeout)
},
})
} else {
qsm = &quiescerNoop{}
}
quiescenceReqs := make(
chan fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]], 1,
)
return &channelLink{
cfg: cfg,
channel: channel,
hodlMap: make(map[models.CircuitKey]hodlHtlc),
hodlQueue: queue.NewConcurrentQueue(10),
log: log.WithPrefix(logPrefix),
flushHooks: newHookMap(),
outgoingCommitHooks: newHookMap(),
incomingCommitHooks: newHookMap(),
quiescer: qsm,
quiescenceReqs: quiescenceReqs,
cg: fn.NewContextGuard(),
}
}
// A compile time check to ensure channelLink implements the ChannelLink
// interface.
var _ ChannelLink = (*channelLink)(nil)
// Start starts all helper goroutines required for the operation of the channel
// link.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) Start() error {
if !atomic.CompareAndSwapInt32(&l.started, 0, 1) {
err := fmt.Errorf("channel link(%v): already started", l)
l.log.Warn("already started")
return err
}
l.log.Info("starting")
// If the config supplied watchtower client, ensure the channel is
// registered before trying to use it during operation.
if l.cfg.TowerClient != nil {
err := l.cfg.TowerClient.RegisterChannel(
l.ChanID(), l.channel.State().ChanType,
)
if err != nil {
return err
}
}
l.mailBox.ResetMessages()
l.hodlQueue.Start()
// Before launching the htlcManager messages, revert any circuits that
// were marked open in the switch's circuit map, but did not make it
// into a commitment txn. We use the next local htlc index as the cut
// off point, since all indexes below that are committed. This action
// is only performed if the link's final short channel ID has been
// assigned, otherwise we would try to trim the htlcs belonging to the
// all-zero, hop.Source ID.
if l.ShortChanID() != hop.Source {
localHtlcIndex, err := l.channel.NextLocalHtlcIndex()
if err != nil {
return fmt.Errorf("unable to retrieve next local "+
"htlc index: %v", err)
}
// NOTE: This is automatically done by the switch when it
// starts up, but is necessary to prevent inconsistencies in
// the case that the link flaps. This is a result of a link's
// life-cycle being shorter than that of the switch.
chanID := l.ShortChanID()
err = l.cfg.Circuits.TrimOpenCircuits(chanID, localHtlcIndex)
if err != nil {
return fmt.Errorf("unable to trim circuits above "+
"local htlc index %d: %v", localHtlcIndex, err)
}
// Since the link is live, before we start the link we'll update
// the ChainArbitrator with the set of new channel signals for
// this channel.
//
// TODO(roasbeef): split goroutines within channel arb to avoid
go func() {
signals := &contractcourt.ContractSignals{
ShortChanID: l.channel.ShortChanID(),
}
err := l.cfg.UpdateContractSignals(signals)
if err != nil {
l.log.Errorf("unable to update signals")
}
}()
}
l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout())
l.cg.WgAdd(1)
go l.htlcManager(context.TODO())
return nil
}
// Stop gracefully stops all active helper goroutines, then waits until they've
// exited.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) Stop() {
if !atomic.CompareAndSwapInt32(&l.shutdown, 0, 1) {
l.log.Warn("already stopped")
return
}
l.log.Info("stopping")
// As the link is stopping, we are no longer interested in htlc
// resolutions coming from the invoice registry.
l.cfg.Registry.HodlUnsubscribeAll(l.hodlQueue.ChanIn())
if l.cfg.ChainEvents.Cancel != nil {
l.cfg.ChainEvents.Cancel()
}
// Ensure the channel for the timer is drained.
if l.updateFeeTimer != nil {
if !l.updateFeeTimer.Stop() {
select {
case <-l.updateFeeTimer.C:
default:
}
}
}
if l.hodlQueue != nil {
l.hodlQueue.Stop()
}
l.cg.Quit()
l.cg.WgWait()
// Now that the htlcManager has completely exited, reset the packet
// courier. This allows the mailbox to revaluate any lingering Adds that
// were delivered but didn't make it on a commitment to be failed back
// if the link is offline for an extended period of time. The error is
// ignored since it can only fail when the daemon is exiting.
_ = l.mailBox.ResetPackets()
// As a final precaution, we will attempt to flush any uncommitted
// preimages to the preimage cache. The preimages should be re-delivered
// after channel reestablishment, however this adds an extra layer of
// protection in case the peer never returns. Without this, we will be
// unable to settle any contracts depending on the preimages even though
// we had learned them at some point.
err := l.cfg.PreimageCache.AddPreimages(l.uncommittedPreimages...)
if err != nil {
l.log.Errorf("unable to add preimages=%v to cache: %v",
l.uncommittedPreimages, err)
}
}
// WaitForShutdown blocks until the link finishes shutting down, which includes
// termination of all dependent goroutines.
func (l *channelLink) WaitForShutdown() {
l.cg.WgWait()
}
// EligibleToForward returns a bool indicating if the channel is able to
// actively accept requests to forward HTLC's. We're able to forward HTLC's if
// we are eligible to update AND the channel isn't currently flushing the
// outgoing half of the channel.
//
// NOTE: MUST NOT be called from the main event loop.
func (l *channelLink) EligibleToForward() bool {
l.RLock()
defer l.RUnlock()
return l.eligibleToForward()
}
// eligibleToForward returns a bool indicating if the channel is able to
// actively accept requests to forward HTLC's. We're able to forward HTLC's if
// we are eligible to update AND the channel isn't currently flushing the
// outgoing half of the channel.
//
// NOTE: MUST be called from the main event loop.
func (l *channelLink) eligibleToForward() bool {
return l.eligibleToUpdate() && !l.IsFlushing(Outgoing)
}
// eligibleToUpdate returns a bool indicating if the channel is able to update
// channel state. We're able to update channel state if we know the remote
// party's next revocation point. Otherwise, we can't initiate new channel
// state. We also require that the short channel ID not be the all-zero source
// ID, meaning that the channel has had its ID finalized.
//
// NOTE: MUST be called from the main event loop.
func (l *channelLink) eligibleToUpdate() bool {
return l.channel.RemoteNextRevocation() != nil &&
l.channel.ShortChanID() != hop.Source &&
l.isReestablished() &&
l.quiescer.CanSendUpdates()
}
// EnableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in
// the specified direction. It returns true if the state was changed and false
// if the desired state was already set before the method was called.
func (l *channelLink) EnableAdds(linkDirection LinkDirection) bool {
if linkDirection == Outgoing {
return l.isOutgoingAddBlocked.Swap(false)
}
return l.isIncomingAddBlocked.Swap(false)
}
// DisableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in
// the specified direction. It returns true if the state was changed and false
// if the desired state was already set before the method was called.
func (l *channelLink) DisableAdds(linkDirection LinkDirection) bool {
if linkDirection == Outgoing {
return !l.isOutgoingAddBlocked.Swap(true)
}
return !l.isIncomingAddBlocked.Swap(true)
}
// IsFlushing returns true when UpdateAddHtlc's are disabled in the direction of
// the argument.
func (l *channelLink) IsFlushing(linkDirection LinkDirection) bool {
if linkDirection == Outgoing {
return l.isOutgoingAddBlocked.Load()
}
return l.isIncomingAddBlocked.Load()
}
// OnFlushedOnce adds a hook that will be called the next time the channel
// state reaches zero htlcs. This hook will only ever be called once. If the
// channel state already has zero htlcs, then this will be called immediately.
func (l *channelLink) OnFlushedOnce(hook func()) {
select {
case l.flushHooks.newTransients <- hook:
case <-l.cg.Done():
}
}
// OnCommitOnce adds a hook that will be called the next time a CommitSig
// message is sent in the argument's LinkDirection. This hook will only ever be
// called once. If no CommitSig is owed in the argument's LinkDirection, then
// we will call this hook be run immediately.
func (l *channelLink) OnCommitOnce(direction LinkDirection, hook func()) {
var queue chan func()
if direction == Outgoing {
queue = l.outgoingCommitHooks.newTransients
} else {
queue = l.incomingCommitHooks.newTransients
}
select {
case queue <- hook:
case <-l.cg.Done():
}
}
// InitStfu allows us to initiate quiescence on this link. It returns a receive
// only channel that will block until quiescence has been achieved, or
// definitively fails.
//
// This operation has been added to allow channels to be quiesced via RPC. It
// may be removed or reworked in the future as RPC initiated quiescence is a
// holdover until we have downstream protocols that use it.
func (l *channelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
req, out := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
select {
case l.quiescenceReqs <- req:
case <-l.cg.Done():
req.Resolve(fn.Err[lntypes.ChannelParty](ErrLinkShuttingDown))
}
return out
}
// isReestablished returns true if the link has successfully completed the
// channel reestablishment dance.
func (l *channelLink) isReestablished() bool {
return atomic.LoadInt32(&l.reestablished) == 1
}
// markReestablished signals that the remote peer has successfully exchanged
// channel reestablish messages and that the channel is ready to process
// subsequent messages.
func (l *channelLink) markReestablished() {
atomic.StoreInt32(&l.reestablished, 1)
}
// IsUnadvertised returns true if the underlying channel is unadvertised.
func (l *channelLink) IsUnadvertised() bool {
state := l.channel.State()
return state.ChannelFlags&lnwire.FFAnnounceChannel == 0
}
// sampleNetworkFee samples the current fee rate on the network to get into the
// chain in a timely manner. The returned value is expressed in fee-per-kw, as
// this is the native rate used when computing the fee for commitment
// transactions, and the second-level HTLC transactions.
func (l *channelLink) sampleNetworkFee() (chainfee.SatPerKWeight, error) {
// We'll first query for the sat/kw recommended to be confirmed within 3
// blocks.
feePerKw, err := l.cfg.FeeEstimator.EstimateFeePerKW(3)
if err != nil {
return 0, err
}
l.log.Debugf("sampled fee rate for 3 block conf: %v sat/kw",
int64(feePerKw))
return feePerKw, nil
}
// shouldAdjustCommitFee returns true if we should update our commitment fee to
// match that of the network fee. We'll only update our commitment fee if the
// network fee is +/- 10% to our commitment fee or if our current commitment
// fee is below the minimum relay fee.
func shouldAdjustCommitFee(netFee, chanFee,
minRelayFee chainfee.SatPerKWeight) bool {
switch {
// If the network fee is greater than our current commitment fee and
// our current commitment fee is below the minimum relay fee then
// we should switch to it no matter if it is less than a 10% increase.
case netFee > chanFee && chanFee < minRelayFee:
return true
// If the network fee is greater than the commitment fee, then we'll
// switch to it if it's at least 10% greater than the commit fee.
case netFee > chanFee && netFee >= (chanFee+(chanFee*10)/100):
return true
// If the network fee is less than our commitment fee, then we'll
// switch to it if it's at least 10% less than the commitment fee.
case netFee < chanFee && netFee <= (chanFee-(chanFee*10)/100):
return true
// Otherwise, we won't modify our fee.
default:
return false
}
}
// failCb is used to cut down on the argument verbosity.
type failCb func(update *lnwire.ChannelUpdate1) lnwire.FailureMessage
// createFailureWithUpdate creates a ChannelUpdate when failing an incoming or
// outgoing HTLC. It may return a FailureMessage that references a channel's
// alias. If the channel does not have an alias, then the regular channel
// update from disk will be returned.
func (l *channelLink) createFailureWithUpdate(incoming bool,
outgoingScid lnwire.ShortChannelID, cb failCb) lnwire.FailureMessage {
// Determine which SCID to use in case we need to use aliases in the
// ChannelUpdate.
scid := outgoingScid
if incoming {
scid = l.ShortChanID()
}
// Try using the FailAliasUpdate function. If it returns nil, fallback
// to the non-alias behavior.
update := l.cfg.FailAliasUpdate(scid, incoming)
if update == nil {
// Fallback to the non-alias behavior.
var err error
update, err = l.cfg.FetchLastChannelUpdate(l.ShortChanID())
if err != nil {
return &lnwire.FailTemporaryNodeFailure{}
}
}
return cb(update)
}
// syncChanState attempts to synchronize channel states with the remote party.
// This method is to be called upon reconnection after the initial funding
// flow. We'll compare out commitment chains with the remote party, and re-send
// either a danging commit signature, a revocation, or both.
func (l *channelLink) syncChanStates(ctx context.Context) error {
chanState := l.channel.State()
l.log.Infof("Attempting to re-synchronize channel: %v", chanState)
// First, we'll generate our ChanSync message to send to the other
// side. Based on this message, the remote party will decide if they
// need to retransmit any data or not.
localChanSyncMsg, err := chanState.ChanSyncMsg()
if err != nil {
return fmt.Errorf("unable to generate chan sync message for "+
"ChannelPoint(%v)", l.channel.ChannelPoint())
}
if err := l.cfg.Peer.SendMessage(true, localChanSyncMsg); err != nil {
return fmt.Errorf("unable to send chan sync message for "+
"ChannelPoint(%v): %v", l.channel.ChannelPoint(), err)
}
var msgsToReSend []lnwire.Message
// Next, we'll wait indefinitely to receive the ChanSync message. The
// first message sent MUST be the ChanSync message.
select {
case msg := <-l.upstream:
l.log.Tracef("Received msg=%v from peer(%x)", msg.MsgType(),
l.cfg.Peer.PubKey())
remoteChanSyncMsg, ok := msg.(*lnwire.ChannelReestablish)
if !ok {
return fmt.Errorf("first message sent to sync "+
"should be ChannelReestablish, instead "+
"received: %T", msg)
}
// If the remote party indicates that they think we haven't
// done any state updates yet, then we'll retransmit the
// channel_ready message first. We do this, as at this point
// we can't be sure if they've really received the
// ChannelReady message.
if remoteChanSyncMsg.NextLocalCommitHeight == 1 &&
localChanSyncMsg.NextLocalCommitHeight == 1 &&
!l.channel.IsPending() {
l.log.Infof("resending ChannelReady message to peer")
nextRevocation, err := l.channel.NextRevocationKey()
if err != nil {
return fmt.Errorf("unable to create next "+
"revocation: %v", err)
}
channelReadyMsg := lnwire.NewChannelReady(
l.ChanID(), nextRevocation,
)
// If this is a taproot channel, then we'll send the
// very same nonce that we sent above, as they should
// take the latest verification nonce we send.
if chanState.ChanType.IsTaproot() {
//nolint:ll
channelReadyMsg.NextLocalNonce = localChanSyncMsg.LocalNonce
}
// For channels that negotiated the option-scid-alias
// feature bit, ensure that we send over the alias in
// the channel_ready message. We'll send the first
// alias we find for the channel since it does not
// matter which alias we send. We'll error out if no
// aliases are found.
if l.negotiatedAliasFeature() {
aliases := l.getAliases()
if len(aliases) == 0 {
// This shouldn't happen since we
// always add at least one alias before
// the channel reaches the link.
return fmt.Errorf("no aliases found")
}
// getAliases returns a copy of the alias slice
// so it is ok to use a pointer to the first
// entry.
channelReadyMsg.AliasScid = &aliases[0]
}
err = l.cfg.Peer.SendMessage(false, channelReadyMsg)
if err != nil {
return fmt.Errorf("unable to re-send "+
"ChannelReady: %v", err)
}
}
// In any case, we'll then process their ChanSync message.
l.log.Info("received re-establishment message from remote side")
var (
openedCircuits []CircuitKey
closedCircuits []CircuitKey
)
// We've just received a ChanSync message from the remote
// party, so we'll process the message in order to determine
// if we need to re-transmit any messages to the remote party.
ctx, cancel := l.cg.Create(ctx)
defer cancel()
msgsToReSend, openedCircuits, closedCircuits, err =
l.channel.ProcessChanSyncMsg(ctx, remoteChanSyncMsg)
if err != nil {
return err
}
// Repopulate any identifiers for circuits that may have been
// opened or unclosed. This may happen if we needed to
// retransmit a commitment signature message.
l.openedCircuits = openedCircuits
l.closedCircuits = closedCircuits
// Ensure that all packets have been have been removed from the
// link's mailbox.
if err := l.ackDownStreamPackets(); err != nil {
return err
}
if len(msgsToReSend) > 0 {
l.log.Infof("sending %v updates to synchronize the "+
"state", len(msgsToReSend))
}
// If we have any messages to retransmit, we'll do so
// immediately so we return to a synchronized state as soon as
// possible.
for _, msg := range msgsToReSend {
l.cfg.Peer.SendMessage(false, msg)
}
case <-l.cg.Done():
return ErrLinkShuttingDown
}
return nil
}
// resolveFwdPkgs loads any forwarding packages for this link from disk, and
// reprocesses them in order. The primary goal is to make sure that any HTLCs
// we previously received are reinstated in memory, and forwarded to the switch
// if necessary. After a restart, this will also delete any previously
// completed packages.
func (l *channelLink) resolveFwdPkgs(ctx context.Context) error {
fwdPkgs, err := l.channel.LoadFwdPkgs()
if err != nil {
return err
}
l.log.Debugf("loaded %d fwd pks", len(fwdPkgs))
for _, fwdPkg := range fwdPkgs {
if err := l.resolveFwdPkg(fwdPkg); err != nil {
return err
}
}
// If any of our reprocessing steps require an update to the commitment
// txn, we initiate a state transition to capture all relevant changes.
if l.channel.NumPendingUpdates(lntypes.Local, lntypes.Remote) > 0 {
return l.updateCommitTx(ctx)
}
return nil
}
// resolveFwdPkg interprets the FwdState of the provided package, either
// reprocesses any outstanding htlcs in the package, or performs garbage
// collection on the package.
func (l *channelLink) resolveFwdPkg(fwdPkg *channeldb.FwdPkg) error {
// Remove any completed packages to clear up space.
if fwdPkg.State == channeldb.FwdStateCompleted {
l.log.Debugf("removing completed fwd pkg for height=%d",
fwdPkg.Height)
err := l.channel.RemoveFwdPkgs(fwdPkg.Height)
if err != nil {
l.log.Errorf("unable to remove fwd pkg for height=%d: "+
"%v", fwdPkg.Height, err)
return err
}
}
// Otherwise this is either a new package or one has gone through
// processing, but contains htlcs that need to be restored in memory.
// We replay this forwarding package to make sure our local mem state
// is resurrected, we mimic any original responses back to the remote
// party, and re-forward the relevant HTLCs to the switch.
// If the package is fully acked but not completed, it must still have
// settles and fails to propagate.
if !fwdPkg.SettleFailFilter.IsFull() {
l.processRemoteSettleFails(fwdPkg)
}
// Finally, replay *ALL ADDS* in this forwarding package. The
// downstream logic is able to filter out any duplicates, but we must
// shove the entire, original set of adds down the pipeline so that the
// batch of adds presented to the sphinx router does not ever change.
if !fwdPkg.AckFilter.IsFull() {
l.processRemoteAdds(fwdPkg)
// If the link failed during processing the adds, we must
// return to ensure we won't attempted to update the state
// further.
if l.failed {
return fmt.Errorf("link failed while " +
"processing remote adds")
}
}
return nil
}
// fwdPkgGarbager periodically reads all forwarding packages from disk and
// removes those that can be discarded. It is safe to do this entirely in the
// background, since all state is coordinated on disk. This also ensures the
// link can continue to process messages and interleave database accesses.
//
// NOTE: This MUST be run as a goroutine.
func (l *channelLink) fwdPkgGarbager() {
defer l.cg.WgDone()
l.cfg.FwdPkgGCTicker.Resume()
defer l.cfg.FwdPkgGCTicker.Stop()
if err := l.loadAndRemove(); err != nil {
l.log.Warnf("unable to run initial fwd pkgs gc: %v", err)
}
for {
select {
case <-l.cfg.FwdPkgGCTicker.Ticks():
if err := l.loadAndRemove(); err != nil {
l.log.Warnf("unable to remove fwd pkgs: %v",
err)
continue
}
case <-l.cg.Done():
return
}
}
}
// loadAndRemove loads all the channels forwarding packages and determines if
// they can be removed. It is called once before the FwdPkgGCTicker ticks so that
// a longer tick interval can be used.
func (l *channelLink) loadAndRemove() error {
fwdPkgs, err := l.channel.LoadFwdPkgs()
if err != nil {
return err
}
var removeHeights []uint64
for _, fwdPkg := range fwdPkgs {
if fwdPkg.State != channeldb.FwdStateCompleted {
continue
}
removeHeights = append(removeHeights, fwdPkg.Height)
}
// If removeHeights is empty, return early so we don't use a db
// transaction.
if len(removeHeights) == 0 {
return nil
}
return l.channel.RemoveFwdPkgs(removeHeights...)
}
// handleChanSyncErr performs the error handling logic in the case where we
// could not successfully syncChanStates with our channel peer.
func (l *channelLink) handleChanSyncErr(err error) {
l.log.Warnf("error when syncing channel states: %v", err)
var errDataLoss *lnwallet.ErrCommitSyncLocalDataLoss
switch {
case errors.Is(err, ErrLinkShuttingDown):
l.log.Debugf("unable to sync channel states, link is " +
"shutting down")
return
// We failed syncing the commit chains, probably because the remote has
// lost state. We should force close the channel.
case errors.Is(err, lnwallet.ErrCommitSyncRemoteDataLoss):
fallthrough
// The remote sent us an invalid last commit secret, we should force
// close the channel.
// TODO(halseth): and permanently ban the peer?
case errors.Is(err, lnwallet.ErrInvalidLastCommitSecret):
fallthrough
// The remote sent us a commit point different from what they sent us
// before.
// TODO(halseth): ban peer?
case errors.Is(err, lnwallet.ErrInvalidLocalUnrevokedCommitPoint):
// We'll fail the link and tell the peer to force close the
// channel. Note that the database state is not updated here,
// but will be updated when the close transaction is ready to
// avoid that we go down before storing the transaction in the
// db.
l.failf(
LinkFailureError{
code: ErrSyncError,
FailureAction: LinkFailureForceClose,
},
"unable to synchronize channel states: %v", err,
)
// We have lost state and cannot safely force close the channel. Fail
// the channel and wait for the remote to hopefully force close it. The
// remote has sent us its latest unrevoked commitment point, and we'll
// store it in the database, such that we can attempt to recover the
// funds if the remote force closes the channel.
case errors.As(err, &errDataLoss):
err := l.channel.MarkDataLoss(
errDataLoss.CommitPoint,
)
if err != nil {
l.log.Errorf("unable to mark channel data loss: %v",
err)
}
// We determined the commit chains were not possible to sync. We
// cautiously fail the channel, but don't force close.
// TODO(halseth): can we safely force close in any cases where this
// error is returned?
case errors.Is(err, lnwallet.ErrCannotSyncCommitChains):
if err := l.channel.MarkBorked(); err != nil {
l.log.Errorf("unable to mark channel borked: %v", err)
}
// Other, unspecified error.
default:
}
l.failf(
LinkFailureError{
code: ErrRecoveryError,
FailureAction: LinkFailureForceNone,
},
"unable to synchronize channel states: %v", err,
)
}
// htlcManager is the primary goroutine which drives a channel's commitment
// update state-machine in response to messages received via several channels.
// This goroutine reads messages from the upstream (remote) peer, and also from
// downstream channel managed by the channel link. In the event that an htlc
// needs to be forwarded, then send-only forward handler is used which sends
// htlc packets to the switch. Additionally, this goroutine handles acting upon
// all timeouts for any active HTLCs, manages the channel's revocation window,
// and also the htlc trickle queue+timer for this active channels.
//
// NOTE: This MUST be run as a goroutine.
//
//nolint:funlen
func (l *channelLink) htlcManager(ctx context.Context) {
defer func() {
l.cfg.BatchTicker.Stop()
l.cg.WgDone()
l.log.Infof("exited")
}()
l.log.Infof("HTLC manager started, bandwidth=%v", l.Bandwidth())
// Notify any clients that the link is now in the switch via an
// ActiveLinkEvent. We'll also defer an inactive link notification for
// when the link exits to ensure that every active notification is
// matched by an inactive one.
l.cfg.NotifyActiveLink(l.ChannelPoint())
defer l.cfg.NotifyInactiveLinkEvent(l.ChannelPoint())
// TODO(roasbeef): need to call wipe chan whenever D/C?
// If this isn't the first time that this channel link has been
// created, then we'll need to check to see if we need to
// re-synchronize state with the remote peer. settledHtlcs is a map of
// HTLC's that we re-settled as part of the channel state sync.
if l.cfg.SyncStates {
err := l.syncChanStates(ctx)
if err != nil {
l.handleChanSyncErr(err)
return
}
}
// If a shutdown message has previously been sent on this link, then we
// need to make sure that we have disabled any HTLC adds on the outgoing
// direction of the link and that we re-resend the same shutdown message
// that we previously sent.
l.cfg.PreviouslySentShutdown.WhenSome(func(shutdown lnwire.Shutdown) {
// Immediately disallow any new outgoing HTLCs.
if !l.DisableAdds(Outgoing) {
l.log.Warnf("Outgoing link adds already disabled")
}
// Re-send the shutdown message the peer. Since syncChanStates
// would have sent any outstanding CommitSig, it is fine for us
// to immediately queue the shutdown message now.
err := l.cfg.Peer.SendMessage(false, &shutdown)
if err != nil {
l.log.Warnf("Error sending shutdown message: %v", err)
}
})
// We've successfully reestablished the channel, mark it as such to
// allow the switch to forward HTLCs in the outbound direction.
l.markReestablished()
// Now that we've received both channel_ready and channel reestablish,
// we can go ahead and send the active channel notification. We'll also
// defer the inactive notification for when the link exits to ensure
// that every active notification is matched by an inactive one.
l.cfg.NotifyActiveChannel(l.ChannelPoint())
defer l.cfg.NotifyInactiveChannel(l.ChannelPoint())
// With the channel states synced, we now reset the mailbox to ensure
// we start processing all unacked packets in order. This is done here
// to ensure that all acknowledgments that occur during channel
// resynchronization have taken affect, causing us only to pull unacked
// packets after starting to read from the downstream mailbox.
l.mailBox.ResetPackets()
// After cleaning up any memory pertaining to incoming packets, we now
// replay our forwarding packages to handle any htlcs that can be
// processed locally, or need to be forwarded out to the switch. We will
// only attempt to resolve packages if our short chan id indicates that
// the channel is not pending, otherwise we should have no htlcs to
// reforward.
if l.ShortChanID() != hop.Source {
err := l.resolveFwdPkgs(ctx)
switch err {
// No error was encountered, success.
case nil:
// If the duplicate keystone error was encountered, we'll fail
// without sending an Error message to the peer.
case ErrDuplicateKeystone:
l.failf(LinkFailureError{code: ErrCircuitError},
"temporary circuit error: %v", err)
return
// A non-nil error was encountered, send an Error message to
// the peer.
default:
l.failf(LinkFailureError{code: ErrInternalError},
"unable to resolve fwd pkgs: %v", err)
return
}
// With our link's in-memory state fully reconstructed, spawn a
// goroutine to manage the reclamation of disk space occupied by
// completed forwarding packages.
l.cg.WgAdd(1)
go l.fwdPkgGarbager()
}
for {
// We must always check if we failed at some point processing
// the last update before processing the next.
if l.failed {
l.log.Errorf("link failed, exiting htlcManager")
return
}
// If the previous event resulted in a non-empty batch, resume
// the batch ticker so that it can be cleared. Otherwise pause
// the ticker to prevent waking up the htlcManager while the
// batch is empty.
numUpdates := l.channel.NumPendingUpdates(
lntypes.Local, lntypes.Remote,
)
if numUpdates > 0 {
l.cfg.BatchTicker.Resume()
l.log.Tracef("BatchTicker resumed, "+
"NumPendingUpdates(Local, Remote)=%d",
numUpdates,
)
} else {
l.cfg.BatchTicker.Pause()
l.log.Trace("BatchTicker paused due to zero " +
"NumPendingUpdates(Local, Remote)")
}
select {
// We have a new hook that needs to be run when we reach a clean
// channel state.
case hook := <-l.flushHooks.newTransients:
if l.channel.IsChannelClean() {
hook()
} else {
l.flushHooks.alloc(hook)
}
// We have a new hook that needs to be run when we have
// committed all of our updates.
case hook := <-l.outgoingCommitHooks.newTransients:
if !l.channel.OweCommitment() {
hook()
} else {
l.outgoingCommitHooks.alloc(hook)
}
// We have a new hook that needs to be run when our peer has
// committed all of their updates.
case hook := <-l.incomingCommitHooks.newTransients:
if !l.channel.NeedCommitment() {
hook()
} else {
l.incomingCommitHooks.alloc(hook)
}
// Our update fee timer has fired, so we'll check the network
// fee to see if we should adjust our commitment fee.
case <-l.updateFeeTimer.C:
l.updateFeeTimer.Reset(l.randomFeeUpdateTimeout())
// If we're not the initiator of the channel, don't we
// don't control the fees, so we can ignore this.
if !l.channel.IsInitiator() {
continue
}
// If we are the initiator, then we'll sample the
// current fee rate to get into the chain within 3
// blocks.
netFee, err := l.sampleNetworkFee()
if err != nil {
l.log.Errorf("unable to sample network fee: %v",
err)
continue
}
minRelayFee := l.cfg.FeeEstimator.RelayFeePerKW()
newCommitFee := l.channel.IdealCommitFeeRate(
netFee, minRelayFee,
l.cfg.MaxAnchorsCommitFeeRate,
l.cfg.MaxFeeAllocation,
)
// We determine if we should adjust the commitment fee
// based on the current commitment fee, the suggested
// new commitment fee and the current minimum relay fee
// rate.
commitFee := l.channel.CommitFeeRate()
if !shouldAdjustCommitFee(
newCommitFee, commitFee, minRelayFee,
) {
continue
}
// If we do, then we'll send a new UpdateFee message to
// the remote party, to be locked in with a new update.
err = l.updateChannelFee(ctx, newCommitFee)
if err != nil {
l.log.Errorf("unable to update fee rate: %v",
err)
continue
}
// The underlying channel has notified us of a unilateral close
// carried out by the remote peer. In the case of such an
// event, we'll wipe the channel state from the peer, and mark
// the contract as fully settled. Afterwards we can exit.
//
// TODO(roasbeef): add force closure? also breach?
case <-l.cfg.ChainEvents.RemoteUnilateralClosure:
l.log.Warnf("remote peer has closed on-chain")
// TODO(roasbeef): remove all together
go func() {
chanPoint := l.channel.ChannelPoint()
l.cfg.Peer.WipeChannel(&chanPoint)
}()
return
case <-l.cfg.BatchTicker.Ticks():
// Attempt to extend the remote commitment chain
// including all the currently pending entries. If the
// send was unsuccessful, then abandon the update,
// waiting for the revocation window to open up.
if !l.updateCommitTxOrFail(ctx) {
return
}
case <-l.cfg.PendingCommitTicker.Ticks():
l.failf(
LinkFailureError{
code: ErrRemoteUnresponsive,
FailureAction: LinkFailureDisconnect,
},
"unable to complete dance",
)
return
// A message from the switch was just received. This indicates
// that the link is an intermediate hop in a multi-hop HTLC
// circuit.
case pkt := <-l.downstream:
l.handleDownstreamPkt(ctx, pkt)
// A message from the connected peer was just received. This
// indicates that we have a new incoming HTLC, either directly
// for us, or part of a multi-hop HTLC circuit.
case msg := <-l.upstream:
l.handleUpstreamMsg(ctx, msg)
// A htlc resolution is received. This means that we now have a
// resolution for a previously accepted htlc.
case hodlItem := <-l.hodlQueue.ChanOut():
htlcResolution := hodlItem.(invoices.HtlcResolution)
err := l.processHodlQueue(ctx, htlcResolution)
switch err {
// No error, success.
case nil:
// If the duplicate keystone error was encountered,
// fail back gracefully.
case ErrDuplicateKeystone:
l.failf(LinkFailureError{
code: ErrCircuitError,
}, "process hodl queue: "+
"temporary circuit error: %v",
err,
)
// Send an Error message to the peer.
default:
l.failf(LinkFailureError{
code: ErrInternalError,
}, "process hodl queue: unable to update "+
"commitment: %v", err,
)
}
case qReq := <-l.quiescenceReqs:
l.quiescer.InitStfu(qReq)
if l.noDanglingUpdates(lntypes.Local) {
err := l.quiescer.SendOwedStfu()
if err != nil {
l.stfuFailf(
"SendOwedStfu: %s", err.Error(),
)
res := fn.Err[lntypes.ChannelParty](err)
qReq.Resolve(res)
}
}
case <-l.cg.Done():
return
}
}
}
// processHodlQueue processes a received htlc resolution and continues reading
// from the hodl queue until no more resolutions remain. When this function
// returns without an error, the commit tx should be updated.
func (l *channelLink) processHodlQueue(ctx context.Context,
firstResolution invoices.HtlcResolution) error {
// Try to read all waiting resolution messages, so that they can all be
// processed in a single commitment tx update.
htlcResolution := firstResolution
loop:
for {
// Lookup all hodl htlcs that can be failed or settled with this event.
// The hodl htlc must be present in the map.
circuitKey := htlcResolution.CircuitKey()
hodlHtlc, ok := l.hodlMap[circuitKey]
if !ok {
return fmt.Errorf("hodl htlc not found: %v", circuitKey)
}
if err := l.processHtlcResolution(htlcResolution, hodlHtlc); err != nil {
return err
}
// Clean up hodl map.
delete(l.hodlMap, circuitKey)
select {
case item := <-l.hodlQueue.ChanOut():
htlcResolution = item.(invoices.HtlcResolution)
// No need to process it if the link is broken.
case <-l.cg.Done():
return ErrLinkShuttingDown
default:
break loop
}
}
// Update the commitment tx.
if err := l.updateCommitTx(ctx); err != nil {
return err
}
return nil
}
// processHtlcResolution applies a received htlc resolution to the provided
// htlc. When this function returns without an error, the commit tx should be
// updated.
func (l *channelLink) processHtlcResolution(resolution invoices.HtlcResolution,
htlc hodlHtlc) error {
circuitKey := resolution.CircuitKey()
// Determine required action for the resolution based on the type of
// resolution we have received.
switch res := resolution.(type) {
// Settle htlcs that returned a settle resolution using the preimage
// in the resolution.
case *invoices.HtlcSettleResolution:
l.log.Debugf("received settle resolution for %v "+
"with outcome: %v", circuitKey, res.Outcome)
return l.settleHTLC(
res.Preimage, htlc.add.ID, htlc.sourceRef,
)
// For htlc failures, we get the relevant failure message based
// on the failure resolution and then fail the htlc.
case *invoices.HtlcFailResolution:
l.log.Debugf("received cancel resolution for "+
"%v with outcome: %v", circuitKey, res.Outcome)
// Get the lnwire failure message based on the resolution
// result.
failure := getResolutionFailure(res, htlc.add.Amount)
l.sendHTLCError(
htlc.add, htlc.sourceRef, failure, htlc.obfuscator,
true,
)
return nil
// Fail if we do not get a settle of fail resolution, since we
// are only expecting to handle settles and fails.
default:
return fmt.Errorf("unknown htlc resolution type: %T",
resolution)
}
}
// getResolutionFailure returns the wire message that a htlc resolution should
// be failed with.
func getResolutionFailure(resolution *invoices.HtlcFailResolution,
amount lnwire.MilliSatoshi) *LinkError {
// If the resolution has been resolved as part of a MPP timeout,
// we need to fail the htlc with lnwire.FailMppTimeout.
if resolution.Outcome == invoices.ResultMppTimeout {
return NewDetailedLinkError(
&lnwire.FailMPPTimeout{}, resolution.Outcome,
)
}
// If the htlc is not a MPP timeout, we fail it with
// FailIncorrectDetails. This error is sent for invoice payment
// failures such as underpayment/ expiry too soon and hodl invoices
// (which return FailIncorrectDetails to avoid leaking information).
incorrectDetails := lnwire.NewFailIncorrectDetails(
amount, uint32(resolution.AcceptHeight),
)
return NewDetailedLinkError(incorrectDetails, resolution.Outcome)
}
// randomFeeUpdateTimeout returns a random timeout between the bounds defined
// within the link's configuration that will be used to determine when the link
// should propose an update to its commitment fee rate.
func (l *channelLink) randomFeeUpdateTimeout() time.Duration {
lower := int64(l.cfg.MinUpdateTimeout)
upper := int64(l.cfg.MaxUpdateTimeout)
return time.Duration(prand.Int63n(upper-lower) + lower)
}
// handleDownstreamUpdateAdd processes an UpdateAddHTLC packet sent from the
// downstream HTLC Switch.
func (l *channelLink) handleDownstreamUpdateAdd(ctx context.Context,
pkt *htlcPacket) error {
htlc, ok := pkt.htlc.(*lnwire.UpdateAddHTLC)
if !ok {
return errors.New("not an UpdateAddHTLC packet")
}
// If we are flushing the link in the outgoing direction or we have
// already sent Stfu, then we can't add new htlcs to the link and we
// need to bounce it.
if l.IsFlushing(Outgoing) || !l.quiescer.CanSendUpdates() {
l.mailBox.FailAdd(pkt)
return NewDetailedLinkError(
&lnwire.FailTemporaryChannelFailure{},
OutgoingFailureLinkNotEligible,
)
}
// If hodl.AddOutgoing mode is active, we exit early to simulate
// arbitrary delays between the switch adding an ADD to the
// mailbox, and the HTLC being added to the commitment state.
if l.cfg.HodlMask.Active(hodl.AddOutgoing) {
l.log.Warnf(hodl.AddOutgoing.Warning())
l.mailBox.AckPacket(pkt.inKey())
return nil
}
// Check if we can add the HTLC here without exceededing the max fee
// exposure threshold.
if l.isOverexposedWithHtlc(htlc, false) {
l.log.Debugf("Unable to handle downstream HTLC - max fee " +
"exposure exceeded")
l.mailBox.FailAdd(pkt)
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureDownstreamHtlcAdd,
)
}
// A new payment has been initiated via the downstream channel,
// so we add the new HTLC to our local log, then update the
// commitment chains.
htlc.ChanID = l.ChanID()
openCircuitRef := pkt.inKey()
// We enforce the fee buffer for the commitment transaction because
// we are in control of adding this htlc. Nothing has locked-in yet so
// we can securely enforce the fee buffer which is only relevant if we
// are the initiator of the channel.
index, err := l.channel.AddHTLC(htlc, &openCircuitRef)
if err != nil {
// The HTLC was unable to be added to the state machine,
// as a result, we'll signal the switch to cancel the
// pending payment.
l.log.Warnf("Unable to handle downstream add HTLC: %v",
err)
// Remove this packet from the link's mailbox, this
// prevents it from being reprocessed if the link
// restarts and resets it mailbox. If this response
// doesn't make it back to the originating link, it will
// be rejected upon attempting to reforward the Add to
// the switch, since the circuit was never fully opened,
// and the forwarding package shows it as
// unacknowledged.
l.mailBox.FailAdd(pkt)
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureDownstreamHtlcAdd,
)
}
l.log.Tracef("received downstream htlc: payment_hash=%x, "+
"local_log_index=%v, pend_updates=%v",
htlc.PaymentHash[:], index,
l.channel.NumPendingUpdates(lntypes.Local, lntypes.Remote))
pkt.outgoingChanID = l.ShortChanID()
pkt.outgoingHTLCID = index
htlc.ID = index
l.log.Debugf("queueing keystone of ADD open circuit: %s->%s",
pkt.inKey(), pkt.outKey())
l.openedCircuits = append(l.openedCircuits, pkt.inKey())
l.keystoneBatch = append(l.keystoneBatch, pkt.keystone())
_ = l.cfg.Peer.SendMessage(false, htlc)
// Send a forward event notification to htlcNotifier.
l.cfg.HtlcNotifier.NotifyForwardingEvent(
newHtlcKey(pkt),
HtlcInfo{
IncomingTimeLock: pkt.incomingTimeout,
IncomingAmt: pkt.incomingAmount,
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
getEventType(pkt),
)
l.tryBatchUpdateCommitTx(ctx)
return nil
}
// handleDownstreamPkt processes an HTLC packet sent from the downstream HTLC
// Switch. Possible messages sent by the switch include requests to forward new
// HTLCs, timeout previously cleared HTLCs, and finally to settle currently
// cleared HTLCs with the upstream peer.
//
// TODO(roasbeef): add sync ntfn to ensure switch always has consistent view?
func (l *channelLink) handleDownstreamPkt(ctx context.Context,
pkt *htlcPacket) {
if pkt.htlc.MsgType().IsChannelUpdate() &&
!l.quiescer.CanSendUpdates() {
l.log.Warnf("unable to process channel update. "+
"ChannelID=%v is quiescent.", l.ChanID)
return
}
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
// Handle add message. The returned error can be ignored,
// because it is also sent through the mailbox.
_ = l.handleDownstreamUpdateAdd(ctx, pkt)
case *lnwire.UpdateFulfillHTLC:
// If hodl.SettleOutgoing mode is active, we exit early to
// simulate arbitrary delays between the switch adding the
// SETTLE to the mailbox, and the HTLC being added to the
// commitment state.
if l.cfg.HodlMask.Active(hodl.SettleOutgoing) {
l.log.Warnf(hodl.SettleOutgoing.Warning())
l.mailBox.AckPacket(pkt.inKey())
return
}
// An HTLC we forward to the switch has just settled somewhere
// upstream. Therefore we settle the HTLC within the our local
// state machine.
inKey := pkt.inKey()
err := l.channel.SettleHTLC(
htlc.PaymentPreimage,
pkt.incomingHTLCID,
pkt.sourceRef,
pkt.destRef,
&inKey,
)
if err != nil {
l.log.Errorf("unable to settle incoming HTLC for "+
"circuit-key=%v: %v", inKey, err)
// If the HTLC index for Settle response was not known
// to our commitment state, it has already been
// cleaned up by a prior response. We'll thus try to
// clean up any lingering state to ensure we don't
// continue reforwarding.
if _, ok := err.(lnwallet.ErrUnknownHtlcIndex); ok {
l.cleanupSpuriousResponse(pkt)
}
// Remove the packet from the link's mailbox to ensure
// it doesn't get replayed after a reconnection.
l.mailBox.AckPacket(inKey)
return
}
l.log.Debugf("queueing removal of SETTLE closed circuit: "+
"%s->%s", pkt.inKey(), pkt.outKey())
l.closedCircuits = append(l.closedCircuits, pkt.inKey())
// With the HTLC settled, we'll need to populate the wire
// message to target the specific channel and HTLC to be
// canceled.
htlc.ChanID = l.ChanID()
htlc.ID = pkt.incomingHTLCID
// Then we send the HTLC settle message to the connected peer
// so we can continue the propagation of the settle message.
l.cfg.Peer.SendMessage(false, htlc)
// Send a settle event notification to htlcNotifier.
l.cfg.HtlcNotifier.NotifySettleEvent(
newHtlcKey(pkt),
htlc.PaymentPreimage,
getEventType(pkt),
)
// Immediately update the commitment tx to minimize latency.
l.updateCommitTxOrFail(ctx)
case *lnwire.UpdateFailHTLC:
// If hodl.FailOutgoing mode is active, we exit early to
// simulate arbitrary delays between the switch adding a FAIL to
// the mailbox, and the HTLC being added to the commitment
// state.
if l.cfg.HodlMask.Active(hodl.FailOutgoing) {
l.log.Warnf(hodl.FailOutgoing.Warning())
l.mailBox.AckPacket(pkt.inKey())
return
}
// An HTLC cancellation has been triggered somewhere upstream,
// we'll remove then HTLC from our local state machine.
inKey := pkt.inKey()
err := l.channel.FailHTLC(
pkt.incomingHTLCID,
htlc.Reason,
pkt.sourceRef,
pkt.destRef,
&inKey,
)
if err != nil {
l.log.Errorf("unable to cancel incoming HTLC for "+
"circuit-key=%v: %v", inKey, err)
// If the HTLC index for Fail response was not known to
// our commitment state, it has already been cleaned up
// by a prior response. We'll thus try to clean up any
// lingering state to ensure we don't continue
// reforwarding.
if _, ok := err.(lnwallet.ErrUnknownHtlcIndex); ok {
l.cleanupSpuriousResponse(pkt)
}
// Remove the packet from the link's mailbox to ensure
// it doesn't get replayed after a reconnection.
l.mailBox.AckPacket(inKey)
return
}
l.log.Debugf("queueing removal of FAIL closed circuit: %s->%s",
pkt.inKey(), pkt.outKey())
l.closedCircuits = append(l.closedCircuits, pkt.inKey())
// With the HTLC removed, we'll need to populate the wire
// message to target the specific channel and HTLC to be
// canceled. The "Reason" field will have already been set
// within the switch.
htlc.ChanID = l.ChanID()
htlc.ID = pkt.incomingHTLCID
// We send the HTLC message to the peer which initially created
// the HTLC. If the incoming blinding point is non-nil, we
// know that we are a relaying node in a blinded path.
// Otherwise, we're either an introduction node or not part of
// a blinded path at all.
if err := l.sendIncomingHTLCFailureMsg(
htlc.ID,
pkt.obfuscator,
htlc.Reason,
); err != nil {
l.log.Errorf("unable to send HTLC failure: %v",
err)
return
}
// If the packet does not have a link failure set, it failed
// further down the route so we notify a forwarding failure.
// Otherwise, we notify a link failure because it failed at our
// node.
if pkt.linkFailure != nil {
l.cfg.HtlcNotifier.NotifyLinkFailEvent(
newHtlcKey(pkt),
newHtlcInfo(pkt),
getEventType(pkt),
pkt.linkFailure,
false,
)
} else {
l.cfg.HtlcNotifier.NotifyForwardingFailEvent(
newHtlcKey(pkt), getEventType(pkt),
)
}
// Immediately update the commitment tx to minimize latency.
l.updateCommitTxOrFail(ctx)
}
}
// tryBatchUpdateCommitTx updates the commitment transaction if the batch is
// full.
func (l *channelLink) tryBatchUpdateCommitTx(ctx context.Context) {
pending := l.channel.NumPendingUpdates(lntypes.Local, lntypes.Remote)
if pending < uint64(l.cfg.BatchSize) {
return
}
l.updateCommitTxOrFail(ctx)
}
// cleanupSpuriousResponse attempts to ack any AddRef or SettleFailRef
// associated with this packet. If successful in doing so, it will also purge
// the open circuit from the circuit map and remove the packet from the link's
// mailbox.
func (l *channelLink) cleanupSpuriousResponse(pkt *htlcPacket) {
inKey := pkt.inKey()
l.log.Debugf("cleaning up spurious response for incoming "+
"circuit-key=%v", inKey)
// If the htlc packet doesn't have a source reference, it is unsafe to
// proceed, as skipping this ack may cause the htlc to be reforwarded.
if pkt.sourceRef == nil {
l.log.Errorf("unable to cleanup response for incoming "+
"circuit-key=%v, does not contain source reference",
inKey)
return
}
// If the source reference is present, we will try to prevent this link
// from resending the packet to the switch. To do so, we ack the AddRef
// of the incoming HTLC belonging to this link.
err := l.channel.AckAddHtlcs(*pkt.sourceRef)
if err != nil {
l.log.Errorf("unable to ack AddRef for incoming "+
"circuit-key=%v: %v", inKey, err)
// If this operation failed, it is unsafe to attempt removal of
// the destination reference or circuit, so we exit early. The
// cleanup may proceed with a different packet in the future
// that succeeds on this step.
return
}
// Now that we know this link will stop retransmitting Adds to the
// switch, we can begin to teardown the response reference and circuit
// map.
//
// If the packet includes a destination reference, then a response for
// this HTLC was locked into the outgoing channel. Attempt to remove
// this reference, so we stop retransmitting the response internally.
// Even if this fails, we will proceed in trying to delete the circuit.
// When retransmitting responses, the destination references will be
// cleaned up if an open circuit is not found in the circuit map.
if pkt.destRef != nil {
err := l.channel.AckSettleFails(*pkt.destRef)
if err != nil {
l.log.Errorf("unable to ack SettleFailRef "+
"for incoming circuit-key=%v: %v",
inKey, err)
}
}
l.log.Debugf("deleting circuit for incoming circuit-key=%x", inKey)
// With all known references acked, we can now safely delete the circuit
// from the switch's circuit map, as the state is no longer needed.
err = l.cfg.Circuits.DeleteCircuits(inKey)
if err != nil {
l.log.Errorf("unable to delete circuit for "+
"circuit-key=%v: %v", inKey, err)
}
}
// handleUpstreamMsg processes wire messages related to commitment state
// updates from the upstream peer. The upstream peer is the peer whom we have a
// direct channel with, updating our respective commitment chains.
//
//nolint:funlen
func (l *channelLink) handleUpstreamMsg(ctx context.Context,
msg lnwire.Message) {
l.log.Tracef("receive upstream msg %v, handling now... ", msg.MsgType())
defer l.log.Tracef("handled upstream msg %v", msg.MsgType())
// First check if the message is an update and we are capable of
// receiving updates right now.
if msg.MsgType().IsChannelUpdate() && !l.quiescer.CanRecvUpdates() {
l.stfuFailf("update received after stfu: %T", msg)
return
}
switch msg := msg.(type) {
case *lnwire.UpdateAddHTLC:
if l.IsFlushing(Incoming) {
// This is forbidden by the protocol specification.
// The best chance we have to deal with this is to drop
// the connection. This should roll back the channel
// state to the last CommitSig. If the remote has
// already sent a CommitSig we haven't received yet,
// channel state will be re-synchronized with a
// ChannelReestablish message upon reconnection and the
// protocol state that caused us to flush the link will
// be rolled back. In the event that there was some
// non-deterministic behavior in the remote that caused
// them to violate the protocol, we have a decent shot
// at correcting it this way, since reconnecting will
// put us in the cleanest possible state to try again.
//
// In addition to the above, it is possible for us to
// hit this case in situations where we improperly
// handle message ordering due to concurrency choices.
// An issue has been filed to address this here:
// https://github.com/lightningnetwork/lnd/issues/8393
l.failf(
LinkFailureError{
code: ErrInvalidUpdate,
FailureAction: LinkFailureDisconnect,
PermanentFailure: false,
Warning: true,
},
"received add while link is flushing",
)
return
}
// Disallow htlcs with blinding points set if we haven't
// enabled the feature. This saves us from having to process
// the onion at all, but will only catch blinded payments
// where we are a relaying node (as the blinding point will
// be in the payload when we're the introduction node).
if msg.BlindingPoint.IsSome() && l.cfg.DisallowRouteBlinding {
l.failf(LinkFailureError{code: ErrInvalidUpdate},
"blinding point included when route blinding "+
"is disabled")
return
}
// We have to check the limit here rather than later in the
// switch because the counterparty can keep sending HTLC's
// without sending a revoke. This would mean that the switch
// check would only occur later.
if l.isOverexposedWithHtlc(msg, true) {
l.failf(LinkFailureError{code: ErrInternalError},
"peer sent us an HTLC that exceeded our max "+
"fee exposure")
return
}
// We just received an add request from an upstream peer, so we
// add it to our state machine, then add the HTLC to our
// "settle" list in the event that we know the preimage.
index, err := l.channel.ReceiveHTLC(msg)
if err != nil {
l.failf(LinkFailureError{code: ErrInvalidUpdate},
"unable to handle upstream add HTLC: %v", err)
return
}
l.log.Tracef("receive upstream htlc with payment hash(%x), "+
"assigning index: %v", msg.PaymentHash[:], index)
case *lnwire.UpdateFulfillHTLC:
pre := msg.PaymentPreimage
idx := msg.ID
// Before we pipeline the settle, we'll check the set of active
// htlc's to see if the related UpdateAddHTLC has been fully
// locked-in.
var lockedin bool
htlcs := l.channel.ActiveHtlcs()
for _, add := range htlcs {
// The HTLC will be outgoing and match idx.
if !add.Incoming && add.HtlcIndex == idx {
lockedin = true
break
}
}
if !lockedin {
l.failf(
LinkFailureError{code: ErrInvalidUpdate},
"unable to handle upstream settle",
)
return
}
if err := l.channel.ReceiveHTLCSettle(pre, idx); err != nil {
l.failf(
LinkFailureError{
code: ErrInvalidUpdate,
FailureAction: LinkFailureForceClose,
},
"unable to handle upstream settle HTLC: %v", err,
)
return
}
settlePacket := &htlcPacket{
outgoingChanID: l.ShortChanID(),
outgoingHTLCID: idx,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: pre,
},
}
// Add the newly discovered preimage to our growing list of
// uncommitted preimage. These will be written to the witness
// cache just before accepting the next commitment signature
// from the remote peer.
l.uncommittedPreimages = append(l.uncommittedPreimages, pre)
// Pipeline this settle, send it to the switch.
go l.forwardBatch(false, settlePacket)
case *lnwire.UpdateFailMalformedHTLC:
// Convert the failure type encoded within the HTLC fail
// message to the proper generic lnwire error code.
var failure lnwire.FailureMessage
switch msg.FailureCode {
case lnwire.CodeInvalidOnionVersion:
failure = &lnwire.FailInvalidOnionVersion{
OnionSHA256: msg.ShaOnionBlob,
}
case lnwire.CodeInvalidOnionHmac:
failure = &lnwire.FailInvalidOnionHmac{
OnionSHA256: msg.ShaOnionBlob,
}
case lnwire.CodeInvalidOnionKey:
failure = &lnwire.FailInvalidOnionKey{
OnionSHA256: msg.ShaOnionBlob,
}
// Handle malformed errors that are part of a blinded route.
// This case is slightly different, because we expect every
// relaying node in the blinded portion of the route to send
// malformed errors. If we're also a relaying node, we're
// likely going to switch this error out anyway for our own
// malformed error, but we handle the case here for
// completeness.
case lnwire.CodeInvalidBlinding:
failure = &lnwire.FailInvalidBlinding{
OnionSHA256: msg.ShaOnionBlob,
}
default:
l.log.Warnf("unexpected failure code received in "+
"UpdateFailMailformedHTLC: %v", msg.FailureCode)
// We don't just pass back the error we received from
// our successor. Otherwise we might report a failure
// that penalizes us more than needed. If the onion that
// we forwarded was correct, the node should have been
// able to send back its own failure. The node did not
// send back its own failure, so we assume there was a
// problem with the onion and report that back. We reuse
// the invalid onion key failure because there is no
// specific error for this case.
failure = &lnwire.FailInvalidOnionKey{
OnionSHA256: msg.ShaOnionBlob,
}
}
// With the error parsed, we'll convert the into it's opaque
// form.
var b bytes.Buffer
if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
l.log.Errorf("unable to encode malformed error: %v", err)
return
}
// If remote side have been unable to parse the onion blob we
// have sent to it, than we should transform the malformed HTLC
// message to the usual HTLC fail message.
err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes())
if err != nil {
l.failf(LinkFailureError{code: ErrInvalidUpdate},
"unable to handle upstream fail HTLC: %v", err)
return
}
case *lnwire.UpdateFailHTLC:
// Verify that the failure reason is at least 256 bytes plus
// overhead.
const minimumFailReasonLength = lnwire.FailureMessageLength +
2 + 2 + 32
if len(msg.Reason) < minimumFailReasonLength {
// We've received a reason with a non-compliant length.
// Older nodes happily relay back these failures that
// may originate from a node further downstream.
// Therefore we can't just fail the channel.
//
// We want to be compliant ourselves, so we also can't
// pass back the reason unmodified. And we must make
// sure that we don't hit the magic length check of 260
// bytes in processRemoteSettleFails either.
//
// Because the reason is unreadable for the payer
// anyway, we just replace it by a compliant-length
// series of random bytes.
msg.Reason = make([]byte, minimumFailReasonLength)
_, err := crand.Read(msg.Reason[:])
if err != nil {
l.log.Errorf("Random generation error: %v", err)
return
}
}
// Add fail to the update log.
idx := msg.ID
err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:])
if err != nil {
l.failf(LinkFailureError{code: ErrInvalidUpdate},
"unable to handle upstream fail HTLC: %v", err)
return
}
case *lnwire.CommitSig:
// Since we may have learned new preimages for the first time,
// we'll add them to our preimage cache. By doing this, we
// ensure any contested contracts watched by any on-chain
// arbitrators can now sweep this HTLC on-chain. We delay
// committing the preimages until just before accepting the new
// remote commitment, as afterwards the peer won't resend the
// Settle messages on the next channel reestablishment. Doing so
// allows us to more effectively batch this operation, instead
// of doing a single write per preimage.
err := l.cfg.PreimageCache.AddPreimages(
l.uncommittedPreimages...,
)
if err != nil {
l.failf(
LinkFailureError{code: ErrInternalError},
"unable to add preimages=%v to cache: %v",
l.uncommittedPreimages, err,
)
return
}
// Instead of truncating the slice to conserve memory
// allocations, we simply set the uncommitted preimage slice to
// nil so that a new one will be initialized if any more
// witnesses are discovered. We do this because the maximum size
// that the slice can occupy is 15KB, and we want to ensure we
// release that memory back to the runtime.
l.uncommittedPreimages = nil
// We just received a new updates to our local commitment
// chain, validate this new commitment, closing the link if
// invalid.
auxSigBlob, err := msg.CustomRecords.Serialize()
if err != nil {
l.failf(
LinkFailureError{code: ErrInvalidCommitment},
"unable to serialize custom records: %v", err,
)
return
}
err = l.channel.ReceiveNewCommitment(&lnwallet.CommitSigs{
CommitSig: msg.CommitSig,
HtlcSigs: msg.HtlcSigs,
PartialSig: msg.PartialSig,
AuxSigBlob: auxSigBlob,
})
if err != nil {
// If we were unable to reconstruct their proposed
// commitment, then we'll examine the type of error. If
// it's an InvalidCommitSigError, then we'll send a
// direct error.
var sendData []byte
switch err.(type) {
case *lnwallet.InvalidCommitSigError:
sendData = []byte(err.Error())
case *lnwallet.InvalidHtlcSigError:
sendData = []byte(err.Error())
}
l.failf(
LinkFailureError{
code: ErrInvalidCommitment,
FailureAction: LinkFailureForceClose,
SendData: sendData,
},
"ChannelPoint(%v): unable to accept new "+
"commitment: %v",
l.channel.ChannelPoint(), err,
)
return
}
// As we've just accepted a new state, we'll now
// immediately send the remote peer a revocation for our prior
// state.
nextRevocation, currentHtlcs, finalHTLCs, err :=
l.channel.RevokeCurrentCommitment()
if err != nil {
l.log.Errorf("unable to revoke commitment: %v", err)
// We need to fail the channel in case revoking our
// local commitment does not succeed. We might have
// already advanced our channel state which would lead
// us to proceed with an unclean state.
//
// NOTE: We do not trigger a force close because this
// could resolve itself in case our db was just busy
// not accepting new transactions.
l.failf(
LinkFailureError{
code: ErrInternalError,
Warning: true,
FailureAction: LinkFailureDisconnect,
},
"ChannelPoint(%v): unable to accept new "+
"commitment: %v",
l.channel.ChannelPoint(), err,
)
return
}
// As soon as we are ready to send our next revocation, we can
// invoke the incoming commit hooks.
l.RWMutex.Lock()
l.incomingCommitHooks.invoke()
l.RWMutex.Unlock()
l.cfg.Peer.SendMessage(false, nextRevocation)
// Notify the incoming htlcs of which the resolutions were
// locked in.
for id, settled := range finalHTLCs {
l.cfg.HtlcNotifier.NotifyFinalHtlcEvent(
models.CircuitKey{
ChanID: l.ShortChanID(),
HtlcID: id,
},
channeldb.FinalHtlcInfo{
Settled: settled,
Offchain: true,
},
)
}
// Since we just revoked our commitment, we may have a new set
// of HTLC's on our commitment, so we'll send them using our
// function closure NotifyContractUpdate.
newUpdate := &contractcourt.ContractUpdate{
HtlcKey: contractcourt.LocalHtlcSet,
Htlcs: currentHtlcs,
}
err = l.cfg.NotifyContractUpdate(newUpdate)
if err != nil {
l.log.Errorf("unable to notify contract update: %v",
err)
return
}
select {
case <-l.cg.Done():
return
default:
}
// If the remote party initiated the state transition,
// we'll reply with a signature to provide them with their
// version of the latest commitment. Otherwise, both commitment
// chains are fully synced from our PoV, then we don't need to
// reply with a signature as both sides already have a
// commitment with the latest accepted.
if l.channel.OweCommitment() {
if !l.updateCommitTxOrFail(ctx) {
return
}
}
// If we need to send out an Stfu, this would be the time to do
// so.
if l.noDanglingUpdates(lntypes.Local) {
err = l.quiescer.SendOwedStfu()
if err != nil {
l.stfuFailf("sendOwedStfu: %v", err.Error())
}
}
// Now that we have finished processing the incoming CommitSig
// and sent out our RevokeAndAck, we invoke the flushHooks if
// the channel state is clean.
l.RWMutex.Lock()
if l.channel.IsChannelClean() {
l.flushHooks.invoke()
}
l.RWMutex.Unlock()
case *lnwire.RevokeAndAck:
// We've received a revocation from the remote chain, if valid,
// this moves the remote chain forward, and expands our
// revocation window.
// We now process the message and advance our remote commit
// chain.
fwdPkg, remoteHTLCs, err := l.channel.ReceiveRevocation(msg)
if err != nil {
// TODO(halseth): force close?
l.failf(
LinkFailureError{
code: ErrInvalidRevocation,
FailureAction: LinkFailureDisconnect,
},
"unable to accept revocation: %v", err,
)
return
}
// The remote party now has a new primary commitment, so we'll
// update the contract court to be aware of this new set (the
// prior old remote pending).
newUpdate := &contractcourt.ContractUpdate{
HtlcKey: contractcourt.RemoteHtlcSet,
Htlcs: remoteHTLCs,
}
err = l.cfg.NotifyContractUpdate(newUpdate)
if err != nil {
l.log.Errorf("unable to notify contract update: %v",
err)
return
}
select {
case <-l.cg.Done():
return
default:
}
// If we have a tower client for this channel type, we'll
// create a backup for the current state.
if l.cfg.TowerClient != nil {
state := l.channel.State()
chanID := l.ChanID()
err = l.cfg.TowerClient.BackupState(
&chanID, state.RemoteCommitment.CommitHeight-1,
)
if err != nil {
l.failf(LinkFailureError{
code: ErrInternalError,
}, "unable to queue breach backup: %v", err)
return
}
}
// If we can send updates then we can process adds in case we
// are the exit hop and need to send back resolutions, or in
// case there are validity issues with the packets. Otherwise
// we defer the action until resume.
//
// We are free to process the settles and fails without this
// check since processing those can't result in further updates
// to this channel link.
if l.quiescer.CanSendUpdates() {
l.processRemoteAdds(fwdPkg)
} else {
l.quiescer.OnResume(func() {
l.processRemoteAdds(fwdPkg)
})
}
l.processRemoteSettleFails(fwdPkg)
// If the link failed during processing the adds, we must
// return to ensure we won't attempted to update the state
// further.
if l.failed {
return
}
// The revocation window opened up. If there are pending local
// updates, try to update the commit tx. Pending updates could
// already have been present because of a previously failed
// update to the commit tx or freshly added in by
// processRemoteAdds. Also in case there are no local updates,
// but there are still remote updates that are not in the remote
// commit tx yet, send out an update.
if l.channel.OweCommitment() {
if !l.updateCommitTxOrFail(ctx) {
return
}
}
// Now that we have finished processing the RevokeAndAck, we
// can invoke the flushHooks if the channel state is clean.
l.RWMutex.Lock()
if l.channel.IsChannelClean() {
l.flushHooks.invoke()
}
l.RWMutex.Unlock()
case *lnwire.UpdateFee:
// Check and see if their proposed fee-rate would make us
// exceed the fee threshold.
fee := chainfee.SatPerKWeight(msg.FeePerKw)
isDust, err := l.exceedsFeeExposureLimit(fee)
if err != nil {
// This shouldn't typically happen. If it does, it
// indicates something is wrong with our channel state.
l.log.Errorf("Unable to determine if fee threshold " +
"exceeded")
l.failf(LinkFailureError{code: ErrInternalError},
"error calculating fee exposure: %v", err)
return
}
if isDust {
// The proposed fee-rate makes us exceed the fee
// threshold.
l.failf(LinkFailureError{code: ErrInternalError},
"fee threshold exceeded: %v", err)
return
}
// We received fee update from peer. If we are the initiator we
// will fail the channel, if not we will apply the update.
if err := l.channel.ReceiveUpdateFee(fee); err != nil {
l.failf(LinkFailureError{code: ErrInvalidUpdate},
"error receiving fee update: %v", err)
return
}
// Update the mailbox's feerate as well.
l.mailBox.SetFeeRate(fee)
case *lnwire.Stfu:
err := l.handleStfu(msg)
if err != nil {
l.stfuFailf("handleStfu: %v", err.Error())
}
// In the case where we receive a warning message from our peer, just
// log it and move on. We choose not to disconnect from our peer,
// although we "MAY" do so according to the specification.
case *lnwire.Warning:
l.log.Warnf("received warning message from peer: %v",
msg.Warning())
case *lnwire.Error:
// Error received from remote, MUST fail channel, but should
// only print the contents of the error message if all
// characters are printable ASCII.
l.failf(
LinkFailureError{
code: ErrRemoteError,
// TODO(halseth): we currently don't fail the
// channel permanently, as there are some sync
// issues with other implementations that will
// lead to them sending an error message, but
// we can recover from on next connection. See
// https://github.com/ElementsProject/lightning/issues/4212
PermanentFailure: false,
},
"ChannelPoint(%v): received error from peer: %v",
l.channel.ChannelPoint(), msg.Error(),
)
default:
l.log.Warnf("received unknown message of type %T", msg)
}
}
// handleStfu implements the top-level logic for handling the Stfu message from
// our peer.
func (l *channelLink) handleStfu(stfu *lnwire.Stfu) error {
if !l.noDanglingUpdates(lntypes.Remote) {
return ErrPendingRemoteUpdates
}
err := l.quiescer.RecvStfu(*stfu)
if err != nil {
return err
}
// If we can immediately send an Stfu response back, we will.
if l.noDanglingUpdates(lntypes.Local) {
return l.quiescer.SendOwedStfu()
}
return nil
}
// stfuFailf fails the link in the case where the requirements of the quiescence
// protocol are violated. In all cases we opt to drop the connection as only
// link state (as opposed to channel state) is affected.
func (l *channelLink) stfuFailf(format string, args ...interface{}) {
l.failf(LinkFailureError{
code: ErrStfuViolation,
FailureAction: LinkFailureDisconnect,
PermanentFailure: false,
Warning: true,
}, format, args...)
}
// noDanglingUpdates returns true when there are 0 updates that were originally
// issued by whose on either the Local or Remote commitment transaction.
func (l *channelLink) noDanglingUpdates(whose lntypes.ChannelParty) bool {
pendingOnLocal := l.channel.NumPendingUpdates(
whose, lntypes.Local,
)
pendingOnRemote := l.channel.NumPendingUpdates(
whose, lntypes.Remote,
)
return pendingOnLocal == 0 && pendingOnRemote == 0
}
// ackDownStreamPackets is responsible for removing htlcs from a link's mailbox
// for packets delivered from server, and cleaning up any circuits closed by
// signing a previous commitment txn. This method ensures that the circuits are
// removed from the circuit map before removing them from the link's mailbox,
// otherwise it could be possible for some circuit to be missed if this link
// flaps.
func (l *channelLink) ackDownStreamPackets() error {
// First, remove the downstream Add packets that were included in the
// previous commitment signature. This will prevent the Adds from being
// replayed if this link disconnects.
for _, inKey := range l.openedCircuits {
// In order to test the sphinx replay logic of the remote
// party, unsafe replay does not acknowledge the packets from
// the mailbox. We can then force a replay of any Add packets
// held in memory by disconnecting and reconnecting the link.
if l.cfg.UnsafeReplay {
continue
}
l.log.Debugf("removing Add packet %s from mailbox", inKey)
l.mailBox.AckPacket(inKey)
}
// Now, we will delete all circuits closed by the previous commitment
// signature, which is the result of downstream Settle/Fail packets. We
// batch them here to ensure circuits are closed atomically and for
// performance.
err := l.cfg.Circuits.DeleteCircuits(l.closedCircuits...)
switch err {
case nil:
// Successful deletion.
default:
l.log.Errorf("unable to delete %d circuits: %v",
len(l.closedCircuits), err)
return err
}
// With the circuits removed from memory and disk, we now ack any
// Settle/Fails in the mailbox to ensure they do not get redelivered
// after startup. If forgive is enabled and we've reached this point,
// the circuits must have been removed at some point, so it is now safe
// to un-queue the corresponding Settle/Fails.
for _, inKey := range l.closedCircuits {
l.log.Debugf("removing Fail/Settle packet %s from mailbox",
inKey)
l.mailBox.AckPacket(inKey)
}
// Lastly, reset our buffers to be empty while keeping any acquired
// growth in the backing array.
l.openedCircuits = l.openedCircuits[:0]
l.closedCircuits = l.closedCircuits[:0]
return nil
}
// updateCommitTxOrFail updates the commitment tx and if that fails, it fails
// the link.
func (l *channelLink) updateCommitTxOrFail(ctx context.Context) bool {
err := l.updateCommitTx(ctx)
switch err {
// No error encountered, success.
case nil:
// A duplicate keystone error should be resolved and is not fatal, so
// we won't send an Error message to the peer.
case ErrDuplicateKeystone:
l.failf(LinkFailureError{code: ErrCircuitError},
"temporary circuit error: %v", err)
return false
// Any other error is treated results in an Error message being sent to
// the peer.
default:
l.failf(LinkFailureError{code: ErrInternalError},
"unable to update commitment: %v", err)
return false
}
return true
}
// updateCommitTx signs, then sends an update to the remote peer adding a new
// commitment to their commitment chain which includes all the latest updates
// we've received+processed up to this point.
func (l *channelLink) updateCommitTx(ctx context.Context) error {
// Preemptively write all pending keystones to disk, just in case the
// HTLCs we have in memory are included in the subsequent attempt to
// sign a commitment state.
err := l.cfg.Circuits.OpenCircuits(l.keystoneBatch...)
if err != nil {
// If ErrDuplicateKeystone is returned, the caller will catch
// it.
return err
}
// Reset the batch, but keep the backing buffer to avoid reallocating.
l.keystoneBatch = l.keystoneBatch[:0]
// If hodl.Commit mode is active, we will refrain from attempting to
// commit any in-memory modifications to the channel state. Exiting here
// permits testing of either the switch or link's ability to trim
// circuits that have been opened, but unsuccessfully committed.
if l.cfg.HodlMask.Active(hodl.Commit) {
l.log.Warnf(hodl.Commit.Warning())
return nil
}
ctx, done := l.cg.Create(ctx)
defer done()
newCommit, err := l.channel.SignNextCommitment(ctx)
if err == lnwallet.ErrNoWindow {
l.cfg.PendingCommitTicker.Resume()
l.log.Trace("PendingCommitTicker resumed")
n := l.channel.NumPendingUpdates(lntypes.Local, lntypes.Remote)
l.log.Tracef("revocation window exhausted, unable to send: "+
"%v, pend_updates=%v, dangling_closes%v", n,
lnutils.SpewLogClosure(l.openedCircuits),
lnutils.SpewLogClosure(l.closedCircuits))
return nil
} else if err != nil {
return err
}
if err := l.ackDownStreamPackets(); err != nil {
return err
}
l.cfg.PendingCommitTicker.Pause()
l.log.Trace("PendingCommitTicker paused after ackDownStreamPackets")
// The remote party now has a new pending commitment, so we'll update
// the contract court to be aware of this new set (the prior old remote
// pending).
newUpdate := &contractcourt.ContractUpdate{
HtlcKey: contractcourt.RemotePendingHtlcSet,
Htlcs: newCommit.PendingHTLCs,
}
err = l.cfg.NotifyContractUpdate(newUpdate)
if err != nil {
l.log.Errorf("unable to notify contract update: %v", err)
return err
}
select {
case <-l.cg.Done():
return ErrLinkShuttingDown
default:
}
auxBlobRecords, err := lnwire.ParseCustomRecords(newCommit.AuxSigBlob)
if err != nil {
return fmt.Errorf("error parsing aux sigs: %w", err)
}
commitSig := &lnwire.CommitSig{
ChanID: l.ChanID(),
CommitSig: newCommit.CommitSig,
HtlcSigs: newCommit.HtlcSigs,
PartialSig: newCommit.PartialSig,
CustomRecords: auxBlobRecords,
}
l.cfg.Peer.SendMessage(false, commitSig)
// Now that we have sent out a new CommitSig, we invoke the outgoing set
// of commit hooks.
l.RWMutex.Lock()
l.outgoingCommitHooks.invoke()
l.RWMutex.Unlock()
return nil
}
// Peer returns the representation of remote peer with which we have the
// channel link opened.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) PeerPubKey() [33]byte {
return l.cfg.Peer.PubKey()
}
// ChannelPoint returns the channel outpoint for the channel link.
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) ChannelPoint() wire.OutPoint {
return l.channel.ChannelPoint()
}
// ShortChanID returns the short channel ID for the channel link. The short
// channel ID encodes the exact location in the main chain that the original
// funding output can be found.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) ShortChanID() lnwire.ShortChannelID {
l.RLock()
defer l.RUnlock()
return l.channel.ShortChanID()
}
// UpdateShortChanID updates the short channel ID for a link. This may be
// required in the event that a link is created before the short chan ID for it
// is known, or a re-org occurs, and the funding transaction changes location
// within the chain.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
chanID := l.ChanID()
// Refresh the channel state's short channel ID by loading it from disk.
// This ensures that the channel state accurately reflects the updated
// short channel ID.
err := l.channel.State().Refresh()
if err != nil {
l.log.Errorf("unable to refresh short_chan_id for chan_id=%v: "+
"%v", chanID, err)
return hop.Source, err
}
return hop.Source, nil
}
// ChanID returns the channel ID for the channel link. The channel ID is a more
// compact representation of a channel's full outpoint.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) ChanID() lnwire.ChannelID {
return lnwire.NewChanIDFromOutPoint(l.channel.ChannelPoint())
}
// Bandwidth returns the total amount that can flow through the channel link at
// this given instance. The value returned is expressed in millisatoshi and can
// be used by callers when making forwarding decisions to determine if a link
// can accept an HTLC.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) Bandwidth() lnwire.MilliSatoshi {
// Get the balance available on the channel for new HTLCs. This takes
// the channel reserve into account so HTLCs up to this value won't
// violate it.
return l.channel.AvailableBalance()
}
// MayAddOutgoingHtlc indicates whether we can add an outgoing htlc with the
// amount provided to the link. This check does not reserve a space, since
// forwards or other payments may use the available slot, so it should be
// considered best-effort.
func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error {
return l.channel.MayAddOutgoingHtlc(amt)
}
// getDustSum is a wrapper method that calls the underlying channel's dust sum
// method.
//
// NOTE: Part of the dustHandler interface.
func (l *channelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
return l.channel.GetDustSum(whoseCommit, dryRunFee)
}
// getFeeRate is a wrapper method that retrieves the underlying channel's
// feerate.
//
// NOTE: Part of the dustHandler interface.
func (l *channelLink) getFeeRate() chainfee.SatPerKWeight {
return l.channel.CommitFeeRate()
}
// getDustClosure returns a closure that can be used by the switch or mailbox
// to evaluate whether a given HTLC is dust.
//
// NOTE: Part of the dustHandler interface.
func (l *channelLink) getDustClosure() dustClosure {
localDustLimit := l.channel.State().LocalChanCfg.DustLimit
remoteDustLimit := l.channel.State().RemoteChanCfg.DustLimit
chanType := l.channel.State().ChanType
return dustHelper(chanType, localDustLimit, remoteDustLimit)
}
// getCommitFee returns either the local or remote CommitFee in satoshis. This
// is used so that the Switch can have access to the commitment fee without
// needing to have a *LightningChannel. This doesn't include dust.
//
// NOTE: Part of the dustHandler interface.
func (l *channelLink) getCommitFee(remote bool) btcutil.Amount {
if remote {
return l.channel.State().RemoteCommitment.CommitFee
}
return l.channel.State().LocalCommitment.CommitFee
}
// exceedsFeeExposureLimit returns whether or not the new proposed fee-rate
// increases the total dust and fees within the channel past the configured
// fee threshold. It first calculates the dust sum over every update in the
// update log with the proposed fee-rate and taking into account both the local
// and remote dust limits. It uses every update in the update log instead of
// what is actually on the local and remote commitments because it is assumed
// that in a worst-case scenario, every update in the update log could
// theoretically be on either commitment transaction and this needs to be
// accounted for with this fee-rate. It then calculates the local and remote
// commitment fees given the proposed fee-rate. Finally, it tallies the results
// and determines if the fee threshold has been exceeded.
func (l *channelLink) exceedsFeeExposureLimit(
feePerKw chainfee.SatPerKWeight) (bool, error) {
dryRunFee := fn.Some[chainfee.SatPerKWeight](feePerKw)
// Get the sum of dust for both the local and remote commitments using
// this "dry-run" fee.
localDustSum := l.getDustSum(lntypes.Local, dryRunFee)
remoteDustSum := l.getDustSum(lntypes.Remote, dryRunFee)
// Calculate the local and remote commitment fees using this dry-run
// fee.
localFee, remoteFee, err := l.channel.CommitFeeTotalAt(feePerKw)
if err != nil {
return false, err
}
// Finally, check whether the max fee exposure was exceeded on either
// future commitment transaction with the fee-rate.
totalLocalDust := localDustSum + lnwire.NewMSatFromSatoshis(localFee)
if totalLocalDust > l.cfg.MaxFeeExposure {
l.log.Debugf("ChannelLink(%v): exceeds fee exposure limit: "+
"local dust: %v, local fee: %v", l.ShortChanID(),
totalLocalDust, localFee)
return true, nil
}
totalRemoteDust := remoteDustSum + lnwire.NewMSatFromSatoshis(
remoteFee,
)
if totalRemoteDust > l.cfg.MaxFeeExposure {
l.log.Debugf("ChannelLink(%v): exceeds fee exposure limit: "+
"remote dust: %v, remote fee: %v", l.ShortChanID(),
totalRemoteDust, remoteFee)
return true, nil
}
return false, nil
}
// isOverexposedWithHtlc calculates whether the proposed HTLC will make the
// channel exceed the fee threshold. It first fetches the largest fee-rate that
// may be on any unrevoked commitment transaction. Then, using this fee-rate,
// determines if the to-be-added HTLC is dust. If the HTLC is dust, it adds to
// the overall dust sum. If it is not dust, it contributes to weight, which
// also adds to the overall dust sum by an increase in fees. If the dust sum on
// either commitment exceeds the configured fee threshold, this function
// returns true.
func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC,
incoming bool) bool {
dustClosure := l.getDustClosure()
feeRate := l.channel.WorstCaseFeeRate()
amount := htlc.Amount.ToSatoshis()
// See if this HTLC is dust on both the local and remote commitments.
isLocalDust := dustClosure(feeRate, incoming, lntypes.Local, amount)
isRemoteDust := dustClosure(feeRate, incoming, lntypes.Remote, amount)
// Calculate the dust sum for the local and remote commitments.
localDustSum := l.getDustSum(
lntypes.Local, fn.None[chainfee.SatPerKWeight](),
)
remoteDustSum := l.getDustSum(
lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
)
// Grab the larger of the local and remote commitment fees w/o dust.
commitFee := l.getCommitFee(false)
if l.getCommitFee(true) > commitFee {
commitFee = l.getCommitFee(true)
}
commitFeeMSat := lnwire.NewMSatFromSatoshis(commitFee)
localDustSum += commitFeeMSat
remoteDustSum += commitFeeMSat
// Calculate the additional fee increase if this is a non-dust HTLC.
weight := lntypes.WeightUnit(input.HTLCWeight)
additional := lnwire.NewMSatFromSatoshis(
feeRate.FeeForWeight(weight),
)
if isLocalDust {
// If this is dust, it doesn't contribute to weight but does
// contribute to the overall dust sum.
localDustSum += lnwire.NewMSatFromSatoshis(amount)
} else {
// Account for the fee increase that comes with an increase in
// weight.
localDustSum += additional
}
if localDustSum > l.cfg.MaxFeeExposure {
// The max fee exposure was exceeded.
l.log.Debugf("ChannelLink(%v): HTLC %v makes the channel "+
"overexposed, total local dust: %v (current commit "+
"fee: %v)", l.ShortChanID(), htlc, localDustSum)
return true
}
if isRemoteDust {
// If this is dust, it doesn't contribute to weight but does
// contribute to the overall dust sum.
remoteDustSum += lnwire.NewMSatFromSatoshis(amount)
} else {
// Account for the fee increase that comes with an increase in
// weight.
remoteDustSum += additional
}
if remoteDustSum > l.cfg.MaxFeeExposure {
// The max fee exposure was exceeded.
l.log.Debugf("ChannelLink(%v): HTLC %v makes the channel "+
"overexposed, total remote dust: %v (current commit "+
"fee: %v)", l.ShortChanID(), htlc, remoteDustSum)
return true
}
return false
}
// dustClosure is a function that evaluates whether an HTLC is dust. It returns
// true if the HTLC is dust. It takes in a feerate, a boolean denoting whether
// the HTLC is incoming (i.e. one that the remote sent), a boolean denoting
// whether to evaluate on the local or remote commit, and finally an HTLC
// amount to test.
type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool,
whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool
// dustHelper is used to construct the dustClosure.
func dustHelper(chantype channeldb.ChannelType, localDustLimit,
remoteDustLimit btcutil.Amount) dustClosure {
isDust := func(feerate chainfee.SatPerKWeight, incoming bool,
whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool {
var dustLimit btcutil.Amount
if whoseCommit.IsLocal() {
dustLimit = localDustLimit
} else {
dustLimit = remoteDustLimit
}
return lnwallet.HtlcIsDust(
chantype, incoming, whoseCommit, feerate, amt,
dustLimit,
)
}
return isDust
}
// zeroConfConfirmed returns whether or not the zero-conf channel has
// confirmed on-chain.
//
// Part of the scidAliasHandler interface.
func (l *channelLink) zeroConfConfirmed() bool {
return l.channel.State().ZeroConfConfirmed()
}
// confirmedScid returns the confirmed SCID for a zero-conf channel. This
// should not be called for non-zero-conf channels.
//
// Part of the scidAliasHandler interface.
func (l *channelLink) confirmedScid() lnwire.ShortChannelID {
return l.channel.State().ZeroConfRealScid()
}
// isZeroConf returns whether or not the underlying channel is a zero-conf
// channel.
//
// Part of the scidAliasHandler interface.
func (l *channelLink) isZeroConf() bool {
return l.channel.State().IsZeroConf()
}
// negotiatedAliasFeature returns whether or not the underlying channel has
// negotiated the option-scid-alias feature bit. This will be true for both
// option-scid-alias and zero-conf channel-types. It will also be true for
// channels with the feature bit but without the above channel-types.
//
// Part of the scidAliasFeature interface.
func (l *channelLink) negotiatedAliasFeature() bool {
return l.channel.State().NegotiatedAliasFeature()
}
// getAliases returns the set of aliases for the underlying channel.
//
// Part of the scidAliasHandler interface.
func (l *channelLink) getAliases() []lnwire.ShortChannelID {
return l.cfg.GetAliases(l.ShortChanID())
}
// attachFailAliasUpdate sets the link's FailAliasUpdate function.
//
// Part of the scidAliasHandler interface.
func (l *channelLink) attachFailAliasUpdate(closure func(
sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) {
l.Lock()
l.cfg.FailAliasUpdate = closure
l.Unlock()
}
// AttachMailBox updates the current mailbox used by this link, and hooks up
// the mailbox's message and packet outboxes to the link's upstream and
// downstream chans, respectively.
func (l *channelLink) AttachMailBox(mailbox MailBox) {
l.Lock()
l.mailBox = mailbox
l.upstream = mailbox.MessageOutBox()
l.downstream = mailbox.PacketOutBox()
l.Unlock()
// Set the mailbox's fee rate. This may be refreshing a feerate that was
// never committed.
l.mailBox.SetFeeRate(l.getFeeRate())
// Also set the mailbox's dust closure so that it can query whether HTLC's
// are dust given the current feerate.
l.mailBox.SetDustClosure(l.getDustClosure())
}
// UpdateForwardingPolicy updates the forwarding policy for the target
// ChannelLink. Once updated, the link will use the new forwarding policy to
// govern if it an incoming HTLC should be forwarded or not. We assume that
// fields that are zero are intentionally set to zero, so we'll use newPolicy to
// update all of the link's FwrdingPolicy's values.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) UpdateForwardingPolicy(
newPolicy models.ForwardingPolicy) {
l.Lock()
defer l.Unlock()
l.cfg.FwrdingPolicy = newPolicy
}
// CheckHtlcForward should return a nil error if the passed HTLC details
// satisfy the current forwarding policy fo the target link. Otherwise,
// a LinkError with a valid protocol failure message should be returned
// in order to signal to the source of the HTLC, the policy consistency
// issue.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
amtToForward lnwire.MilliSatoshi, incomingTimeout,
outgoingTimeout uint32, inboundFee models.InboundFee,
heightNow uint32, originalScid lnwire.ShortChannelID,
customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
l.RUnlock()
// Using the outgoing HTLC amount, we'll calculate the outgoing
// fee this incoming HTLC must carry in order to satisfy the constraints
// of the outgoing link.
outFee := ExpectedFee(policy, amtToForward)
// Then calculate the inbound fee that we charge based on the sum of
// outgoing HTLC amount and outgoing fee.
inFee := inboundFee.CalcFee(amtToForward + outFee)
// Add up both fee components. It is important to calculate both fees
// separately. An alternative way of calculating is to first determine
// an aggregate fee and apply that to the outgoing HTLC amount. However,
// rounding may cause the result to be slightly higher than in the case
// of separately rounded fee components. This potentially causes failed
// forwards for senders and is something to be avoided.
expectedFee := inFee + int64(outFee)
// If the actual fee is less than our expected fee, then we'll reject
// this HTLC as it didn't provide a sufficient amount of fees, or the
// values have been tampered with, or the send used incorrect/dated
// information to construct the forwarding information for this hop. In
// any case, we'll cancel this HTLC.
actualFee := int64(incomingHtlcAmt) - int64(amtToForward)
if incomingHtlcAmt < amtToForward || actualFee < expectedFee {
l.log.Warnf("outgoing htlc(%x) has insufficient fee: "+
"expected %v, got %v: incoming=%v, outgoing=%v, "+
"inboundFee=%v",
payHash[:], expectedFee, actualFee,
incomingHtlcAmt, amtToForward, inboundFee,
)
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data.
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewFeeInsufficient(amtToForward, *upd)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewLinkError(failure)
}
// Check whether the outgoing htlc satisfies the channel policy.
err := l.canSendHtlc(
policy, payHash, amtToForward, outgoingTimeout, heightNow,
originalScid, customRecords,
)
if err != nil {
return err
}
// Finally, we'll ensure that the time-lock on the outgoing HTLC meets
// the following constraint: the incoming time-lock minus our time-lock
// delta should equal the outgoing time lock. Otherwise, whether the
// sender messed up, or an intermediate node tampered with the HTLC.
timeDelta := policy.TimeLockDelta
if incomingTimeout < outgoingTimeout+timeDelta {
l.log.Warnf("incoming htlc(%x) has incorrect time-lock value: "+
"expected at least %v block delta, got %v block delta",
payHash[:], timeDelta, incomingTimeout-outgoingTimeout)
// Grab the latest routing policy so the sending node is up to
// date with our current policy.
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewIncorrectCltvExpiry(
incomingTimeout, *upd,
)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewLinkError(failure)
}
return nil
}
// CheckHtlcTransit should return a nil error if the passed HTLC details
// satisfy the current channel policy. Otherwise, a LinkError with a
// valid protocol failure message should be returned in order to signal
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32,
customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
l.RUnlock()
// We pass in hop.Source here as this is only used in the Switch when
// trying to send over a local link. This causes the fallback mechanism
// to occur.
return l.canSendHtlc(
policy, payHash, amt, timeout, heightNow, hop.Source,
customRecords,
)
}
// canSendHtlc checks whether the given htlc parameters satisfy
// the channel's amount and time lock constraints.
func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32, originalScid lnwire.ShortChannelID,
customRecords lnwire.CustomRecords) *LinkError {
// As our first sanity check, we'll ensure that the passed HTLC isn't
// too small for the next hop. If so, then we'll cancel the HTLC
// directly.
if amt < policy.MinHTLCOut {
l.log.Warnf("outgoing htlc(%x) is too small: min_htlc=%v, "+
"htlc_value=%v", payHash[:], policy.MinHTLCOut,
amt)
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data.
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewAmountBelowMinimum(amt, *upd)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewLinkError(failure)
}
// Next, ensure that the passed HTLC isn't too large. If so, we'll
// cancel the HTLC directly.
if policy.MaxHTLC != 0 && amt > policy.MaxHTLC {
l.log.Warnf("outgoing htlc(%x) is too large: max_htlc=%v, "+
"htlc_value=%v", payHash[:], policy.MaxHTLC, amt)
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up-to-date data.
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewTemporaryChannelFailure(upd)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewDetailedLinkError(failure, OutgoingFailureHTLCExceedsMax)
}
// We want to avoid offering an HTLC which will expire in the near
// future, so we'll reject an HTLC if the outgoing expiration time is
// too close to the current height.
if timeout <= heightNow+l.cfg.OutgoingCltvRejectDelta {
l.log.Warnf("htlc(%x) has an expiry that's too soon: "+
"outgoing_expiry=%v, best_height=%v", payHash[:],
timeout, heightNow)
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewExpiryTooSoon(*upd)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewLinkError(failure)
}
// Check absolute max delta.
if timeout > l.cfg.MaxOutgoingCltvExpiry+heightNow {
l.log.Warnf("outgoing htlc(%x) has a time lock too far in "+
"the future: got %v, but maximum is %v", payHash[:],
timeout-heightNow, l.cfg.MaxOutgoingCltvExpiry)
return NewLinkError(&lnwire.FailExpiryTooFar{})
}
// We now check the available bandwidth to see if this HTLC can be
// forwarded.
availableBandwidth := l.Bandwidth()
auxBandwidth, err := fn.MapOptionZ(
l.cfg.AuxTrafficShaper,
func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] {
var htlcBlob fn.Option[tlv.Blob]
blob, err := customRecords.Serialize()
if err != nil {
return fn.Err[OptionalBandwidth](
fmt.Errorf("unable to serialize "+
"custom records: %w", err))
}
if len(blob) > 0 {
htlcBlob = fn.Some(blob)
}
return l.AuxBandwidth(amt, originalScid, htlcBlob, ts)
},
).Unpack()
if err != nil {
l.log.Errorf("Unable to determine aux bandwidth: %v", err)
return NewLinkError(&lnwire.FailTemporaryNodeFailure{})
}
if auxBandwidth.IsHandled && auxBandwidth.Bandwidth.IsSome() {
auxBandwidth.Bandwidth.WhenSome(
func(bandwidth lnwire.MilliSatoshi) {
availableBandwidth = bandwidth
},
)
}
// Check to see if there is enough balance in this channel.
if amt > availableBandwidth {
l.log.Warnf("insufficient bandwidth to route htlc: %v is "+
"larger than %v", amt, availableBandwidth)
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage {
return lnwire.NewTemporaryChannelFailure(upd)
}
failure := l.createFailureWithUpdate(false, originalScid, cb)
return NewDetailedLinkError(
failure, OutgoingFailureInsufficientBalance,
)
}
return nil
}
// AuxBandwidth returns the bandwidth that can be used for a channel, expressed
// in milli-satoshi. This might be different from the regular BTC bandwidth for
// custom channels. This will always return fn.None() for a regular (non-custom)
// channel.
func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi,
cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob],
ts AuxTrafficShaper) fn.Result[OptionalBandwidth] {
fundingBlob := l.FundingCustomBlob()
shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob, htlcBlob)
if err != nil {
return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+
"failed to decide whether to handle traffic: %w", err))
}
log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+
"traffic: %v", cid, shouldHandle)
// If this channel isn't handled by the aux traffic shaper, we'll return
// early.
if !shouldHandle {
return fn.Ok(OptionalBandwidth{
IsHandled: false,
})
}
// Ask for a specific bandwidth to be used for the channel.
commitmentBlob := l.CommitmentCustomBlob()
auxBandwidth, err := ts.PaymentBandwidth(
htlcBlob, commitmentBlob, l.Bandwidth(), amount,
l.channel.FetchLatestAuxHTLCView(),
)
if err != nil {
return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+
"bandwidth from external traffic shaper: %w", err))
}
log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+
"bandwidth: %v", cid, auxBandwidth)
return fn.Ok(OptionalBandwidth{
IsHandled: true,
Bandwidth: fn.Some(auxBandwidth),
})
}
// Stats returns the statistics of channel link.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) {
snapshot := l.channel.StateSnapshot()
return snapshot.ChannelCommitment.CommitHeight,
snapshot.TotalMSatSent,
snapshot.TotalMSatReceived
}
// String returns the string representation of channel link.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) String() string {
return l.channel.ChannelPoint().String()
}
// handleSwitchPacket handles the switch packets. This packets which might be
// forwarded to us from another channel link in case the htlc update came from
// another peer or if the update was created by user
//
// NOTE: Part of the packetHandler interface.
func (l *channelLink) handleSwitchPacket(pkt *htlcPacket) error {
l.log.Tracef("received switch packet inkey=%v, outkey=%v",
pkt.inKey(), pkt.outKey())
return l.mailBox.AddPacket(pkt)
}
// HandleChannelUpdate handles the htlc requests as settle/add/fail which sent
// to us from remote peer we have a channel with.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) HandleChannelUpdate(message lnwire.Message) {
select {
case <-l.cg.Done():
// Return early if the link is already in the process of
// quitting. It doesn't make sense to hand the message to the
// mailbox here.
return
default:
}
err := l.mailBox.AddMessage(message)
if err != nil {
l.log.Errorf("failed to add Message to mailbox: %v", err)
}
}
// updateChannelFee updates the commitment fee-per-kw on this channel by
// committing to an update_fee message.
func (l *channelLink) updateChannelFee(ctx context.Context,
feePerKw chainfee.SatPerKWeight) error {
l.log.Infof("updating commit fee to %v", feePerKw)
// We skip sending the UpdateFee message if the channel is not
// currently eligible to forward messages.
if !l.eligibleToUpdate() {
l.log.Debugf("skipping fee update for inactive channel")
return nil
}
// Check and see if our proposed fee-rate would make us exceed the fee
// threshold.
thresholdExceeded, err := l.exceedsFeeExposureLimit(feePerKw)
if err != nil {
// This shouldn't typically happen. If it does, it indicates
// something is wrong with our channel state.
return err
}
if thresholdExceeded {
return fmt.Errorf("link fee threshold exceeded")
}
// First, we'll update the local fee on our commitment.
if err := l.channel.UpdateFee(feePerKw); err != nil {
return err
}
// The fee passed the channel's validation checks, so we update the
// mailbox feerate.
l.mailBox.SetFeeRate(feePerKw)
// We'll then attempt to send a new UpdateFee message, and also lock it
// in immediately by triggering a commitment update.
msg := lnwire.NewUpdateFee(l.ChanID(), uint32(feePerKw))
if err := l.cfg.Peer.SendMessage(false, msg); err != nil {
return err
}
return l.updateCommitTx(ctx)
}
// processRemoteSettleFails accepts a batch of settle/fail payment descriptors
// after receiving a revocation from the remote party, and reprocesses them in
// the context of the provided forwarding package. Any settles or fails that
// have already been acknowledged in the forwarding package will not be sent to
// the switch.
func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg) {
if len(fwdPkg.SettleFails) == 0 {
return
}
l.log.Debugf("settle-fail-filter: %v", fwdPkg.SettleFailFilter)
var switchPackets []*htlcPacket
for i, update := range fwdPkg.SettleFails {
destRef := fwdPkg.DestRef(uint16(i))
// Skip any settles or fails that have already been
// acknowledged by the incoming link that originated the
// forwarded Add.
if fwdPkg.SettleFailFilter.Contains(uint16(i)) {
continue
}
// TODO(roasbeef): rework log entries to a shared
// interface.
switch msg := update.UpdateMsg.(type) {
// A settle for an HTLC we previously forwarded HTLC has been
// received. So we'll forward the HTLC to the switch which will
// handle propagating the settle to the prior hop.
case *lnwire.UpdateFulfillHTLC:
// If hodl.SettleIncoming is requested, we will not
// forward the SETTLE to the switch and will not signal
// a free slot on the commitment transaction.
if l.cfg.HodlMask.Active(hodl.SettleIncoming) {
l.log.Warnf(hodl.SettleIncoming.Warning())
continue
}
settlePacket := &htlcPacket{
outgoingChanID: l.ShortChanID(),
outgoingHTLCID: msg.ID,
destRef: &destRef,
htlc: msg,
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, settlePacket)
// A failureCode message for a previously forwarded HTLC has
// been received. As a result a new slot will be freed up in
// our commitment state, so we'll forward this to the switch so
// the backwards undo can continue.
case *lnwire.UpdateFailHTLC:
// If hodl.SettleIncoming is requested, we will not
// forward the FAIL to the switch and will not signal a
// free slot on the commitment transaction.
if l.cfg.HodlMask.Active(hodl.FailIncoming) {
l.log.Warnf(hodl.FailIncoming.Warning())
continue
}
// Fetch the reason the HTLC was canceled so we can
// continue to propagate it. This failure originated
// from another node, so the linkFailure field is not
// set on the packet.
failPacket := &htlcPacket{
outgoingChanID: l.ShortChanID(),
outgoingHTLCID: msg.ID,
destRef: &destRef,
htlc: msg,
}
l.log.Debugf("Failed to send HTLC with ID=%d", msg.ID)
// If the failure message lacks an HMAC (but includes
// the 4 bytes for encoding the message and padding
// lengths, then this means that we received it as an
// UpdateFailMalformedHTLC. As a result, we'll signal
// that we need to convert this error within the switch
// to an actual error, by encrypting it as if we were
// the originating hop.
convertedErrorSize := lnwire.FailureMessageLength + 4
if len(msg.Reason) == convertedErrorSize {
failPacket.convertedError = true
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, failPacket)
}
}
// Only spawn the task forward packets we have a non-zero number.
if len(switchPackets) > 0 {
go l.forwardBatch(false, switchPackets...)
}
}
// processRemoteAdds serially processes each of the Add payment descriptors
// which have been "locked-in" by receiving a revocation from the remote party.
// The forwarding package provided instructs how to process this batch,
// indicating whether this is the first time these Adds are being processed, or
// whether we are reprocessing as a result of a failure or restart. Adds that
// have already been acknowledged in the forwarding package will be ignored.
//
//nolint:funlen
func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) {
l.log.Tracef("processing %d remote adds for height %d",
len(fwdPkg.Adds), fwdPkg.Height)
decodeReqs := make(
[]hop.DecodeHopIteratorRequest, 0, len(fwdPkg.Adds),
)
for _, update := range fwdPkg.Adds {
if msg, ok := update.UpdateMsg.(*lnwire.UpdateAddHTLC); ok {
// Before adding the new htlc to the state machine,
// parse the onion object in order to obtain the
// routing information with DecodeHopIterator function
// which process the Sphinx packet.
onionReader := bytes.NewReader(msg.OnionBlob[:])
req := hop.DecodeHopIteratorRequest{
OnionReader: onionReader,
RHash: msg.PaymentHash[:],
IncomingCltv: msg.Expiry,
IncomingAmount: msg.Amount,
BlindingPoint: msg.BlindingPoint,
}
decodeReqs = append(decodeReqs, req)
}
}
// Atomically decode the incoming htlcs, simultaneously checking for
// replay attempts. A particular index in the returned, spare list of
// channel iterators should only be used if the failure code at the
// same index is lnwire.FailCodeNone.
decodeResps, sphinxErr := l.cfg.DecodeHopIterators(
fwdPkg.ID(), decodeReqs,
)
if sphinxErr != nil {
l.failf(LinkFailureError{code: ErrInternalError},
"unable to decode hop iterators: %v", sphinxErr)
return
}
var switchPackets []*htlcPacket
for i, update := range fwdPkg.Adds {
idx := uint16(i)
//nolint:forcetypeassert
add := *update.UpdateMsg.(*lnwire.UpdateAddHTLC)
sourceRef := fwdPkg.SourceRef(idx)
if fwdPkg.State == channeldb.FwdStateProcessed &&
fwdPkg.AckFilter.Contains(idx) {
// If this index is already found in the ack filter,
// the response to this forwarding decision has already
// been committed by one of our commitment txns. ADDs
// in this state are waiting for the rest of the fwding
// package to get acked before being garbage collected.
continue
}
// An incoming HTLC add has been full-locked in. As a result we
// can now examine the forwarding details of the HTLC, and the
// HTLC itself to decide if: we should forward it, cancel it,
// or are able to settle it (and it adheres to our fee related
// constraints).
// Before adding the new htlc to the state machine, parse the
// onion object in order to obtain the routing information with
// DecodeHopIterator function which process the Sphinx packet.
chanIterator, failureCode := decodeResps[i].Result()
if failureCode != lnwire.CodeNone {
// If we're unable to process the onion blob then we
// should send the malformed htlc error to payment
// sender.
l.sendMalformedHTLCError(
add.ID, failureCode, add.OnionBlob, &sourceRef,
)
l.log.Errorf("unable to decode onion hop "+
"iterator: %v", failureCode)
continue
}
heightNow := l.cfg.BestHeight()
pld, routeRole, pldErr := chanIterator.HopPayload()
if pldErr != nil {
// If we're unable to process the onion payload, or we
// received invalid onion payload failure, then we
// should send an error back to the caller so the HTLC
// can be canceled.
var failedType uint64
// We need to get the underlying error value, so we
// can't use errors.As as suggested by the linter.
//nolint:errorlint
if e, ok := pldErr.(hop.ErrInvalidPayload); ok {
failedType = uint64(e.Type)
}
// If we couldn't parse the payload, make our best
// effort at creating an error encrypter that knows
// what blinding type we were, but if we couldn't
// parse the payload we have no way of knowing whether
// we were the introduction node or not.
//
//nolint:ll
obfuscator, failCode := chanIterator.ExtractErrorEncrypter(
l.cfg.ExtractErrorEncrypter,
// We need our route role here because we
// couldn't parse or validate the payload.
routeRole == hop.RouteRoleIntroduction,
)
if failCode != lnwire.CodeNone {
l.log.Errorf("could not extract error "+
"encrypter: %v", pldErr)
// We can't process this htlc, send back
// malformed.
l.sendMalformedHTLCError(
add.ID, failureCode, add.OnionBlob,
&sourceRef,
)
continue
}
// TODO: currently none of the test unit infrastructure
// is setup to handle TLV payloads, so testing this
// would require implementing a separate mock iterator
// for TLV payloads that also supports injecting invalid
// payloads. Deferring this non-trival effort till a
// later date
failure := lnwire.NewInvalidOnionPayload(failedType, 0)
l.sendHTLCError(
add, sourceRef, NewLinkError(failure),
obfuscator, false,
)
l.log.Errorf("unable to decode forwarding "+
"instructions: %v", pldErr)
continue
}
// Retrieve onion obfuscator from onion blob in order to
// produce initial obfuscation of the onion failureCode.
obfuscator, failureCode := chanIterator.ExtractErrorEncrypter(
l.cfg.ExtractErrorEncrypter,
routeRole == hop.RouteRoleIntroduction,
)
if failureCode != lnwire.CodeNone {
// If we're unable to process the onion blob than we
// should send the malformed htlc error to payment
// sender.
l.sendMalformedHTLCError(
add.ID, failureCode, add.OnionBlob,
&sourceRef,
)
l.log.Errorf("unable to decode onion "+
"obfuscator: %v", failureCode)
continue
}
fwdInfo := pld.ForwardingInfo()
// Check whether the payload we've just processed uses our
// node as the introduction point (gave us a blinding key in
// the payload itself) and fail it back if we don't support
// route blinding.
if fwdInfo.NextBlinding.IsSome() &&
l.cfg.DisallowRouteBlinding {
failure := lnwire.NewInvalidBlinding(
fn.Some(add.OnionBlob),
)
l.sendHTLCError(
add, sourceRef, NewLinkError(failure),
obfuscator, false,
)
l.log.Error("rejected htlc that uses use as an " +
"introduction point when we do not support " +
"route blinding")
continue
}
switch fwdInfo.NextHop {
case hop.Exit:
err := l.processExitHop(
add, sourceRef, obfuscator, fwdInfo,
heightNow, pld,
)
if err != nil {
l.failf(LinkFailureError{
code: ErrInternalError,
}, err.Error()) //nolint
return
}
// There are additional channels left within this route. So
// we'll simply do some forwarding package book-keeping.
default:
// If hodl.AddIncoming is requested, we will not
// validate the forwarded ADD, nor will we send the
// packet to the htlc switch.
if l.cfg.HodlMask.Active(hodl.AddIncoming) {
l.log.Warnf(hodl.AddIncoming.Warning())
continue
}
endorseValue := l.experimentalEndorsement(
record.CustomSet(add.CustomRecords),
)
endorseType := uint64(
lnwire.ExperimentalEndorsementType,
)
switch fwdPkg.State {
case channeldb.FwdStateProcessed:
// This add was not forwarded on the previous
// processing phase, run it through our
// validation pipeline to reproduce an error.
// This may trigger a different error due to
// expiring timelocks, but we expect that an
// error will be reproduced.
if !fwdPkg.FwdFilter.Contains(idx) {
break
}
// Otherwise, it was already processed, we can
// can collect it and continue.
outgoingAdd := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward,
PaymentHash: add.PaymentHash,
BlindingPoint: fwdInfo.NextBlinding,
}
endorseValue.WhenSome(func(e byte) {
custRecords := map[uint64][]byte{
endorseType: {e},
}
outgoingAdd.CustomRecords = custRecords
})
// Finally, we'll encode the onion packet for
// the _next_ hop using the hop iterator
// decoded for the current hop.
buf := bytes.NewBuffer(
outgoingAdd.OnionBlob[0:0],
)
// We know this cannot fail, as this ADD
// was marked forwarded in a previous
// round of processing.
chanIterator.EncodeNextHop(buf)
inboundFee := l.cfg.FwrdingPolicy.InboundFee
//nolint:ll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: add.ID,
outgoingChanID: fwdInfo.NextHop,
sourceRef: &sourceRef,
incomingAmount: add.Amount,
amount: outgoingAdd.Amount,
htlc: outgoingAdd,
obfuscator: obfuscator,
incomingTimeout: add.Expiry,
outgoingTimeout: fwdInfo.OutgoingCTLV,
inOnionCustomRecords: pld.CustomRecords(),
inboundFee: inboundFee,
inWireCustomRecords: add.CustomRecords.Copy(),
}
switchPackets = append(
switchPackets, updatePacket,
)
continue
}
// TODO(roasbeef): ensure don't accept outrageous
// timeout for htlc
// With all our forwarding constraints met, we'll
// create the outgoing HTLC using the parameters as
// specified in the forwarding info.
addMsg := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward,
PaymentHash: add.PaymentHash,
BlindingPoint: fwdInfo.NextBlinding,
}
endorseValue.WhenSome(func(e byte) {
addMsg.CustomRecords = map[uint64][]byte{
endorseType: {e},
}
})
// Finally, we'll encode the onion packet for the
// _next_ hop using the hop iterator decoded for the
// current hop.
buf := bytes.NewBuffer(addMsg.OnionBlob[0:0])
err := chanIterator.EncodeNextHop(buf)
if err != nil {
l.log.Errorf("unable to encode the "+
"remaining route %v", err)
cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { //nolint:ll
return lnwire.NewTemporaryChannelFailure(upd)
}
failure := l.createFailureWithUpdate(
true, hop.Source, cb,
)
l.sendHTLCError(
add, sourceRef, NewLinkError(failure),
obfuscator, false,
)
continue
}
// Now that this add has been reprocessed, only append
// it to our list of packets to forward to the switch
// this is the first time processing the add. If the
// fwd pkg has already been processed, then we entered
// the above section to recreate a previous error. If
// the packet had previously been forwarded, it would
// have been added to switchPackets at the top of this
// section.
if fwdPkg.State == channeldb.FwdStateLockedIn {
inboundFee := l.cfg.FwrdingPolicy.InboundFee
//nolint:ll
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: add.ID,
outgoingChanID: fwdInfo.NextHop,
sourceRef: &sourceRef,
incomingAmount: add.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
incomingTimeout: add.Expiry,
outgoingTimeout: fwdInfo.OutgoingCTLV,
inOnionCustomRecords: pld.CustomRecords(),
inboundFee: inboundFee,
inWireCustomRecords: add.CustomRecords.Copy(),
}
fwdPkg.FwdFilter.Set(idx)
switchPackets = append(switchPackets,
updatePacket)
}
}
}
// Commit the htlcs we are intending to forward if this package has not
// been fully processed.
if fwdPkg.State == channeldb.FwdStateLockedIn {
err := l.channel.SetFwdFilter(fwdPkg.Height, fwdPkg.FwdFilter)
if err != nil {
l.failf(LinkFailureError{code: ErrInternalError},
"unable to set fwd filter: %v", err)
return
}
}
if len(switchPackets) == 0 {
return
}
replay := fwdPkg.State != channeldb.FwdStateLockedIn
l.log.Debugf("forwarding %d packets to switch: replay=%v",
len(switchPackets), replay)
// NOTE: This call is made synchronous so that we ensure all circuits
// are committed in the exact order that they are processed in the link.
// Failing to do this could cause reorderings/gaps in the range of
// opened circuits, which violates assumptions made by the circuit
// trimming.
l.forwardBatch(replay, switchPackets...)
}
// experimentalEndorsement returns the value to set for our outgoing
// experimental endorsement field, and a boolean indicating whether it should
// be populated on the outgoing htlc.
func (l *channelLink) experimentalEndorsement(
customUpdateAdd record.CustomSet) fn.Option[byte] {
// Only relay experimental signal if we are within the experiment
// period.
if !l.cfg.ShouldFwdExpEndorsement() {
return fn.None[byte]()
}
// If we don't have any custom records or the experimental field is
// not set, just forward a zero value.
if len(customUpdateAdd) == 0 {
return fn.Some[byte](lnwire.ExperimentalUnendorsed)
}
t := uint64(lnwire.ExperimentalEndorsementType)
value, set := customUpdateAdd[t]
if !set {
return fn.Some[byte](lnwire.ExperimentalUnendorsed)
}
// We expect at least one byte for this field, consider it invalid if
// it has no data and just forward a zero value.
if len(value) == 0 {
return fn.Some[byte](lnwire.ExperimentalUnendorsed)
}
// Only forward endorsed if the incoming link is endorsed.
if value[0] == lnwire.ExperimentalEndorsed {
return fn.Some[byte](lnwire.ExperimentalEndorsed)
}
// Forward as unendorsed otherwise, including cases where we've
// received an invalid value that uses more than 3 bits of information.
return fn.Some[byte](lnwire.ExperimentalUnendorsed)
}
// processExitHop handles an htlc for which this link is the exit hop. It
// returns a boolean indicating whether the commitment tx needs an update.
func (l *channelLink) processExitHop(add lnwire.UpdateAddHTLC,
sourceRef channeldb.AddRef, obfuscator hop.ErrorEncrypter,
fwdInfo hop.ForwardingInfo, heightNow uint32,
payload invoices.Payload) error {
// If hodl.ExitSettle is requested, we will not validate the final hop's
// ADD, nor will we settle the corresponding invoice or respond with the
// preimage.
if l.cfg.HodlMask.Active(hodl.ExitSettle) {
l.log.Warnf("%s for htlc(rhash=%x,htlcIndex=%v)",
hodl.ExitSettle.Warning(), add.PaymentHash, add.ID)
return nil
}
// In case the traffic shaper is active, we'll check if the HTLC has
// custom records and skip the amount check in the onion payload below.
isCustomHTLC := fn.MapOptionZ(
l.cfg.AuxTrafficShaper,
func(ts AuxTrafficShaper) bool {
return ts.IsCustomHTLC(add.CustomRecords)
},
)
// As we're the exit hop, we'll double check the hop-payload included in
// the HTLC to ensure that it was crafted correctly by the sender and
// is compatible with the HTLC we were extended. If an external
// validator is active we might bypass the amount check.
if !isCustomHTLC && add.Amount < fwdInfo.AmountToForward {
l.log.Errorf("onion payload of incoming htlc(%x) has "+
"incompatible value: expected <=%v, got %v",
add.PaymentHash, add.Amount, fwdInfo.AmountToForward)
failure := NewLinkError(
lnwire.NewFinalIncorrectHtlcAmount(add.Amount),
)
l.sendHTLCError(add, sourceRef, failure, obfuscator, true)
return nil
}
// We'll also ensure that our time-lock value has been computed
// correctly.
if add.Expiry < fwdInfo.OutgoingCTLV {
l.log.Errorf("onion payload of incoming htlc(%x) has "+
"incompatible time-lock: expected <=%v, got %v",
add.PaymentHash, add.Expiry, fwdInfo.OutgoingCTLV)
failure := NewLinkError(
lnwire.NewFinalIncorrectCltvExpiry(add.Expiry),
)
l.sendHTLCError(add, sourceRef, failure, obfuscator, true)
return nil
}
// Notify the invoiceRegistry of the exit hop htlc. If we crash right
// after this, this code will be re-executed after restart. We will
// receive back a resolution event.
invoiceHash := lntypes.Hash(add.PaymentHash)
circuitKey := models.CircuitKey{
ChanID: l.ShortChanID(),
HtlcID: add.ID,
}
event, err := l.cfg.Registry.NotifyExitHopHtlc(
invoiceHash, add.Amount, add.Expiry, int32(heightNow),
circuitKey, l.hodlQueue.ChanIn(), add.CustomRecords, payload,
)
if err != nil {
return err
}
// Create a hodlHtlc struct and decide either resolved now or later.
htlc := hodlHtlc{
add: add,
sourceRef: sourceRef,
obfuscator: obfuscator,
}
// If the event is nil, the invoice is being held, so we save payment
// descriptor for future reference.
if event == nil {
l.hodlMap[circuitKey] = htlc
return nil
}
// Process the received resolution.
return l.processHtlcResolution(event, htlc)
}
// settleHTLC settles the HTLC on the channel.
func (l *channelLink) settleHTLC(preimage lntypes.Preimage,
htlcIndex uint64, sourceRef channeldb.AddRef) error {
hash := preimage.Hash()
l.log.Infof("settling htlc %v as exit hop", hash)
err := l.channel.SettleHTLC(
preimage, htlcIndex, &sourceRef, nil, nil,
)
if err != nil {
return fmt.Errorf("unable to settle htlc: %w", err)
}
// If the link is in hodl.BogusSettle mode, replace the preimage with a
// fake one before sending it to the peer.
if l.cfg.HodlMask.Active(hodl.BogusSettle) {
l.log.Warnf(hodl.BogusSettle.Warning())
preimage = [32]byte{}
copy(preimage[:], bytes.Repeat([]byte{2}, 32))
}
// HTLC was successfully settled locally send notification about it
// remote peer.
l.cfg.Peer.SendMessage(false, &lnwire.UpdateFulfillHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
PaymentPreimage: preimage,
})
// Once we have successfully settled the htlc, notify a settle event.
l.cfg.HtlcNotifier.NotifySettleEvent(
HtlcKey{
IncomingCircuit: models.CircuitKey{
ChanID: l.ShortChanID(),
HtlcID: htlcIndex,
},
},
preimage,
HtlcEventTypeReceive,
)
return nil
}
// forwardBatch forwards the given htlcPackets to the switch, and waits on the
// err chan for the individual responses. This method is intended to be spawned
// as a goroutine so the responses can be handled in the background.
func (l *channelLink) forwardBatch(replay bool, packets ...*htlcPacket) {
// Don't forward packets for which we already have a response in our
// mailbox. This could happen if a packet fails and is buffered in the
// mailbox, and the incoming link flaps.
var filteredPkts = make([]*htlcPacket, 0, len(packets))
for _, pkt := range packets {
if l.mailBox.HasPacket(pkt.inKey()) {
continue
}
filteredPkts = append(filteredPkts, pkt)
}
err := l.cfg.ForwardPackets(l.cg.Done(), replay, filteredPkts...)
if err != nil {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
}
// sendHTLCError functions cancels HTLC and send cancel message back to the
// peer from which HTLC was received.
func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC,
sourceRef channeldb.AddRef, failure *LinkError,
e hop.ErrorEncrypter, isReceive bool) {
reason, err := e.EncryptFirstHop(failure.WireMessage())
if err != nil {
l.log.Errorf("unable to obfuscate error: %v", err)
return
}
err = l.channel.FailHTLC(add.ID, reason, &sourceRef, nil, nil)
if err != nil {
l.log.Errorf("unable cancel htlc: %v", err)
return
}
// Send the appropriate failure message depending on whether we're
// in a blinded route or not.
if err := l.sendIncomingHTLCFailureMsg(
add.ID, e, reason,
); err != nil {
l.log.Errorf("unable to send HTLC failure: %v", err)
return
}
// Notify a link failure on our incoming link. Outgoing htlc information
// is not available at this point, because we have not decrypted the
// onion, so it is excluded.
var eventType HtlcEventType
if isReceive {
eventType = HtlcEventTypeReceive
} else {
eventType = HtlcEventTypeForward
}
l.cfg.HtlcNotifier.NotifyLinkFailEvent(
HtlcKey{
IncomingCircuit: models.CircuitKey{
ChanID: l.ShortChanID(),
HtlcID: add.ID,
},
},
HtlcInfo{
IncomingTimeLock: add.Expiry,
IncomingAmt: add.Amount,
},
eventType,
failure,
true,
)
}
// sendPeerHTLCFailure handles sending a HTLC failure message back to the
// peer from which the HTLC was received. This function is primarily used to
// handle the special requirements of route blinding, specifically:
// - Forwarding nodes must switch out any errors with MalformedFailHTLC
// - Introduction nodes should return regular HTLC failure messages.
//
// It accepts the original opaque failure, which will be used in the case
// that we're not part of a blinded route and an error encrypter that'll be
// used if we are the introduction node and need to present an error as if
// we're the failing party.
func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64,
e hop.ErrorEncrypter,
originalFailure lnwire.OpaqueReason) error {
var msg lnwire.Message
switch {
// Our circuit's error encrypter will be nil if this was a locally
// initiated payment. We can only hit a blinded error for a locally
// initiated payment if we allow ourselves to be picked as the
// introduction node for our own payments and in that case we
// shouldn't reach this code. To prevent the HTLC getting stuck,
// we fail it back and log an error.
// code.
case e == nil:
msg = &lnwire.UpdateFailHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
Reason: originalFailure,
}
l.log.Errorf("Unexpected blinded failure when "+
"we are the sending node, incoming htlc: %v(%v)",
l.ShortChanID(), htlcIndex)
// For cleartext hops (ie, non-blinded/normal) we don't need any
// transformation on the error message and can just send the original.
case !e.Type().IsBlinded():
msg = &lnwire.UpdateFailHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
Reason: originalFailure,
}
// When we're the introduction node, we need to convert the error to
// a UpdateFailHTLC.
case e.Type() == hop.EncrypterTypeIntroduction:
l.log.Debugf("Introduction blinded node switching out failure "+
"error: %v", htlcIndex)
// The specification does not require that we set the onion
// blob.
failureMsg := lnwire.NewInvalidBlinding(
fn.None[[lnwire.OnionPacketSize]byte](),
)
reason, err := e.EncryptFirstHop(failureMsg)
if err != nil {
return err
}
msg = &lnwire.UpdateFailHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
Reason: reason,
}
// If we are a relaying node, we need to switch out any error that
// we've received to a malformed HTLC error.
case e.Type() == hop.EncrypterTypeRelaying:
l.log.Debugf("Relaying blinded node switching out malformed "+
"error: %v", htlcIndex)
msg = &lnwire.UpdateFailMalformedHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
FailureCode: lnwire.CodeInvalidBlinding,
}
default:
return fmt.Errorf("unexpected encrypter: %d", e)
}
if err := l.cfg.Peer.SendMessage(false, msg); err != nil {
l.log.Warnf("Send update fail failed: %v", err)
}
return nil
}
// sendMalformedHTLCError helper function which sends the malformed HTLC update
// to the payment sender.
func (l *channelLink) sendMalformedHTLCError(htlcIndex uint64,
code lnwire.FailCode, onionBlob [lnwire.OnionPacketSize]byte,
sourceRef *channeldb.AddRef) {
shaOnionBlob := sha256.Sum256(onionBlob[:])
err := l.channel.MalformedFailHTLC(htlcIndex, code, shaOnionBlob, sourceRef)
if err != nil {
l.log.Errorf("unable cancel htlc: %v", err)
return
}
l.cfg.Peer.SendMessage(false, &lnwire.UpdateFailMalformedHTLC{
ChanID: l.ChanID(),
ID: htlcIndex,
ShaOnionBlob: shaOnionBlob,
FailureCode: code,
})
}
// failf is a function which is used to encapsulate the action necessary for
// properly failing the link. It takes a LinkFailureError, which will be passed
// to the OnChannelFailure closure, in order for it to determine if we should
// force close the channel, and if we should send an error message to the
// remote peer.
func (l *channelLink) failf(linkErr LinkFailureError, format string,
a ...interface{}) {
reason := fmt.Errorf(format, a...)
// Return if we have already notified about a failure.
if l.failed {
l.log.Warnf("ignoring link failure (%v), as link already "+
"failed", reason)
return
}
l.log.Errorf("failing link: %s with error: %v", reason, linkErr)
// Set failed, such that we won't process any more updates, and notify
// the peer about the failure.
l.failed = true
l.cfg.OnChannelFailure(l.ChanID(), l.ShortChanID(), linkErr)
}
// FundingCustomBlob returns the custom funding blob of the channel that this
// link is associated with. The funding blob represents static information about
// the channel that was created at channel funding time.
func (l *channelLink) FundingCustomBlob() fn.Option[tlv.Blob] {
if l.channel == nil {
return fn.None[tlv.Blob]()
}
if l.channel.State() == nil {
return fn.None[tlv.Blob]()
}
return l.channel.State().CustomBlob
}
// CommitmentCustomBlob returns the custom blob of the current local commitment
// of the channel that this link is associated with.
func (l *channelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] {
if l.channel == nil {
return fn.None[tlv.Blob]()
}
return l.channel.LocalCommitmentBlob()
}
package htlcswitch
import "github.com/go-errors/errors"
var (
// ErrLinkShuttingDown signals that the link is shutting down.
ErrLinkShuttingDown = errors.New("link shutting down")
// ErrLinkFailedShutdown signals that a requested shutdown failed.
ErrLinkFailedShutdown = errors.New("link failed to shutdown")
)
// errorCode encodes the possible types of errors that will make us fail the
// current link.
type errorCode uint8
const (
// ErrInternalError indicates that something internal in the link
// failed. In this case we will send a generic error to our peer.
ErrInternalError errorCode = iota
// ErrRemoteError indicates that our peer sent an error, prompting up
// to fail the link.
ErrRemoteError
// ErrRemoteUnresponsive indicates that our peer took too long to
// complete a commitment dance.
ErrRemoteUnresponsive
// ErrSyncError indicates that we failed synchronizing the state of the
// channel with our peer.
ErrSyncError
// ErrInvalidUpdate indicates that the peer send us an invalid update.
ErrInvalidUpdate
// ErrInvalidCommitment indicates that the remote peer sent us an
// invalid commitment signature.
ErrInvalidCommitment
// ErrInvalidRevocation indicates that the remote peer send us an
// invalid revocation message.
ErrInvalidRevocation
// ErrRecoveryError the channel was unable to be resumed, we need the
// remote party to force close the channel out on chain now as a
// result.
ErrRecoveryError
// ErrCircuitError indicates a duplicate keystone error was hit in the
// circuit map. This is non-fatal and will resolve itself (usually
// within several minutes).
ErrCircuitError
// ErrStfuViolation indicates that the quiescence protocol has been
// violated, either because Stfu has been sent/received at an invalid
// time, or that an update has been sent/received while the channel is
// quiesced.
ErrStfuViolation
)
// LinkFailureAction is an enum-like type that describes the action that should
// be taken in response to a link failure.
type LinkFailureAction uint8
const (
// LinkFailureForceNone indicates no action is to be taken.
LinkFailureForceNone LinkFailureAction = iota
// LinkFailureForceClose indicates that the channel should be force
// closed.
LinkFailureForceClose
// LinkFailureDisconnect indicates that we should disconnect in an
// attempt to recycle the connection. This can be useful if we think a
// TCP connection or state machine is stalled.
LinkFailureDisconnect
)
// LinkFailureError encapsulates an error that will make us fail the current
// link. It contains the necessary information needed to determine if we should
// force close the channel in the process, and if any error data should be sent
// to the peer.
type LinkFailureError struct {
// code is the type of error this LinkFailureError encapsulates.
code errorCode
// FailureAction describes what we should do to fail the channel.
FailureAction LinkFailureAction
// PermanentFailure indicates whether this failure is permanent, and
// the channel should not be attempted loaded again.
PermanentFailure bool
// Warning denotes if this is a non-terminal error that doesn't warrant
// failing the channel all together.
Warning bool
// SendData is a byte slice that will be sent to the peer. If nil a
// generic error will be sent.
SendData []byte
}
// A compile time check to ensure LinkFailureError implements the error
// interface.
var _ error = (*LinkFailureError)(nil)
// Error returns a generic error for the LinkFailureError.
//
// NOTE: Part of the error interface.
func (e LinkFailureError) Error() string {
switch e.code {
case ErrInternalError:
return "internal error"
case ErrRemoteError:
return "remote error"
case ErrRemoteUnresponsive:
return "remote unresponsive"
case ErrSyncError:
return "sync error"
case ErrInvalidUpdate:
return "invalid update"
case ErrInvalidCommitment:
return "invalid commitment"
case ErrInvalidRevocation:
return "invalid revocation"
case ErrRecoveryError:
return "unable to resume channel, recovery required"
case ErrCircuitError:
return "non-fatal circuit map error"
case ErrStfuViolation:
return "quiescence protocol executed improperly"
default:
return "unknown error"
}
}
// ShouldSendToPeer indicates whether we should send an error to the peer if
// the link fails with this LinkFailureError.
func (e LinkFailureError) ShouldSendToPeer() bool {
switch e.code {
// Since sending an error can lead some nodes to force close the
// channel, create a whitelist of the failures we want to send so that
// newly added error codes aren't automatically sent to the remote peer.
case
ErrInternalError,
ErrRemoteError,
ErrSyncError,
ErrInvalidUpdate,
ErrInvalidCommitment,
ErrInvalidRevocation,
ErrRecoveryError:
return true
// In all other cases we will not attempt to send our peer an error.
default:
return false
}
}
package htlcswitch
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
logger := build.NewSubLogger("HSWC", nil)
UseLogger(logger)
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
hop.UseLogger(logger)
}
package htlcswitch
import (
"bytes"
"container/list"
"errors"
"fmt"
"sync"
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrMailBoxShuttingDown is returned when the mailbox is interrupted by
// a shutdown request.
ErrMailBoxShuttingDown = errors.New("mailbox is shutting down")
// ErrPacketAlreadyExists signals that an attempt to add a packet failed
// because it already exists in the mailbox.
ErrPacketAlreadyExists = errors.New("mailbox already has packet")
)
// MailBox is an interface which represents a concurrent-safe, in-order
// delivery queue for messages from the network and also from the main switch.
// This struct serves as a buffer between incoming messages, and messages to
// the handled by the link. Each of the mutating methods within this interface
// should be implemented in a non-blocking manner.
type MailBox interface {
// AddMessage appends a new message to the end of the message queue.
AddMessage(msg lnwire.Message) error
// AddPacket appends a new message to the end of the packet queue.
AddPacket(pkt *htlcPacket) error
// HasPacket queries the packets for a circuit key, this is used to drop
// packets bound for the switch that already have a queued response.
HasPacket(CircuitKey) bool
// AckPacket removes a packet from the mailboxes in-memory replay
// buffer. This will prevent a packet from being delivered after a link
// restarts if the switch has remained online. The returned boolean
// indicates whether or not a packet with the passed incoming circuit
// key was removed.
AckPacket(CircuitKey) bool
// FailAdd fails an UpdateAddHTLC that exists within the mailbox,
// removing it from the in-memory replay buffer. This will prevent the
// packet from being delivered after the link restarts if the switch has
// remained online. The generated LinkError will show an
// OutgoingFailureDownstreamHtlcAdd FailureDetail.
FailAdd(pkt *htlcPacket)
// MessageOutBox returns a channel that any new messages ready for
// delivery will be sent on.
MessageOutBox() chan lnwire.Message
// PacketOutBox returns a channel that any new packets ready for
// delivery will be sent on.
PacketOutBox() chan *htlcPacket
// Clears any pending wire messages from the inbox.
ResetMessages() error
// Reset the packet head to point at the first element in the list.
ResetPackets() error
// SetDustClosure takes in a closure that is used to evaluate whether
// mailbox HTLC's are dust.
SetDustClosure(isDust dustClosure)
// SetFeeRate sets the feerate to be used when evaluating dust.
SetFeeRate(feerate chainfee.SatPerKWeight)
// DustPackets returns the dust sum for Adds in the mailbox for the
// local and remote commitments.
DustPackets() (lnwire.MilliSatoshi, lnwire.MilliSatoshi)
// Start starts the mailbox and any goroutines it needs to operate
// properly.
Start()
// Stop signals the mailbox and its goroutines for a graceful shutdown.
Stop()
}
type mailBoxConfig struct {
// shortChanID is the short channel id of the channel this mailbox
// belongs to.
shortChanID lnwire.ShortChannelID
// forwardPackets send a varidic number of htlcPackets to the switch to
// be routed. A quit channel should be provided so that the call can
// properly exit during shutdown.
forwardPackets func(<-chan struct{}, ...*htlcPacket) error
// clock is a time source for the mailbox.
clock clock.Clock
// expiry is the interval after which Adds will be cancelled if they
// have not been yet been delivered. The computed deadline will expiry
// this long after the Adds are added via AddPacket.
expiry time.Duration
// failMailboxUpdate is used to fail an expired HTLC and use the
// correct SCID if the underlying channel uses aliases.
failMailboxUpdate func(outScid,
mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage
}
// memoryMailBox is an implementation of the MailBox struct backed by purely
// in-memory queues.
//
// TODO(morehouse): use typed lists instead of list.Lists to avoid type asserts.
type memoryMailBox struct {
started sync.Once
stopped sync.Once
cfg *mailBoxConfig
wireMessages *list.List
wireMtx sync.Mutex
wireCond *sync.Cond
messageOutbox chan lnwire.Message
msgReset chan chan struct{}
// repPkts is a queue for reply packets, e.g. Settles and Fails.
repPkts *list.List
repIndex map[CircuitKey]*list.Element
repHead *list.Element
// addPkts is a dedicated queue for Adds.
addPkts *list.List
addIndex map[CircuitKey]*list.Element
addHead *list.Element
pktMtx sync.Mutex
pktCond *sync.Cond
pktOutbox chan *htlcPacket
pktReset chan chan struct{}
wireShutdown chan struct{}
pktShutdown chan struct{}
quit chan struct{}
// feeRate is set when the link receives or sends out fee updates. It
// is refreshed when AttachMailBox is called in case a fee update did
// not get committed. In some cases it may be out of sync with the
// channel's feerate, but it should eventually get back in sync.
feeRate chainfee.SatPerKWeight
// isDust is set when AttachMailBox is called and serves to evaluate
// the outstanding dust in the memoryMailBox given the current set
// feeRate.
isDust dustClosure
}
// newMemoryMailBox creates a new instance of the memoryMailBox.
func newMemoryMailBox(cfg *mailBoxConfig) *memoryMailBox {
box := &memoryMailBox{
cfg: cfg,
wireMessages: list.New(),
repPkts: list.New(),
addPkts: list.New(),
messageOutbox: make(chan lnwire.Message),
pktOutbox: make(chan *htlcPacket),
msgReset: make(chan chan struct{}, 1),
pktReset: make(chan chan struct{}, 1),
repIndex: make(map[CircuitKey]*list.Element),
addIndex: make(map[CircuitKey]*list.Element),
wireShutdown: make(chan struct{}),
pktShutdown: make(chan struct{}),
quit: make(chan struct{}),
}
box.wireCond = sync.NewCond(&box.wireMtx)
box.pktCond = sync.NewCond(&box.pktMtx)
return box
}
// A compile time assertion to ensure that memoryMailBox meets the MailBox
// interface.
var _ MailBox = (*memoryMailBox)(nil)
// courierType is an enum that reflects the distinct types of messages a
// MailBox can handle. Each type will be placed in an isolated mail box and
// will have a dedicated goroutine for delivering the messages.
type courierType uint8
const (
// wireCourier is a type of courier that handles wire messages.
wireCourier courierType = iota
// pktCourier is a type of courier that handles htlc packets.
pktCourier
)
// Start starts the mailbox and any goroutines it needs to operate properly.
//
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) Start() {
m.started.Do(func() {
go m.wireMailCourier()
go m.pktMailCourier()
})
}
// ResetMessages blocks until all buffered wire messages are cleared.
func (m *memoryMailBox) ResetMessages() error {
msgDone := make(chan struct{})
select {
case m.msgReset <- msgDone:
return m.signalUntilReset(wireCourier, msgDone)
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
// ResetPackets blocks until the head of packets buffer is reset, causing the
// packets to be redelivered in order.
func (m *memoryMailBox) ResetPackets() error {
pktDone := make(chan struct{})
select {
case m.pktReset <- pktDone:
return m.signalUntilReset(pktCourier, pktDone)
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
// signalUntilReset strobes the condition variable for the specified inbox type
// until receiving a response that the mailbox has processed a reset.
func (m *memoryMailBox) signalUntilReset(cType courierType,
done chan struct{}) error {
for {
switch cType {
case wireCourier:
m.wireCond.Signal()
case pktCourier:
m.pktCond.Signal()
}
select {
case <-time.After(time.Millisecond):
continue
case <-done:
return nil
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
}
// AckPacket removes the packet identified by it's incoming circuit key from the
// queue of packets to be delivered. The returned boolean indicates whether or
// not a packet with the passed incoming circuit key was removed.
//
// NOTE: It is safe to call this method multiple times for the same circuit key.
func (m *memoryMailBox) AckPacket(inKey CircuitKey) bool {
m.pktCond.L.Lock()
defer m.pktCond.L.Unlock()
if entry, ok := m.repIndex[inKey]; ok {
// Check whether we are removing the head of the queue. If so,
// we must advance the head to the next packet before removing.
// It's possible that the courier has already advanced the
// repHead, so this check prevents the repHead from getting
// desynchronized.
if entry == m.repHead {
m.repHead = entry.Next()
}
m.repPkts.Remove(entry)
delete(m.repIndex, inKey)
return true
}
if entry, ok := m.addIndex[inKey]; ok {
// Check whether we are removing the head of the queue. If so,
// we must advance the head to the next add before removing.
// It's possible that the courier has already advanced the
// addHead, so this check prevents the addHead from getting
// desynchronized.
//
// NOTE: While this event is rare for Settles or Fails, it could
// be very common for Adds since the mailbox has the ability to
// cancel Adds before they are delivered. When that occurs, the
// head of addPkts has only been peeked and we expect to be
// removing the head of the queue.
if entry == m.addHead {
m.addHead = entry.Next()
}
m.addPkts.Remove(entry)
delete(m.addIndex, inKey)
return true
}
return false
}
// HasPacket queries the packets for a circuit key, this is used to drop packets
// bound for the switch that already have a queued response.
func (m *memoryMailBox) HasPacket(inKey CircuitKey) bool {
m.pktCond.L.Lock()
_, ok := m.repIndex[inKey]
m.pktCond.L.Unlock()
return ok
}
// Stop signals the mailbox and its goroutines for a graceful shutdown.
//
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) Stop() {
m.stopped.Do(func() {
close(m.quit)
m.signalUntilShutdown(wireCourier)
m.signalUntilShutdown(pktCourier)
})
}
// signalUntilShutdown strobes the condition variable of the passed courier
// type, blocking until the worker has exited.
func (m *memoryMailBox) signalUntilShutdown(cType courierType) {
var (
cond *sync.Cond
shutdown chan struct{}
)
switch cType {
case wireCourier:
cond = m.wireCond
shutdown = m.wireShutdown
case pktCourier:
cond = m.pktCond
shutdown = m.pktShutdown
}
for {
select {
case <-time.After(time.Millisecond):
cond.Signal()
case <-shutdown:
return
}
}
}
// pktWithExpiry wraps an incoming packet and records the time at which it it
// should be canceled from the mailbox. This will be used to detect if it gets
// stuck in the mailbox and inform when to cancel back.
type pktWithExpiry struct {
pkt *htlcPacket
expiry time.Time
}
func (p *pktWithExpiry) deadline(clock clock.Clock) <-chan time.Time {
return clock.TickAfter(p.expiry.Sub(clock.Now()))
}
// wireMailCourier is a dedicated goroutine whose job is to reliably deliver
// wire messages.
func (m *memoryMailBox) wireMailCourier() {
defer close(m.wireShutdown)
for {
// First, we'll check our condition. If our mailbox is empty,
// then we'll wait until a new item is added.
m.wireCond.L.Lock()
for m.wireMessages.Front() == nil {
m.wireCond.Wait()
select {
case msgDone := <-m.msgReset:
m.wireMessages.Init()
close(msgDone)
case <-m.quit:
m.wireCond.L.Unlock()
return
default:
}
}
// Grab the datum off the front of the queue, shifting the
// slice's reference down one in order to remove the datum from
// the queue.
entry := m.wireMessages.Front()
//nolint:forcetypeassert
nextMsg := m.wireMessages.Remove(entry).(lnwire.Message)
// Now that we're done with the condition, we can unlock it to
// allow any callers to append to the end of our target queue.
m.wireCond.L.Unlock()
// With the next message obtained, we'll now select to attempt
// to deliver the message. If we receive a kill signal, then
// we'll bail out.
select {
case m.messageOutbox <- nextMsg:
case msgDone := <-m.msgReset:
m.wireCond.L.Lock()
m.wireMessages.Init()
m.wireCond.L.Unlock()
close(msgDone)
case <-m.quit:
return
}
}
}
// pktMailCourier is a dedicated goroutine whose job is to reliably deliver
// packet messages.
func (m *memoryMailBox) pktMailCourier() {
defer close(m.pktShutdown)
for {
// First, we'll check our condition. If our mailbox is empty,
// then we'll wait until a new item is added.
m.pktCond.L.Lock()
for m.repHead == nil && m.addHead == nil {
m.pktCond.Wait()
select {
// Resetting the packet queue means just moving our
// pointer to the front. This ensures that any un-ACK'd
// messages are re-delivered upon reconnect.
case pktDone := <-m.pktReset:
m.repHead = m.repPkts.Front()
m.addHead = m.addPkts.Front()
close(pktDone)
case <-m.quit:
m.pktCond.L.Unlock()
return
default:
}
}
var (
nextRep *htlcPacket
nextRepEl *list.Element
nextAdd *pktWithExpiry
nextAddEl *list.Element
)
// For packets, we actually never remove an item until it has
// been ACK'd by the link. This ensures that if a read packet
// doesn't make it into a commitment, then it'll be
// re-delivered once the link comes back online.
// Peek at the head of the Settle/Fails and Add queues. We peak
// both even if there is a Settle/Fail present because we need
// to set a deadline for the next pending Add if it's present.
// Due to clock monotonicity, we know that the head of the Adds
// is the next to expire.
if m.repHead != nil {
//nolint:forcetypeassert
nextRep = m.repHead.Value.(*htlcPacket)
nextRepEl = m.repHead
}
if m.addHead != nil {
//nolint:forcetypeassert
nextAdd = m.addHead.Value.(*pktWithExpiry)
nextAddEl = m.addHead
}
// Now that we're done with the condition, we can unlock it to
// allow any callers to append to the end of our target queue.
m.pktCond.L.Unlock()
var (
pktOutbox chan *htlcPacket
addOutbox chan *htlcPacket
add *htlcPacket
deadline <-chan time.Time
)
// Prioritize delivery of Settle/Fail packets over Adds. This
// ensures that we actively clear the commitment of existing
// HTLCs before trying to add new ones. This can help to improve
// forwarding performance since the time to sign a commitment is
// linear in the number of HTLCs manifested on the commitments.
//
// NOTE: Both types are eventually delivered over the same
// channel, but we can control which is delivered by exclusively
// making one nil and the other non-nil. We know from our loop
// condition that at least one nextRep and nextAdd are non-nil.
if nextRep != nil {
pktOutbox = m.pktOutbox
} else {
addOutbox = m.pktOutbox
}
// If we have a pending Add, we'll also construct the deadline
// so we can fail it back if we are unable to deliver any
// message in time. We also dereference the nextAdd's packet,
// since we will need access to it in the case we are delivering
// it and/or if the deadline expires.
//
// NOTE: It's possible after this point for add to be nil, but
// this can only occur when addOutbox is also nil, hence we
// won't accidentally deliver a nil packet.
if nextAdd != nil {
add = nextAdd.pkt
deadline = nextAdd.deadline(m.cfg.clock)
}
select {
case pktOutbox <- nextRep:
m.pktCond.L.Lock()
// Only advance the repHead if this Settle or Fail is
// still at the head of the queue.
if m.repHead != nil && m.repHead == nextRepEl {
m.repHead = m.repHead.Next()
}
m.pktCond.L.Unlock()
case addOutbox <- add:
m.pktCond.L.Lock()
// Only advance the addHead if this Add is still at the
// head of the queue.
if m.addHead != nil && m.addHead == nextAddEl {
m.addHead = m.addHead.Next()
}
m.pktCond.L.Unlock()
case <-deadline:
log.Debugf("Expiring add htlc with "+
"keystone=%v", add.keystone())
m.FailAdd(add)
case pktDone := <-m.pktReset:
m.pktCond.L.Lock()
m.repHead = m.repPkts.Front()
m.addHead = m.addPkts.Front()
m.pktCond.L.Unlock()
close(pktDone)
case <-m.quit:
return
}
}
}
// AddMessage appends a new message to the end of the message queue.
//
// NOTE: This method is safe for concrete use and part of the MailBox
// interface.
func (m *memoryMailBox) AddMessage(msg lnwire.Message) error {
// First, we'll lock the condition, and add the message to the end of
// the wire message inbox.
m.wireCond.L.Lock()
m.wireMessages.PushBack(msg)
m.wireCond.L.Unlock()
// With the message added, we signal to the mailCourier that there are
// additional messages to deliver.
m.wireCond.Signal()
return nil
}
// AddPacket appends a new message to the end of the packet queue.
//
// NOTE: This method is safe for concrete use and part of the MailBox
// interface.
func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error {
m.pktCond.L.Lock()
switch htlc := pkt.htlc.(type) {
// Split off Settle/Fail packets into the repPkts queue.
case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
if _, ok := m.repIndex[pkt.inKey()]; ok {
m.pktCond.L.Unlock()
return ErrPacketAlreadyExists
}
entry := m.repPkts.PushBack(pkt)
m.repIndex[pkt.inKey()] = entry
if m.repHead == nil {
m.repHead = entry
}
// Split off Add packets into the addPkts queue.
case *lnwire.UpdateAddHTLC:
if _, ok := m.addIndex[pkt.inKey()]; ok {
m.pktCond.L.Unlock()
return ErrPacketAlreadyExists
}
entry := m.addPkts.PushBack(&pktWithExpiry{
pkt: pkt,
expiry: m.cfg.clock.Now().Add(m.cfg.expiry),
})
m.addIndex[pkt.inKey()] = entry
if m.addHead == nil {
m.addHead = entry
}
default:
m.pktCond.L.Unlock()
return fmt.Errorf("unknown htlc type: %T", htlc)
}
m.pktCond.L.Unlock()
// With the packet added, we signal to the mailCourier that there are
// additional packets to consume.
m.pktCond.Signal()
return nil
}
// SetFeeRate sets the memoryMailBox's feerate for use in DustPackets.
func (m *memoryMailBox) SetFeeRate(feeRate chainfee.SatPerKWeight) {
m.pktCond.L.Lock()
defer m.pktCond.L.Unlock()
m.feeRate = feeRate
}
// SetDustClosure sets the memoryMailBox's dustClosure for use in DustPackets.
func (m *memoryMailBox) SetDustClosure(isDust dustClosure) {
m.pktCond.L.Lock()
defer m.pktCond.L.Unlock()
m.isDust = isDust
}
// DustPackets returns the dust sum for add packets in the mailbox. The first
// return value is the local dust sum and the second is the remote dust sum.
// This will keep track of a given dust HTLC from the time it is added via
// AddPacket until it is removed via AckPacket.
func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
lnwire.MilliSatoshi) {
m.pktCond.L.Lock()
defer m.pktCond.L.Unlock()
var (
localDustSum lnwire.MilliSatoshi
remoteDustSum lnwire.MilliSatoshi
)
// Run through the map of HTLC's and determine the dust sum with calls
// to the memoryMailBox's isDust closure. Note that all mailbox packets
// are outgoing so the second argument to isDust will be false.
for _, e := range m.addIndex {
addPkt := e.Value.(*pktWithExpiry).pkt
// Evaluate whether this HTLC is dust on the local commitment.
if m.isDust(
m.feeRate, false, lntypes.Local,
addPkt.amount.ToSatoshis(),
) {
localDustSum += addPkt.amount
}
// Evaluate whether this HTLC is dust on the remote commitment.
if m.isDust(
m.feeRate, false, lntypes.Remote,
addPkt.amount.ToSatoshis(),
) {
remoteDustSum += addPkt.amount
}
}
return localDustSum, remoteDustSum
}
// FailAdd fails an UpdateAddHTLC that exists within the mailbox, removing it
// from the in-memory replay buffer. This will prevent the packet from being
// delivered after the link restarts if the switch has remained online. The
// generated LinkError will show an OutgoingFailureDownstreamHtlcAdd
// FailureDetail.
func (m *memoryMailBox) FailAdd(pkt *htlcPacket) {
// First, remove the packet from mailbox. If we didn't find the packet
// because it has already been acked, we'll exit early to avoid sending
// a duplicate fail message through the switch.
if !m.AckPacket(pkt.inKey()) {
return
}
var (
localFailure = false
reason lnwire.OpaqueReason
)
// Create a temporary channel failure which we will send back to our
// peer if this is a forward, or report to the user if the failed
// payment was locally initiated.
failure := m.cfg.failMailboxUpdate(
pkt.originalOutgoingChanID, m.cfg.shortChanID,
)
// If the payment was locally initiated (which is indicated by a nil
// obfuscator), we do not need to encrypt it back to the sender.
if pkt.obfuscator == nil {
var b bytes.Buffer
err := lnwire.EncodeFailure(&b, failure, 0)
if err != nil {
log.Errorf("Unable to encode failure: %v", err)
return
}
reason = lnwire.OpaqueReason(b.Bytes())
localFailure = true
} else {
// If the packet is part of a forward, (identified by a non-nil
// obfuscator) we need to encrypt the error back to the source.
var err error
reason, err = pkt.obfuscator.EncryptFirstHop(failure)
if err != nil {
log.Errorf("Unable to obfuscate error: %v", err)
return
}
}
// Create a link error containing the temporary channel failure and a
// detail which indicates the we failed to add the htlc.
linkError := NewDetailedLinkError(
failure, OutgoingFailureDownstreamHtlcAdd,
)
failPkt := &htlcPacket{
incomingChanID: pkt.incomingChanID,
incomingHTLCID: pkt.incomingHTLCID,
circuit: pkt.circuit,
sourceRef: pkt.sourceRef,
hasSource: true,
localFailure: localFailure,
obfuscator: pkt.obfuscator,
linkFailure: linkError,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
if err := m.cfg.forwardPackets(m.quit, failPkt); err != nil {
log.Errorf("Unhandled error while reforwarding packets "+
"settle/fail over htlcswitch: %v", err)
}
}
// MessageOutBox returns a channel that any new messages ready for delivery
// will be sent on.
//
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) MessageOutBox() chan lnwire.Message {
return m.messageOutbox
}
// PacketOutBox returns a channel that any new packets ready for delivery will
// be sent on.
//
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) PacketOutBox() chan *htlcPacket {
return m.pktOutbox
}
// mailOrchestrator is responsible for coordinating the creation and lifecycle
// of mailboxes used within the switch. It supports the ability to create
// mailboxes, reassign their short channel id's, deliver htlc packets, and
// queue packets for mailboxes that have not been created due to a link's late
// registration.
type mailOrchestrator struct {
mu sync.RWMutex
cfg *mailOrchConfig
// mailboxes caches exactly one mailbox for all known channels.
mailboxes map[lnwire.ChannelID]MailBox
// liveIndex maps a live short chan id to the primary mailbox key.
// An index in liveIndex map is only entered under two conditions:
// 1. A link has a non-zero short channel id at time of AddLink.
// 2. A link receives a non-zero short channel via UpdateShortChanID.
liveIndex map[lnwire.ShortChannelID]lnwire.ChannelID
// TODO(conner): add another pair of indexes:
// chan_id -> short_chan_id
// short_chan_id -> mailbox
// so that Deliver can lookup mailbox directly once live,
// but still queryable by channel_id.
// unclaimedPackets maps a live short chan id to queue of packets if no
// mailbox has been created.
unclaimedPackets map[lnwire.ShortChannelID][]*htlcPacket
}
type mailOrchConfig struct {
// forwardPackets send a varidic number of htlcPackets to the switch to
// be routed. A quit channel should be provided so that the call can
// properly exit during shutdown.
forwardPackets func(<-chan struct{}, ...*htlcPacket) error
// clock is a time source for the generated mailboxes.
clock clock.Clock
// expiry is the interval after which Adds will be cancelled if they
// have not been yet been delivered. The computed deadline will expiry
// this long after the Adds are added to a mailbox via AddPacket.
expiry time.Duration
// failMailboxUpdate is used to fail an expired HTLC and use the
// correct SCID if the underlying channel uses aliases.
failMailboxUpdate func(outScid,
mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage
}
// newMailOrchestrator initializes a fresh mailOrchestrator.
func newMailOrchestrator(cfg *mailOrchConfig) *mailOrchestrator {
return &mailOrchestrator{
cfg: cfg,
mailboxes: make(map[lnwire.ChannelID]MailBox),
liveIndex: make(map[lnwire.ShortChannelID]lnwire.ChannelID),
unclaimedPackets: make(map[lnwire.ShortChannelID][]*htlcPacket),
}
}
// Stop instructs the orchestrator to stop all active mailboxes.
func (mo *mailOrchestrator) Stop() {
for _, mailbox := range mo.mailboxes {
mailbox.Stop()
}
}
// GetOrCreateMailBox returns an existing mailbox belonging to `chanID`, or
// creates and returns a new mailbox if none is found.
func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID) MailBox {
// First, try lookup the mailbox directly using only the shared mutex.
mo.mu.RLock()
mailbox, ok := mo.mailboxes[chanID]
if ok {
mo.mu.RUnlock()
return mailbox
}
mo.mu.RUnlock()
// Otherwise, we will try again with exclusive lock, creating a mailbox
// if one still has not been created.
mo.mu.Lock()
mailbox = mo.exclusiveGetOrCreateMailBox(chanID, shortChanID)
mo.mu.Unlock()
return mailbox
}
// exclusiveGetOrCreateMailBox checks for the existence of a mailbox for the
// given channel id. If none is found, a new one is creates, started, and
// recorded.
//
// NOTE: This method MUST be invoked with the mailOrchestrator's exclusive lock.
func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox(
chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID) MailBox {
mailbox, ok := mo.mailboxes[chanID]
if !ok {
mailbox = newMemoryMailBox(&mailBoxConfig{
shortChanID: shortChanID,
forwardPackets: mo.cfg.forwardPackets,
clock: mo.cfg.clock,
expiry: mo.cfg.expiry,
failMailboxUpdate: mo.cfg.failMailboxUpdate,
})
mailbox.Start()
mo.mailboxes[chanID] = mailbox
}
return mailbox
}
// BindLiveShortChanID registers that messages bound for a particular short
// channel id should be forwarded to the mailbox corresponding to the given
// channel id. This method also checks to see if there are any unclaimed
// packets for this short_chan_id. If any are found, they are delivered to the
// mailbox and removed (marked as claimed).
func (mo *mailOrchestrator) BindLiveShortChanID(mailbox MailBox,
cid lnwire.ChannelID, sid lnwire.ShortChannelID) {
mo.mu.Lock()
// Update the mapping from short channel id to mailbox's channel id.
mo.liveIndex[sid] = cid
// Retrieve any unclaimed packets destined for this mailbox.
pkts := mo.unclaimedPackets[sid]
delete(mo.unclaimedPackets, sid)
mo.mu.Unlock()
// Deliver the unclaimed packets.
for _, pkt := range pkts {
mailbox.AddPacket(pkt)
}
}
// Deliver lookups the target mailbox using the live index from short_chan_id
// to channel_id. If the mailbox is found, the message is delivered directly.
// Otherwise the packet is recorded as unclaimed, and will be delivered to the
// mailbox upon the subsequent call to BindLiveShortChanID.
func (mo *mailOrchestrator) Deliver(
sid lnwire.ShortChannelID, pkt *htlcPacket) error {
var (
mailbox MailBox
found bool
)
// First, try to find the channel id for the target short_chan_id. If
// the link is live, we will also look up the created mailbox.
mo.mu.RLock()
chanID, isLive := mo.liveIndex[sid]
if isLive {
mailbox, found = mo.mailboxes[chanID]
}
mo.mu.RUnlock()
// The link is live and target mailbox was found, deliver immediately.
if isLive && found {
return mailbox.AddPacket(pkt)
}
// If we detected that the link has not been made live, we will acquire
// the exclusive lock preemptively in order to queue this packet in the
// list of unclaimed packets.
mo.mu.Lock()
// Double check to see if the mailbox has been not made live since the
// release of the shared lock.
//
// NOTE: Checking again with the exclusive lock held prevents a race
// condition where BindLiveShortChanID is interleaved between the
// release of the shared lock, and acquiring the exclusive lock. The
// result would be stuck packets, as they wouldn't be redelivered until
// the next call to BindLiveShortChanID, which is expected to occur
// infrequently.
chanID, isLive = mo.liveIndex[sid]
if isLive {
// Reaching this point indicates the mailbox is actually live.
// We'll try to load the mailbox using the fresh channel id.
//
// NOTE: This should never create a new mailbox, as the live
// index should only be set if the mailbox had been initialized
// beforehand. However, this does ensure that this case is
// handled properly in the event that it could happen.
mailbox = mo.exclusiveGetOrCreateMailBox(chanID, sid)
mo.mu.Unlock()
// Deliver the packet to the mailbox if it was found or created.
return mailbox.AddPacket(pkt)
}
// Finally, if the channel id is still not found in the live index,
// we'll add this to the list of unclaimed packets. These will be
// delivered upon the next call to BindLiveShortChanID.
mo.unclaimedPackets[sid] = append(mo.unclaimedPackets[sid], pkt)
mo.mu.Unlock()
return nil
}
package htlcswitch
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"net"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntest/mock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
)
func isAlias(scid lnwire.ShortChannelID) bool {
return scid.BlockHeight >= 16_000_000 && scid.BlockHeight < 16_250_000
}
type mockPreimageCache struct {
sync.Mutex
preimageMap map[lntypes.Hash]lntypes.Preimage
}
func newMockPreimageCache() *mockPreimageCache {
return &mockPreimageCache{
preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
}
}
func (m *mockPreimageCache) LookupPreimage(
hash lntypes.Hash) (lntypes.Preimage, bool) {
m.Lock()
defer m.Unlock()
p, ok := m.preimageMap[hash]
return p, ok
}
func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
m.Lock()
defer m.Unlock()
for _, preimage := range preimages {
m.preimageMap[preimage.Hash()] = preimage
}
return nil
}
func (m *mockPreimageCache) SubscribeUpdates(
chanID lnwire.ShortChannelID, htlc *channeldb.HTLC,
payload *hop.Payload,
nextHopOnionBlob []byte) (*contractcourt.WitnessSubscription, error) {
return nil, nil
}
// TODO(yy): replace it with chainfee.MockEstimator.
type mockFeeEstimator struct {
byteFeeIn chan chainfee.SatPerKWeight
relayFee chan chainfee.SatPerKWeight
quit chan struct{}
}
func newMockFeeEstimator() *mockFeeEstimator {
return &mockFeeEstimator{
byteFeeIn: make(chan chainfee.SatPerKWeight),
relayFee: make(chan chainfee.SatPerKWeight),
quit: make(chan struct{}),
}
}
func (m *mockFeeEstimator) EstimateFeePerKW(
numBlocks uint32) (chainfee.SatPerKWeight, error) {
select {
case feeRate := <-m.byteFeeIn:
return feeRate, nil
case <-m.quit:
return 0, fmt.Errorf("exiting")
}
}
func (m *mockFeeEstimator) RelayFeePerKW() chainfee.SatPerKWeight {
select {
case feeRate := <-m.relayFee:
return feeRate
case <-m.quit:
return 0
}
}
func (m *mockFeeEstimator) Start() error {
return nil
}
func (m *mockFeeEstimator) Stop() error {
close(m.quit)
return nil
}
var _ chainfee.Estimator = (*mockFeeEstimator)(nil)
type mockForwardingLog struct {
sync.Mutex
events map[time.Time]channeldb.ForwardingEvent
}
func (m *mockForwardingLog) AddForwardingEvents(events []channeldb.ForwardingEvent) error {
m.Lock()
defer m.Unlock()
for _, event := range events {
m.events[event.Timestamp] = event
}
return nil
}
type mockServer struct {
started int32 // To be used atomically.
shutdown int32 // To be used atomically.
wg sync.WaitGroup
quit chan struct{}
t testing.TB
name string
messages chan lnwire.Message
protocolTraceMtx sync.Mutex
protocolTrace []lnwire.Message
id [33]byte
htlcSwitch *Switch
registry *mockInvoiceRegistry
pCache *mockPreimageCache
interceptorFuncs []messageInterceptor
}
var _ lnpeer.Peer = (*mockServer)(nil)
func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) {
signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature,
error) {
return testSig, nil
}
cfg := Config{
DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchAllChannels: db.ChannelStateDB().FetchAllChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) (
*lnwire.ChannelUpdate1, error) {
return &lnwire.ChannelUpdate1{
ShortChannelID: scid,
}, nil
},
Notifier: &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
},
FwdEventTicker: ticker.NewForce(
DefaultFwdEventInterval,
),
LogEventTicker: ticker.NewForce(DefaultLogInterval),
AckEventTicker: ticker.NewForce(DefaultAckInterval),
HtlcNotifier: &mockHTLCNotifier{},
Clock: clock.NewDefaultClock(),
MailboxDeliveryTimeout: time.Hour,
MaxFeeExposure: DefaultMaxFeeExposure,
SignAliasUpdate: signAliasUpdate,
IsAlias: isAlias,
}
return New(cfg, startingHeight)
}
func initSwitchWithTempDB(t testing.TB, startingHeight uint32) (*Switch,
error) {
tempPath := filepath.Join(t.TempDir(), "switchdb")
db := channeldb.OpenForTesting(t, tempPath)
s, err := initSwitchWithDB(startingHeight, db)
if err != nil {
return nil, err
}
return s, nil
}
func newMockServer(t testing.TB, name string, startingHeight uint32,
db *channeldb.DB, defaultDelta uint32) (*mockServer, error) {
var id [33]byte
h := sha256.Sum256([]byte(name))
copy(id[:], h[:])
pCache := newMockPreimageCache()
var (
htlcSwitch *Switch
err error
)
if db == nil {
htlcSwitch, err = initSwitchWithTempDB(t, startingHeight)
} else {
htlcSwitch, err = initSwitchWithDB(startingHeight, db)
}
if err != nil {
return nil, err
}
t.Cleanup(func() { _ = htlcSwitch.Stop() })
registry := newMockRegistry(t)
return &mockServer{
t: t,
id: id,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: registry,
htlcSwitch: htlcSwitch,
pCache: pCache,
interceptorFuncs: make([]messageInterceptor, 0),
}, nil
}
func (s *mockServer) Start() error {
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
return errors.New("mock server already started")
}
if err := s.htlcSwitch.Start(); err != nil {
return err
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
s.htlcSwitch.Stop()
}()
for {
select {
case msg := <-s.messages:
s.protocolTraceMtx.Lock()
s.protocolTrace = append(s.protocolTrace, msg)
s.protocolTraceMtx.Unlock()
var shouldSkip bool
for _, interceptor := range s.interceptorFuncs {
skip, err := interceptor(msg)
if err != nil {
s.t.Fatalf("%v: error in the "+
"interceptor: %v", s.name, err)
return
}
shouldSkip = shouldSkip || skip
}
if shouldSkip {
continue
}
if err := s.readHandler(msg); err != nil {
s.t.Fatal(err)
return
}
case <-s.quit:
return
}
}
}()
return nil
}
func (s *mockServer) QuitSignal() <-chan struct{} {
return s.quit
}
// mockHopIterator represents the test version of hop iterator which instead
// of encrypting the path in onion blob just stores the path as a list of hops.
type mockHopIterator struct {
hops []*hop.Payload
}
func newMockHopIterator(hops ...*hop.Payload) hop.Iterator {
return &mockHopIterator{hops: hops}
}
func (r *mockHopIterator) HopPayload() (*hop.Payload, hop.RouteRole, error) {
h := r.hops[0]
r.hops = r.hops[1:]
return h, hop.RouteRoleCleartext, nil
}
func (r *mockHopIterator) ExtraOnionBlob() []byte {
return nil
}
func (r *mockHopIterator) ExtractErrorEncrypter(
extracter hop.ErrorEncrypterExtracter, _ bool) (hop.ErrorEncrypter,
lnwire.FailCode) {
return extracter(nil)
}
func (r *mockHopIterator) EncodeNextHop(w io.Writer) error {
var hopLength [4]byte
binary.BigEndian.PutUint32(hopLength[:], uint32(len(r.hops)))
if _, err := w.Write(hopLength[:]); err != nil {
return err
}
for _, hop := range r.hops {
fwdInfo := hop.ForwardingInfo()
if err := encodeFwdInfo(w, &fwdInfo); err != nil {
return err
}
}
return nil
}
func encodeFwdInfo(w io.Writer, f *hop.ForwardingInfo) error {
if err := binary.Write(w, binary.BigEndian, f.NextHop); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, f.AmountToForward); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, f.OutgoingCTLV); err != nil {
return err
}
return nil
}
var _ hop.Iterator = (*mockHopIterator)(nil)
// mockObfuscator mock implementation of the failure obfuscator which only
// encodes the failure and do not makes any onion obfuscation.
type mockObfuscator struct {
ogPacket *sphinx.OnionPacket
failure lnwire.FailureMessage
}
// NewMockObfuscator initializes a dummy mockObfuscator used for testing.
func NewMockObfuscator() hop.ErrorEncrypter {
return &mockObfuscator{}
}
func (o *mockObfuscator) OnionPacket() *sphinx.OnionPacket {
return o.ogPacket
}
func (o *mockObfuscator) Type() hop.EncrypterType {
return hop.EncrypterTypeMock
}
func (o *mockObfuscator) Encode(w io.Writer) error {
return nil
}
func (o *mockObfuscator) Decode(r io.Reader) error {
return nil
}
func (o *mockObfuscator) Reextract(
extracter hop.ErrorEncrypterExtracter) error {
return nil
}
var fakeHmac = []byte("hmachmachmachmachmachmachmachmac")
func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
lnwire.OpaqueReason, error) {
o.failure = failure
var b bytes.Buffer
b.Write(fakeHmac)
if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
return reason
}
func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
var b bytes.Buffer
b.Write(fakeHmac)
b.Write(reason)
return b.Bytes()
}
// mockDeobfuscator mock implementation of the failure deobfuscator which
// only decodes the failure do not makes any onion obfuscation.
type mockDeobfuscator struct{}
func newMockDeobfuscator() ErrorDecrypter {
return &mockDeobfuscator{}
}
func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) (
*ForwardingError, error) {
if !bytes.Equal(reason[:32], fakeHmac) {
return nil, errors.New("fake decryption error")
}
reason = reason[32:]
r := bytes.NewReader(reason)
failure, err := lnwire.DecodeFailure(r, 0)
if err != nil {
return nil, err
}
return NewForwardingError(failure, 1), nil
}
var _ ErrorDecrypter = (*mockDeobfuscator)(nil)
// mockIteratorDecoder test version of hop iterator decoder which decodes the
// encoded array of hops.
type mockIteratorDecoder struct {
mu sync.RWMutex
responses map[[32]byte][]hop.DecodeHopIteratorResponse
decodeFail bool
}
func newMockIteratorDecoder() *mockIteratorDecoder {
return &mockIteratorDecoder{
responses: make(map[[32]byte][]hop.DecodeHopIteratorResponse),
}
}
func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
cltv uint32) (hop.Iterator, lnwire.FailCode) {
var b [4]byte
_, err := r.Read(b[:])
if err != nil {
return nil, lnwire.CodeTemporaryChannelFailure
}
hopLength := binary.BigEndian.Uint32(b[:])
hops := make([]*hop.Payload, hopLength)
for i := uint32(0); i < hopLength; i++ {
var f hop.ForwardingInfo
if err := decodeFwdInfo(r, &f); err != nil {
return nil, lnwire.CodeTemporaryChannelFailure
}
var nextHopBytes [8]byte
binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64())
hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
Realm: [1]byte{}, // hop.BitcoinNetwork
NextAddress: nextHopBytes,
ForwardAmount: uint64(f.AmountToForward),
OutgoingCltv: f.OutgoingCTLV,
})
}
return newMockHopIterator(hops...), lnwire.CodeNone
}
func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
reqs []hop.DecodeHopIteratorRequest) (
[]hop.DecodeHopIteratorResponse, error) {
idHash := sha256.Sum256(id)
p.mu.RLock()
if resps, ok := p.responses[idHash]; ok {
p.mu.RUnlock()
return resps, nil
}
p.mu.RUnlock()
batchSize := len(reqs)
resps := make([]hop.DecodeHopIteratorResponse, 0, batchSize)
for _, req := range reqs {
iterator, failcode := p.DecodeHopIterator(
req.OnionReader, req.RHash, req.IncomingCltv,
)
if p.decodeFail {
failcode = lnwire.CodeTemporaryChannelFailure
}
resp := hop.DecodeHopIteratorResponse{
HopIterator: iterator,
FailCode: failcode,
}
resps = append(resps, resp)
}
p.mu.Lock()
p.responses[idHash] = resps
p.mu.Unlock()
return resps, nil
}
func decodeFwdInfo(r io.Reader, f *hop.ForwardingInfo) error {
if err := binary.Read(r, binary.BigEndian, &f.NextHop); err != nil {
return err
}
if err := binary.Read(r, binary.BigEndian, &f.AmountToForward); err != nil {
return err
}
if err := binary.Read(r, binary.BigEndian, &f.OutgoingCTLV); err != nil {
return err
}
return nil
}
// messageInterceptor is function that handles the incoming peer messages and
// may decide should the peer skip the message or not.
type messageInterceptor func(m lnwire.Message) (bool, error)
// Record is used to set the function which will be triggered when new
// lnwire message was received.
func (s *mockServer) intersect(f messageInterceptor) {
s.interceptorFuncs = append(s.interceptorFuncs, f)
}
func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error {
for _, msg := range msgs {
select {
case s.messages <- msg:
case <-s.quit:
return errors.New("server is stopped")
}
}
return nil
}
func (s *mockServer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error {
panic("not implemented")
}
func (s *mockServer) readHandler(message lnwire.Message) error {
var targetChan lnwire.ChannelID
switch msg := message.(type) {
case *lnwire.UpdateAddHTLC:
targetChan = msg.ChanID
case *lnwire.UpdateFulfillHTLC:
targetChan = msg.ChanID
case *lnwire.UpdateFailHTLC:
targetChan = msg.ChanID
case *lnwire.UpdateFailMalformedHTLC:
targetChan = msg.ChanID
case *lnwire.RevokeAndAck:
targetChan = msg.ChanID
case *lnwire.CommitSig:
targetChan = msg.ChanID
case *lnwire.ChannelReady:
// Ignore
return nil
case *lnwire.ChannelReestablish:
targetChan = msg.ChanID
case *lnwire.UpdateFee:
targetChan = msg.ChanID
case *lnwire.Stfu:
targetChan = msg.ChanID
default:
return fmt.Errorf("unknown message type: %T", msg)
}
// Dispatch the commitment update message to the proper channel link
// dedicated to this channel. If the link is not found, we will discard
// the message.
link, err := s.htlcSwitch.GetLink(targetChan)
if err != nil {
return nil
}
// Create goroutine for this, in order to be able to properly stop
// the server when handler stacked (server unavailable)
link.HandleChannelUpdate(message)
return nil
}
func (s *mockServer) PubKey() [33]byte {
return s.id
}
func (s *mockServer) IdentityKey() *btcec.PublicKey {
pubkey, _ := btcec.ParsePubKey(s.id[:])
return pubkey
}
func (s *mockServer) Address() net.Addr {
return nil
}
func (s *mockServer) AddNewChannel(channel *lnpeer.NewChannel,
cancel <-chan struct{}) error {
return nil
}
func (s *mockServer) AddPendingChannel(_ lnwire.ChannelID,
cancel <-chan struct{}) error {
return nil
}
func (s *mockServer) RemovePendingChannel(_ lnwire.ChannelID) error {
return nil
}
func (s *mockServer) WipeChannel(*wire.OutPoint) {}
func (s *mockServer) LocalFeatures() *lnwire.FeatureVector {
return nil
}
func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector {
return nil
}
func (s *mockServer) Disconnect(err error) {}
func (s *mockServer) Stop() error {
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
return nil
}
close(s.quit)
s.wg.Wait()
return nil
}
func (s *mockServer) String() string {
return s.name
}
type mockChannelLink struct {
htlcSwitch *Switch
shortChanID lnwire.ShortChannelID
// Only used for zero-conf channels.
realScid lnwire.ShortChannelID
aliases []lnwire.ShortChannelID
chanID lnwire.ChannelID
peer lnpeer.Peer
mailBox MailBox
packets chan *htlcPacket
eligible bool
unadvertised bool
zeroConf bool
optionFeature bool
htlcID uint64
checkHtlcTransitResult *LinkError
checkHtlcForwardResult *LinkError
failAliasUpdate func(sid lnwire.ShortChannelID,
incoming bool) *lnwire.ChannelUpdate1
confirmedZC bool
}
// completeCircuit is a helper method for adding the finalized payment circuit
// to the switch's circuit map. In testing, this should be executed after
// receiving an htlc from the downstream packets channel.
func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error {
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
pkt.outgoingChanID = f.shortChanID
pkt.outgoingHTLCID = f.htlcID
htlc.ID = f.htlcID
keystone := Keystone{pkt.inKey(), pkt.outKey()}
err := f.htlcSwitch.circuits.OpenCircuits(keystone)
if err != nil {
return err
}
f.htlcID++
case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
if pkt.circuit != nil {
err := f.htlcSwitch.teardownCircuit(pkt)
if err != nil {
return err
}
}
}
f.mailBox.AckPacket(pkt.inKey())
return nil
}
func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error {
return f.htlcSwitch.circuits.DeleteCircuits(pkt.inKey())
}
func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
shortChanID, realScid lnwire.ShortChannelID, peer lnpeer.Peer,
eligible, unadvertised, zeroConf, optionFeature bool,
) *mockChannelLink {
aliases := make([]lnwire.ShortChannelID, 0)
var realConfirmed bool
if zeroConf {
aliases = append(aliases, shortChanID)
}
if realScid != hop.Source {
realConfirmed = true
}
return &mockChannelLink{
htlcSwitch: htlcSwitch,
chanID: chanID,
shortChanID: shortChanID,
realScid: realScid,
peer: peer,
eligible: eligible,
unadvertised: unadvertised,
zeroConf: zeroConf,
optionFeature: optionFeature,
aliases: aliases,
confirmedZC: realConfirmed,
}
}
// addAlias is not part of any interface method.
func (f *mockChannelLink) addAlias(alias lnwire.ShortChannelID) {
f.aliases = append(f.aliases, alias)
}
func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error {
f.mailBox.AddPacket(pkt)
return nil
}
func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
return 0
}
func (f *mockChannelLink) getFeeRate() chainfee.SatPerKWeight {
return 0
}
func (f *mockChannelLink) getDustClosure() dustClosure {
dustLimit := btcutil.Amount(400)
return dustHelper(
channeldb.SingleFunderTweaklessBit, dustLimit, dustLimit,
)
}
func (f *mockChannelLink) getCommitFee(remote bool) btcutil.Amount {
return 0
}
func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
}
func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
}
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError {
return f.checkHtlcForwardResult
}
func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32, _ lnwire.CustomRecords) *LinkError {
return f.checkHtlcTransitResult
}
func (f *mockChannelLink) Stats() (
uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) {
return 0, 0, 0
}
func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
f.mailBox = mailBox
f.packets = mailBox.PacketOutBox()
mailBox.SetDustClosure(f.getDustClosure())
}
func (f *mockChannelLink) attachFailAliasUpdate(closure func(
sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) {
f.failAliasUpdate = closure
}
func (f *mockChannelLink) getAliases() []lnwire.ShortChannelID {
return f.aliases
}
func (f *mockChannelLink) isZeroConf() bool {
return f.zeroConf
}
func (f *mockChannelLink) negotiatedAliasFeature() bool {
return f.optionFeature
}
func (f *mockChannelLink) confirmedScid() lnwire.ShortChannelID {
return f.realScid
}
func (f *mockChannelLink) zeroConfConfirmed() bool {
return f.confirmedZC
}
func (f *mockChannelLink) Start() error {
f.mailBox.ResetMessages()
f.mailBox.ResetPackets()
return nil
}
func (f *mockChannelLink) ChanID() lnwire.ChannelID {
return f.chanID
}
func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID {
return f.shortChanID
}
func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi {
return 99999999
}
func (f *mockChannelLink) PeerPubKey() [33]byte {
return f.peer.PubKey()
}
func (f *mockChannelLink) ChannelPoint() wire.OutPoint {
return wire.OutPoint{}
}
func (f *mockChannelLink) Stop() {}
func (f *mockChannelLink) EligibleToForward() bool { return f.eligible }
func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil }
func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
func (f *mockChannelLink) IsUnadvertised() bool { return f.unadvertised }
func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
f.eligible = true
return f.shortChanID, nil
}
func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement
return true
}
func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement
return true
}
func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement
return false
}
func (f *mockChannelLink) OnFlushedOnce(func()) {
// TODO(proofofkeags): Implement
}
func (f *mockChannelLink) OnCommitOnce(LinkDirection, func()) {
// TODO(proofofkeags): Implement
}
func (f *mockChannelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
// TODO(proofofkeags): Implement
c := make(chan fn.Result[lntypes.ChannelParty], 1)
c <- fn.Errf[lntypes.ChannelParty]("InitStfu not implemented")
return c
}
func (f *mockChannelLink) FundingCustomBlob() fn.Option[tlv.Blob] {
return fn.None[tlv.Blob]()
}
func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] {
return fn.None[tlv.Blob]()
}
// AuxBandwidth returns the bandwidth that can be used for a channel,
// expressed in milli-satoshi. This might be different from the regular
// BTC bandwidth for custom channels. This will always return fn.None()
// for a regular (non-custom) channel.
func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi,
lnwire.ShortChannelID,
fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] {
return fn.Ok(OptionalBandwidth{})
}
var _ ChannelLink = (*mockChannelLink)(nil)
const testInvoiceCltvExpiry = 6
type mockInvoiceRegistry struct {
settleChan chan lntypes.Hash
registry *invoices.InvoiceRegistry
}
type mockChainNotifier struct {
chainntnfs.ChainNotifier
}
// RegisterBlockEpochNtfn mocks a successful call to register block
// notifications.
func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Cancel: func() {},
}, nil
}
func newMockRegistry(t testing.TB) *mockInvoiceRegistry {
cdb := channeldb.OpenForTesting(t, t.TempDir())
modifierMock := &invoices.MockHtlcModifier{}
registry := invoices.NewRegistry(
cdb,
invoices.NewInvoiceExpiryWatcher(
clock.NewDefaultClock(), 0, 0, nil,
&mockChainNotifier{},
),
&invoices.RegistryConfig{
FinalCltvRejectDelta: 5,
HtlcInterceptor: modifierMock,
},
)
registry.Start()
return &mockInvoiceRegistry{
registry: registry,
}
}
func (i *mockInvoiceRegistry) LookupInvoice(ctx context.Context,
rHash lntypes.Hash) (invoices.Invoice, error) {
return i.registry.LookupInvoice(ctx, rHash)
}
func (i *mockInvoiceRegistry) SettleHodlInvoice(
ctx context.Context, preimage lntypes.Preimage) error {
return i.registry.SettleHodlInvoice(ctx, preimage)
}
func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash,
amt lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
circuitKey models.CircuitKey, hodlChan chan<- interface{},
wireCustomRecords lnwire.CustomRecords,
payload invoices.Payload) (invoices.HtlcResolution, error) {
event, err := i.registry.NotifyExitHopHtlc(
rhash, amt, expiry, currentHeight, circuitKey,
hodlChan, wireCustomRecords, payload,
)
if err != nil {
return nil, err
}
if i.settleChan != nil {
i.settleChan <- rhash
}
return event, nil
}
func (i *mockInvoiceRegistry) CancelInvoice(ctx context.Context,
payHash lntypes.Hash) error {
return i.registry.CancelInvoice(ctx, payHash)
}
func (i *mockInvoiceRegistry) AddInvoice(ctx context.Context,
invoice invoices.Invoice, paymentHash lntypes.Hash) error {
_, err := i.registry.AddInvoice(ctx, &invoice, paymentHash)
return err
}
func (i *mockInvoiceRegistry) HodlUnsubscribeAll(
subscriber chan<- interface{}) {
i.registry.HodlUnsubscribeAll(subscriber)
}
var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil)
type mockCircuitMap struct {
lookup chan *PaymentCircuit
}
var _ CircuitMap = (*mockCircuitMap)(nil)
func (m *mockCircuitMap) OpenCircuits(...Keystone) error {
return nil
}
func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
start uint64) error {
return nil
}
func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
return nil
}
func (m *mockCircuitMap) CommitCircuits(
circuit ...*PaymentCircuit) (*CircuitFwdActions, error) {
return nil, nil
}
func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit,
error) {
return nil, nil
}
func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit,
error) {
return nil, nil
}
func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit {
return <-m.lookup
}
func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
return nil
}
func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
return nil
}
func (m *mockCircuitMap) NumPending() int {
return 0
}
func (m *mockCircuitMap) NumOpen() int {
return 0
}
type mockOnionErrorDecryptor struct {
sourceIdx int
message []byte
err error
}
func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) (
*sphinx.DecryptedError, error) {
return &sphinx.DecryptedError{
SenderIdx: m.sourceIdx,
Message: m.message,
}, m.err
}
var _ htlcNotifier = (*mockHTLCNotifier)(nil)
type mockHTLCNotifier struct {
htlcNotifier //nolint:unused
}
func (h *mockHTLCNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
eventType HtlcEventType) {
}
func (h *mockHTLCNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
eventType HtlcEventType, linkErr *LinkError,
incoming bool) {
}
func (h *mockHTLCNotifier) NotifyForwardingFailEvent(key HtlcKey,
eventType HtlcEventType) {
}
func (h *mockHTLCNotifier) NotifySettleEvent(key HtlcKey,
preimage lntypes.Preimage, eventType HtlcEventType) {
}
func (h *mockHTLCNotifier) NotifyFinalHtlcEvent(key models.CircuitKey,
info channeldb.FinalHtlcInfo) {
}
package htlcswitch
import (
"fmt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
// htlcPacket is a wrapper around htlc lnwire update, which adds additional
// information which is needed by this package.
type htlcPacket struct {
// incomingChanID is the ID of the channel that we have received an incoming
// HTLC on.
incomingChanID lnwire.ShortChannelID
// outgoingChanID is the ID of the channel that we have offered or will
// offer an outgoing HTLC on.
outgoingChanID lnwire.ShortChannelID
// incomingHTLCID is the ID of the HTLC that we have received from the peer
// on the incoming channel.
incomingHTLCID uint64
// outgoingHTLCID is the ID of the HTLC that we offered to the peer on the
// outgoing channel.
outgoingHTLCID uint64
// sourceRef is used by forwarded htlcPackets to locate incoming Add
// entry in a fwdpkg owned by the incoming link. This value can be nil
// if there is no such entry, e.g. switch initiated payments.
sourceRef *channeldb.AddRef
// destRef is used to locate a settle/fail entry in the outgoing link's
// fwdpkg. If sourceRef is non-nil, this reference should be to a
// settle/fail in response to the sourceRef.
destRef *channeldb.SettleFailRef
// incomingAmount is the value in milli-satoshis that arrived on an
// incoming link.
incomingAmount lnwire.MilliSatoshi
// amount is the value of the HTLC that is being created or modified.
amount lnwire.MilliSatoshi
// htlc lnwire message type of which depends on switch request type.
htlc lnwire.Message
// obfuscator contains the necessary state to allow the switch to wrap
// any forwarded errors in an additional layer of encryption.
obfuscator hop.ErrorEncrypter
// localFailure is set to true if an HTLC fails for a local payment before
// the first hop. In this case, the failure reason is simply encoded, not
// encrypted with any shared secret.
localFailure bool
// linkFailure is non-nil for htlcs that fail at our node. This may
// occur for our own payments which fail on the outgoing link,
// or for forwards which fail in the switch or on the outgoing link.
linkFailure *LinkError
// convertedError is set to true if this is an HTLC fail that was
// created using an UpdateFailMalformedHTLC from the remote party. If
// this is true, then when forwarding this failure packet, we'll need
// to wrap it as if we were the first hop if it's a multi-hop HTLC. If
// it's a direct HTLC, then we'll decode the error as no encryption has
// taken place.
convertedError bool
// hasSource is set to true if the incomingChanID and incomingHTLCID
// fields of a forwarded fail packet are already set and do not need to
// be looked up in the circuit map.
hasSource bool
// isResolution is set to true if this packet was actually an incoming
// resolution message from an outside sub-system. We'll treat these as
// if they emanated directly from the switch. As a result, we'll
// encrypt all errors related to this packet as if we were the first
// hop.
isResolution bool
// circuit holds a reference to an Add's circuit which is persisted in
// the switch during successful forwarding.
circuit *PaymentCircuit
// incomingTimeout is the timeout that the incoming HTLC carried. This
// is the timeout of the HTLC applied to the incoming link.
incomingTimeout uint32
// outgoingTimeout is the timeout of the proposed outgoing HTLC. This
// will be extracted from the hop payload received by the incoming
// link.
outgoingTimeout uint32
// inOnionCustomRecords are user-defined records in the custom type
// range that were included in the onion payload.
inOnionCustomRecords record.CustomSet
// inWireCustomRecords are custom type range TLVs that are included
// in the incoming update_add_htlc wire message.
inWireCustomRecords lnwire.CustomRecords
// originalOutgoingChanID is used when sending back failure messages.
// It is only used for forwarded Adds on option_scid_alias channels.
// This is to avoid possible confusion if a payer uses the public SCID
// but receives a channel_update with the alias SCID. Instead, the
// payer should receive a channel_update with the public SCID.
originalOutgoingChanID lnwire.ShortChannelID
// inboundFee is the fee schedule of the incoming channel.
inboundFee models.InboundFee
}
// inKey returns the circuit key used to identify the incoming htlc.
func (p *htlcPacket) inKey() CircuitKey {
return CircuitKey{
ChanID: p.incomingChanID,
HtlcID: p.incomingHTLCID,
}
}
// outKey returns the circuit key used to identify the outgoing, forwarded htlc.
func (p *htlcPacket) outKey() CircuitKey {
return CircuitKey{
ChanID: p.outgoingChanID,
HtlcID: p.outgoingHTLCID,
}
}
// keystone returns a tuple containing the incoming and outgoing circuit keys.
func (p *htlcPacket) keystone() Keystone {
return Keystone{
InKey: p.inKey(),
OutKey: p.outKey(),
}
}
// String returns a human-readable description of the packet.
func (p *htlcPacket) String() string {
return fmt.Sprintf("keystone=%v, sourceRef=%v, destRef=%v, "+
"incomingAmount=%v, amount=%v, localFailure=%v, hasSource=%v "+
"isResolution=%v", p.keystone(), p.sourceRef, p.destRef,
p.incomingAmount, p.amount, p.localFailure, p.hasSource,
p.isResolution)
}
package htlcswitch
import (
"bytes"
"encoding/binary"
"errors"
"io"
"sync"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/multimutex"
)
var (
// networkResultStoreBucketKey is used for the root level bucket that
// stores the network result for each payment ID.
networkResultStoreBucketKey = []byte("network-result-store-bucket")
// ErrPaymentIDNotFound is an error returned if the given paymentID is
// not found.
ErrPaymentIDNotFound = errors.New("paymentID not found")
// ErrPaymentIDAlreadyExists is returned if we try to write a pending
// payment whose paymentID already exists.
ErrPaymentIDAlreadyExists = errors.New("paymentID already exists")
)
// PaymentResult wraps a decoded result received from the network after a
// payment attempt was made. This is what is eventually handed to the router
// for processing.
type PaymentResult struct {
// Preimage is set by the switch in case a sent HTLC was settled.
Preimage [32]byte
// Error is non-nil in case a HTLC send failed, and the HTLC is now
// irrevocably canceled. If the payment failed during forwarding, this
// error will be a *ForwardingError.
Error error
}
// networkResult is the raw result received from the network after a payment
// attempt has been made. Since the switch doesn't always have the necessary
// data to decode the raw message, we store it together with some meta data,
// and decode it when the router query for the final result.
type networkResult struct {
// msg is the received result. This should be of type UpdateFulfillHTLC
// or UpdateFailHTLC.
msg lnwire.Message
// unencrypted indicates whether the failure encoded in the message is
// unencrypted, and hence doesn't need to be decrypted.
unencrypted bool
// isResolution indicates whether this is a resolution message, in
// which the failure reason might not be included.
isResolution bool
}
// serializeNetworkResult serializes the networkResult.
func serializeNetworkResult(w io.Writer, n *networkResult) error {
return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution)
}
// deserializeNetworkResult deserializes the networkResult.
func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
n := &networkResult{}
if err := channeldb.ReadElements(r,
&n.msg, &n.unencrypted, &n.isResolution,
); err != nil {
return nil, err
}
return n, nil
}
// networkResultStore is a persistent store that stores any results of HTLCs in
// flight on the network. Since payment results are inherently asynchronous, it
// is used as a common access point for senders of HTLCs, to know when a result
// is back. The Switch will checkpoint any received result to the store, and
// the store will keep results and notify the callers about them.
type networkResultStore struct {
backend kvdb.Backend
// results is a map from paymentIDs to channels where subscribers to
// payment results will be notified.
results map[uint64][]chan *networkResult
resultsMtx sync.Mutex
// attemptIDMtx is a multimutex used to make sure the database and
// result subscribers map is consistent for each attempt ID in case of
// concurrent callers.
attemptIDMtx *multimutex.Mutex[uint64]
}
func newNetworkResultStore(db kvdb.Backend) *networkResultStore {
return &networkResultStore{
backend: db,
results: make(map[uint64][]chan *networkResult),
attemptIDMtx: multimutex.NewMutex[uint64](),
}
}
// storeResult stores the networkResult for the given attemptID, and notifies
// any subscribers.
func (store *networkResultStore) storeResult(attemptID uint64,
result *networkResult) error {
// We get a mutex for this attempt ID. This is needed to ensure
// consistency between the database state and the subscribers in case
// of concurrent calls.
store.attemptIDMtx.Lock(attemptID)
defer store.attemptIDMtx.Unlock(attemptID)
log.Debugf("Storing result for attemptID=%v", attemptID)
// Serialize the payment result.
var b bytes.Buffer
if err := serializeNetworkResult(&b, result); err != nil {
return err
}
var attemptIDBytes [8]byte
binary.BigEndian.PutUint64(attemptIDBytes[:], attemptID)
err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey,
)
if err != nil {
return err
}
return networkResults.Put(attemptIDBytes[:], b.Bytes())
})
if err != nil {
return err
}
// Now that the result is stored in the database, we can notify any
// active subscribers.
store.resultsMtx.Lock()
for _, res := range store.results[attemptID] {
res <- result
}
delete(store.results, attemptID)
store.resultsMtx.Unlock()
return nil
}
// subscribeResult is used to get the HTLC attempt result for the given attempt
// ID. It returns a channel on which the result will be delivered when ready.
func (store *networkResultStore) subscribeResult(attemptID uint64) (
<-chan *networkResult, error) {
// We get a mutex for this payment ID. This is needed to ensure
// consistency between the database state and the subscribers in case
// of concurrent calls.
store.attemptIDMtx.Lock(attemptID)
defer store.attemptIDMtx.Unlock(attemptID)
log.Debugf("Subscribing to result for attemptID=%v", attemptID)
var (
result *networkResult
resultChan = make(chan *networkResult, 1)
)
err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error
result, err = fetchResult(tx, attemptID)
switch {
// Result not yet available, we will notify once a result is
// available.
case err == ErrPaymentIDNotFound:
return nil
case err != nil:
return err
// The result was found, and will be returned immediately.
default:
return nil
}
}, func() {
result = nil
})
if err != nil {
return nil, err
}
// If the result was found, we can send it on the result channel
// imemdiately.
if result != nil {
resultChan <- result
return resultChan, nil
}
// Otherwise we store the result channel for when the result is
// available.
store.resultsMtx.Lock()
store.results[attemptID] = append(
store.results[attemptID], resultChan,
)
store.resultsMtx.Unlock()
return resultChan, nil
}
// getResult attempts to immediately fetch the result for the given pid from
// the store. If no result is available, ErrPaymentIDNotFound is returned.
func (store *networkResultStore) getResult(pid uint64) (
*networkResult, error) {
var result *networkResult
err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error
result, err = fetchResult(tx, pid)
return err
}, func() {
result = nil
})
if err != nil {
return nil, err
}
return result, nil
}
func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) {
var attemptIDBytes [8]byte
binary.BigEndian.PutUint64(attemptIDBytes[:], pid)
networkResults := tx.ReadBucket(networkResultStoreBucketKey)
if networkResults == nil {
return nil, ErrPaymentIDNotFound
}
// Check whether a result is already available.
resultBytes := networkResults.Get(attemptIDBytes[:])
if resultBytes == nil {
return nil, ErrPaymentIDNotFound
}
// Decode the result we found.
r := bytes.NewReader(resultBytes)
return deserializeNetworkResult(r)
}
// cleanStore removes all entries from the store, except the payment IDs given.
// NOTE: Since every result not listed in the keep map will be deleted, care
// should be taken to ensure no new payment attempts are being made
// concurrently while this process is ongoing, as its result might end up being
// deleted.
func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
return kvdb.Update(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey,
)
if err != nil {
return err
}
// Iterate through the bucket, deleting all items not in the
// keep map.
var toClean [][]byte
if err := networkResults.ForEach(func(k, _ []byte) error {
pid := binary.BigEndian.Uint64(k)
if _, ok := keep[pid]; ok {
return nil
}
toClean = append(toClean, k)
return nil
}); err != nil {
return err
}
for _, k := range toClean {
err := networkResults.Delete(k)
if err != nil {
return err
}
}
if len(toClean) > 0 {
log.Infof("Removed %d stale entries from network "+
"result store", len(toClean))
}
return nil
}, func() {})
}
package htlcswitch
import (
"fmt"
"sync"
"time"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrInvalidStfu indicates that the Stfu we have received is invalid.
// This can happen in instances where we have not sent Stfu but we have
// received one with the initiator field set to false.
ErrInvalidStfu = fmt.Errorf("stfu received is invalid")
// ErrStfuAlreadySent indicates that this channel has already sent an
// Stfu message for this negotiation.
ErrStfuAlreadySent = fmt.Errorf("stfu already sent")
// ErrStfuAlreadyRcvd indicates that this channel has already received
// an Stfu message for this negotiation.
ErrStfuAlreadyRcvd = fmt.Errorf("stfu already received")
// ErrNoQuiescenceInitiator indicates that the caller has requested the
// quiescence initiator for a channel that is not yet quiescent.
ErrNoQuiescenceInitiator = fmt.Errorf(
"indeterminate quiescence initiator: channel is not quiescent",
)
// ErrPendingRemoteUpdates indicates that we have received an Stfu while
// the remote party has issued updates that are not yet bilaterally
// committed.
ErrPendingRemoteUpdates = fmt.Errorf(
"stfu received with pending remote updates",
)
// ErrPendingLocalUpdates indicates that we are attempting to send an
// Stfu while we have issued updates that are not yet bilaterally
// committed.
ErrPendingLocalUpdates = fmt.Errorf(
"stfu send attempted with pending local updates",
)
// ErrQuiescenceTimeout indicates that the quiescer has been quiesced
// beyond the allotted time.
ErrQuiescenceTimeout = fmt.Errorf(
"quiescence timeout",
)
)
const defaultQuiescenceTimeout = 30 * time.Second
type StfuReq = fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]]
// Quiescer is the public interface of the quiescence mechanism. Callers of the
// quiescence API should not need any methods besides the ones detailed here.
type Quiescer interface {
// IsQuiescent returns true if the state machine has been driven all the
// way to completion. If this returns true, processes that depend on
// channel quiescence may proceed.
IsQuiescent() bool
// QuiescenceInitiator determines which ChannelParty is the initiator of
// quiescence for the purposes of downstream protocols. If the channel
// is not currently quiescent, this method will return
// ErrNoDownstreamLeader.
QuiescenceInitiator() fn.Result[lntypes.ChannelParty]
// InitStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet
// because we need to wait for the link to give us a valid opportunity
// to do so.
InitStfu(req StfuReq)
// RecvStfu is called when we receive an Stfu message from the remote.
RecvStfu(stfu lnwire.Stfu) error
// CanRecvUpdates returns true if we haven't yet received an Stfu which
// would mark the end of the remote's ability to send updates.
CanRecvUpdates() bool
// CanSendUpdates returns true if we haven't yet sent an Stfu which
// would mark the end of our ability to send updates.
CanSendUpdates() bool
// SendOwedStfu sends Stfu if it owes one. It returns an error if the
// state machine is in an invalid state.
SendOwedStfu() error
// OnResume accepts a no return closure that will run when the quiescer
// is resumed.
OnResume(hook func())
// Resume runs all of the deferred actions that have accumulated while
// the channel has been quiescent and then resets the quiescer state to
// its initial state.
Resume()
}
// QuiescerCfg is a config structure used to initialize a quiescer giving it the
// appropriate functionality to interact with the channel state that the
// quiescer must syncrhonize with.
type QuiescerCfg struct {
// chanID marks what channel we are managing the state machine for. This
// is important because the quiescer needs to know the ChannelID to
// construct the Stfu message.
chanID lnwire.ChannelID
// channelInitiator indicates which ChannelParty originally opened the
// channel. This is used to break ties when both sides of the channel
// send Stfu claiming to be the initiator.
channelInitiator lntypes.ChannelParty
// sendMsg is a function that can be used to send an Stfu message over
// the wire.
sendMsg func(lnwire.Stfu) error
// timeoutDuration is the Duration that we will wait from the moment the
// channel is considered quiescent before we call the onTimeout function
timeoutDuration time.Duration
// onTimeout is a function that will be called in the event that the
// Quiescer has not been resumed before the timeout is reached. If
// Quiescer.Resume is called before the timeout has been raeached, then
// onTimeout will not be called until the quiescer reaches a quiescent
// state again.
onTimeout func()
}
// QuiescerLive is a state machine that tracks progression through the
// quiescence protocol.
type QuiescerLive struct {
cfg QuiescerCfg
// log is a quiescer-scoped logging instance.
log btclog.Logger
// localInit indicates whether our path through this state machine was
// initiated by our node. This can be true or false independently of
// remoteInit.
localInit bool
// remoteInit indicates whether we received Stfu from our peer where the
// message indicated that the remote node believes it was the initiator.
// This can be true or false independently of localInit.
remoteInit bool
// sent tracks whether or not we have emitted Stfu for sending.
sent bool
// received tracks whether or not we have received Stfu from our peer.
received bool
// activeQuiescenceRequest is a possibly None Request that we should
// resolve when we complete quiescence.
activeQuiescenceReq fn.Option[StfuReq]
// resumeQueue is a slice of hooks that will be called when the quiescer
// is resumed. These are actions that needed to be deferred while the
// channel was quiescent.
resumeQueue []func()
// timeoutTimer is a field that is used to hold onto the timeout job
// when we reach quiescence.
timeoutTimer *time.Timer
sync.RWMutex
}
// NewQuiescer creates a new quiescer for the given channel.
func NewQuiescer(cfg QuiescerCfg) Quiescer {
logPrefix := fmt.Sprintf("Quiescer(%v):", cfg.chanID)
return &QuiescerLive{
cfg: cfg,
log: log.WithPrefix(logPrefix),
}
}
// RecvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu) error {
q.Lock()
defer q.Unlock()
return q.recvStfu(msg)
}
// recvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) recvStfu(msg lnwire.Stfu) error {
// At the time of this writing, this check that we have already received
// an Stfu is not strictly necessary, according to the specification.
// However, it is fishy if we do and it is unclear how we should handle
// such a case so we will err on the side of caution.
if q.received {
return fmt.Errorf("%w for channel %v", ErrStfuAlreadyRcvd,
q.cfg.chanID)
}
// We need to check that the Stfu we are receiving is valid.
if !q.sent && !msg.Initiator {
return fmt.Errorf("%w for channel %v", ErrInvalidStfu,
q.cfg.chanID)
}
if !q.canRecvStfu() {
return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates,
q.cfg.chanID)
}
q.received = true
// If the remote party sets the initiator bit to true then we will
// remember that they are making a claim to the initiator role. This
// does not necessarily mean they will get it, though.
q.remoteInit = msg.Initiator
// Since we just received an Stfu, we may have a newly quiesced state.
// If so, we will try to resolve any outstanding StfuReqs.
q.tryResolveStfuReq()
if q.isQuiescent() {
q.startTimeout()
}
return nil
}
// MakeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) MakeStfu() fn.Result[lnwire.Stfu] {
q.RLock()
defer q.RUnlock()
return q.makeStfu()
}
// makeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) makeStfu() fn.Result[lnwire.Stfu] {
if q.sent {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrStfuAlreadySent, q.cfg.chanID)
}
if !q.canSendStfu() {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrPendingLocalUpdates, q.cfg.chanID)
}
stfu := lnwire.Stfu{
ChanID: q.cfg.chanID,
Initiator: q.localInit,
}
return fn.Ok(stfu)
}
// OweStfu returns true if we owe the other party an Stfu. We owe the remote an
// Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu.
func (q *QuiescerLive) OweStfu() bool {
q.RLock()
defer q.RUnlock()
return q.oweStfu()
}
// oweStfu returns true if we owe the other party an Stfu. We owe the remote an
// Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu.
func (q *QuiescerLive) oweStfu() bool {
return (q.received || q.localInit) && !q.sent
}
// NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
// we have sent but not yet received an Stfu.
func (q *QuiescerLive) NeedStfu() bool {
q.RLock()
defer q.RUnlock()
return q.needStfu()
}
// needStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
// we have sent but not yet received an Stfu.
func (q *QuiescerLive) needStfu() bool {
q.RLock()
defer q.RUnlock()
return q.sent && !q.received
}
// IsQuiescent returns true if the state machine has been driven all the way to
// completion. If this returns true, processes that depend on channel quiescence
// may proceed.
func (q *QuiescerLive) IsQuiescent() bool {
q.RLock()
defer q.RUnlock()
return q.isQuiescent()
}
// isQuiescent returns true if the state machine has been driven all the way to
// completion. If this returns true, processes that depend on channel quiescence
// may proceed.
func (q *QuiescerLive) isQuiescent() bool {
return q.sent && q.received
}
// QuiescenceInitiator determines which ChannelParty is the initiator of
// quiescence for the purposes of downstream protocols. If the channel is not
// currently quiescent, this method will return ErrNoQuiescenceInitiator.
func (q *QuiescerLive) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] {
q.RLock()
defer q.RUnlock()
return q.quiescenceInitiator()
}
// quiescenceInitiator determines which ChannelParty is the initiator of
// quiescence for the purposes of downstream protocols. If the channel is not
// currently quiescent, this method will return ErrNoQuiescenceInitiator.
func (q *QuiescerLive) quiescenceInitiator() fn.Result[lntypes.ChannelParty] {
switch {
case !q.isQuiescent():
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
case q.localInit && q.remoteInit:
// In the case of a tie, the channel initiator wins.
return fn.Ok(q.cfg.channelInitiator)
case q.localInit:
return fn.Ok(lntypes.Local)
case q.remoteInit:
return fn.Ok(lntypes.Remote)
}
// unreachable
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
}
// CanSendUpdates returns true if we haven't yet sent an Stfu which would mark
// the end of our ability to send updates.
func (q *QuiescerLive) CanSendUpdates() bool {
q.RLock()
defer q.RUnlock()
return q.canSendUpdates()
}
// canSendUpdates returns true if we haven't yet sent an Stfu which would mark
// the end of our ability to send updates.
func (q *QuiescerLive) canSendUpdates() bool {
return !q.sent && !q.localInit
}
// CanRecvUpdates returns true if we haven't yet received an Stfu which would
// mark the end of the remote's ability to send updates.
func (q *QuiescerLive) CanRecvUpdates() bool {
q.RLock()
defer q.RUnlock()
return q.canRecvUpdates()
}
// canRecvUpdates returns true if we haven't yet received an Stfu which would
// mark the end of the remote's ability to send updates.
func (q *QuiescerLive) canRecvUpdates() bool {
return !q.received
}
// CanSendStfu returns true if we can send an Stfu.
func (q *QuiescerLive) CanSendStfu(numPendingLocalUpdates uint64) bool {
q.RLock()
defer q.RUnlock()
return q.canSendStfu()
}
// canSendStfu returns true if we can send an Stfu.
func (q *QuiescerLive) canSendStfu() bool {
return !q.sent
}
// CanRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) CanRecvStfu() bool {
q.RLock()
defer q.RUnlock()
return q.canRecvStfu()
}
// canRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) canRecvStfu() bool {
return !q.received
}
// SendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) SendOwedStfu() error {
q.Lock()
defer q.Unlock()
return q.sendOwedStfu()
}
// sendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) sendOwedStfu() error {
if !q.oweStfu() || !q.canSendStfu() {
return nil
}
err := q.makeStfu().Sink(q.cfg.sendMsg)
if err == nil {
q.sent = true
// Since we just sent an Stfu, we may have a newly quiesced
// state. If so, we will try to resolve any outstanding
// StfuReqs.
q.tryResolveStfuReq()
if q.isQuiescent() {
q.startTimeout()
}
}
return err
}
// TryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *QuiescerLive) TryResolveStfuReq() {
q.Lock()
defer q.Unlock()
q.tryResolveStfuReq()
}
// tryResolveStfuReq attempts to resolve the active quiescence request if the
// state machine has reached a quiescent state.
func (q *QuiescerLive) tryResolveStfuReq() {
q.activeQuiescenceReq.WhenSome(
func(req StfuReq) {
if q.isQuiescent() {
req.Resolve(q.quiescenceInitiator())
q.activeQuiescenceReq = fn.None[StfuReq]()
}
},
)
}
// InitStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *QuiescerLive) InitStfu(req StfuReq) {
q.Lock()
defer q.Unlock()
q.initStfu(req)
}
// initStfu instructs the quiescer that we intend to begin a quiescence
// negotiation where we are the initiator. We don't yet send stfu yet because
// we need to wait for the link to give us a valid opportunity to do so.
func (q *QuiescerLive) initStfu(req StfuReq) {
if q.localInit {
req.Resolve(fn.Errf[lntypes.ChannelParty](
"quiescence already requested",
))
return
}
q.localInit = true
q.activeQuiescenceReq = fn.Some(req)
}
// OnResume accepts a no return closure that will run when the quiescer is
// resumed.
func (q *QuiescerLive) OnResume(hook func()) {
q.Lock()
defer q.Unlock()
q.onResume(hook)
}
// onResume accepts a no return closure that will run when the quiescer is
// resumed.
func (q *QuiescerLive) onResume(hook func()) {
q.resumeQueue = append(q.resumeQueue, hook)
}
// Resume runs all of the deferred actions that have accumulated while the
// channel has been quiescent and then resets the quiescer state to its initial
// state.
func (q *QuiescerLive) Resume() {
q.Lock()
defer q.Unlock()
q.resume()
}
// resume runs all of the deferred actions that have accumulated while the
// channel has been quiescent and then resets the quiescer state to its initial
// state.
func (q *QuiescerLive) resume() {
q.log.Debug("quiescence terminated, resuming htlc traffic")
// since we are resuming we want to cancel the quiescence timeout
// action.
q.cancelTimeout()
for _, hook := range q.resumeQueue {
hook()
}
q.localInit = false
q.remoteInit = false
q.sent = false
q.received = false
q.resumeQueue = nil
}
// startTimeout starts the timeout function that fires if the quiescer remains
// in a quiesced state for too long. If this function is called multiple times
// only the last one will have an effect.
func (q *QuiescerLive) startTimeout() {
if q.cfg.onTimeout == nil {
return
}
old := q.timeoutTimer
q.timeoutTimer = time.AfterFunc(q.cfg.timeoutDuration, q.cfg.onTimeout)
if old != nil {
old.Stop()
}
}
// cancelTimeout cancels the timeout function that would otherwise fire if the
// quiescer remains in a quiesced state too long. If this function is called
// before startTimeout or after another call to cancelTimeout, the effect will
// be a noop.
func (q *QuiescerLive) cancelTimeout() {
if q.timeoutTimer != nil {
q.timeoutTimer.Stop()
q.timeoutTimer = nil
}
}
type quiescerNoop struct{}
var _ Quiescer = (*quiescerNoop)(nil)
func (q *quiescerNoop) InitStfu(req StfuReq) {
req.Resolve(fn.Errf[lntypes.ChannelParty]("quiescence not supported"))
}
func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu) error { return nil }
func (q *quiescerNoop) CanRecvUpdates() bool { return true }
func (q *quiescerNoop) CanSendUpdates() bool { return true }
func (q *quiescerNoop) SendOwedStfu() error { return nil }
func (q *quiescerNoop) IsQuiescent() bool { return false }
func (q *quiescerNoop) OnResume(hook func()) { hook() }
func (q *quiescerNoop) Resume() {}
func (q *quiescerNoop) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] {
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
}
package htlcswitch
import (
"bytes"
"io"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// resBucketKey is used for the root level bucket that stores the
// CircuitKey -> ResolutionMsg mapping.
resBucketKey = []byte("resolution-store-bucket-key")
// errResMsgNotFound is used to let callers know that the resolution
// message was not found for the given CircuitKey. This is used in the
// checkResolutionMsg function.
errResMsgNotFound = errors.New("resolution message not found")
)
// resolutionStore contains ResolutionMsgs received from the contractcourt. The
// Switch deletes these from the store when the underlying circuit has been
// removed via DeleteCircuits. If the circuit hasn't been deleted, the Switch
// will dispatch the ResolutionMsg to a link if this was a multi-hop HTLC or to
// itself if the Switch initiated the payment.
type resolutionStore struct {
backend kvdb.Backend
}
func newResolutionStore(db kvdb.Backend) *resolutionStore {
return &resolutionStore{
backend: db,
}
}
// addResolutionMsg persists a ResolutionMsg to the resolutionStore.
func (r *resolutionStore) addResolutionMsg(
resMsg *contractcourt.ResolutionMsg) error {
// The outKey will be the database key.
outKey := &CircuitKey{
ChanID: resMsg.SourceChan,
HtlcID: resMsg.HtlcIndex,
}
var resBuf bytes.Buffer
if err := serializeResolutionMsg(&resBuf, resMsg); err != nil {
return err
}
err := kvdb.Update(r.backend, func(tx kvdb.RwTx) error {
resBucket, err := tx.CreateTopLevelBucket(resBucketKey)
if err != nil {
return err
}
return resBucket.Put(outKey.Bytes(), resBuf.Bytes())
}, func() {})
if err != nil {
return err
}
return nil
}
// checkResolutionMsg returns nil if the resolution message is found in the
// store. It returns an error if no resolution message was found for the
// passed outKey or if a database error occurred.
func (r *resolutionStore) checkResolutionMsg(outKey *CircuitKey) error {
err := kvdb.View(r.backend, func(tx kvdb.RTx) error {
resBucket := tx.ReadBucket(resBucketKey)
if resBucket == nil {
// Return an error if the bucket doesn't exist.
return errResMsgNotFound
}
msg := resBucket.Get(outKey.Bytes())
if msg == nil {
// Return the not found error since no message exists
// for this CircuitKey.
return errResMsgNotFound
}
// Return nil to indicate that the message was found.
return nil
}, func() {})
if err != nil {
return err
}
return nil
}
// fetchAllResolutionMsg returns a slice of all stored ResolutionMsgs. This is
// used by the Switch on start-up.
func (r *resolutionStore) fetchAllResolutionMsg() (
[]*contractcourt.ResolutionMsg, error) {
var msgs []*contractcourt.ResolutionMsg
err := kvdb.View(r.backend, func(tx kvdb.RTx) error {
resBucket := tx.ReadBucket(resBucketKey)
if resBucket == nil {
return nil
}
return resBucket.ForEach(func(k, v []byte) error {
kr := bytes.NewReader(k)
outKey := &CircuitKey{}
if err := outKey.Decode(kr); err != nil {
return err
}
vr := bytes.NewReader(v)
resMsg, err := deserializeResolutionMsg(vr)
if err != nil {
return err
}
// Set the CircuitKey values on the ResolutionMsg.
resMsg.SourceChan = outKey.ChanID
resMsg.HtlcIndex = outKey.HtlcID
msgs = append(msgs, resMsg)
return nil
})
}, func() {
msgs = nil
})
if err != nil {
return nil, err
}
return msgs, nil
}
// deleteResolutionMsg removes a ResolutionMsg with the passed-in CircuitKey.
func (r *resolutionStore) deleteResolutionMsg(outKey *CircuitKey) error {
err := kvdb.Update(r.backend, func(tx kvdb.RwTx) error {
resBucket, err := tx.CreateTopLevelBucket(resBucketKey)
if err != nil {
return err
}
return resBucket.Delete(outKey.Bytes())
}, func() {})
return err
}
// serializeResolutionMsg writes part of a ResolutionMsg to the passed
// io.Writer.
func serializeResolutionMsg(w io.Writer,
resMsg *contractcourt.ResolutionMsg) error {
isFail := resMsg.Failure != nil
if err := channeldb.WriteElement(w, isFail); err != nil {
return err
}
// If this is a failure message, then we're done serializing.
if isFail {
return nil
}
// Else this is a settle message, and we need to write the preimage.
return channeldb.WriteElement(w, *resMsg.PreImage)
}
// deserializeResolutionMsg reads part of a ResolutionMsg from the passed
// io.Reader.
func deserializeResolutionMsg(r io.Reader) (*contractcourt.ResolutionMsg,
error) {
resMsg := &contractcourt.ResolutionMsg{}
var isFail bool
if err := channeldb.ReadElements(r, &isFail); err != nil {
return nil, err
}
// If a failure resolution msg was stored, set the Failure field.
if isFail {
failureMsg := &lnwire.FailPermanentChannelFailure{}
resMsg.Failure = failureMsg
return resMsg, nil
}
var preimage [32]byte
resMsg.PreImage = &preimage
// Else this is a settle resolution msg and we will read the preimage.
if err := channeldb.ReadElement(r, resMsg.PreImage); err != nil {
return nil, err
}
return resMsg, nil
}
package htlcswitch
import (
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
)
// defaultSequenceBatchSize specifies the window of sequence numbers that are
// allocated for each write to disk made by the sequencer.
const defaultSequenceBatchSize = 1000
// Sequencer emits sequence numbers for locally initiated HTLCs. These are
// only used internally for tracking pending payments, however they must be
// unique in order to avoid circuit key collision in the circuit map.
type Sequencer interface {
// NextID returns a unique sequence number for each invocation.
NextID() (uint64, error)
}
var (
// nextPaymentIDKey identifies the bucket that will keep track of the
// persistent sequence numbers for payments.
nextPaymentIDKey = []byte("next-payment-id-key")
// ErrSequencerCorrupted signals that the persistence engine was not
// initialized, or has been corrupted since startup.
ErrSequencerCorrupted = errors.New(
"sequencer database has been corrupted")
)
// persistentSequencer is a concrete implementation of IDGenerator, that uses
// channeldb to allocate sequence numbers.
type persistentSequencer struct {
db *channeldb.DB
mu sync.Mutex
nextID uint64
horizonID uint64
}
// NewPersistentSequencer initializes a new sequencer using a channeldb backend.
func NewPersistentSequencer(db *channeldb.DB) (Sequencer, error) {
g := &persistentSequencer{
db: db,
}
// Ensure the database bucket is created before any updates are
// performed.
if err := g.initDB(); err != nil {
return nil, err
}
return g, nil
}
// NextID returns a unique sequence number for every invocation, persisting the
// assignment to avoid reuse.
func (s *persistentSequencer) NextID() (uint64, error) {
// nextID will be the unique sequence number returned if no errors are
// encountered.
var nextID uint64
// If our sequence batch has not been exhausted, we can allocate the
// next identifier in the range.
s.mu.Lock()
defer s.mu.Unlock()
if s.nextID < s.horizonID {
nextID = s.nextID
s.nextID++
return nextID, nil
}
// Otherwise, our sequence batch has been exhausted. We use the last
// known sequence number on disk to mark the beginning of the next
// sequence batch, and allocate defaultSequenceBatchSize (1000) at a
// time.
//
// NOTE: This also will happen on the first invocation after startup,
// i.e. when nextID and horizonID are both 0. The next sequence batch to be
// allocated will start from the last known tip on disk, which is fine
// as we only require uniqueness of the allocated numbers.
var nextHorizonID uint64
if err := kvdb.Update(s.db, func(tx kvdb.RwTx) error {
nextIDBkt := tx.ReadWriteBucket(nextPaymentIDKey)
if nextIDBkt == nil {
return ErrSequencerCorrupted
}
nextID = nextIDBkt.Sequence()
nextHorizonID = nextID + defaultSequenceBatchSize
// Cannot fail when used in Update.
nextIDBkt.SetSequence(nextHorizonID)
return nil
}, func() {
nextHorizonID = 0
}); err != nil {
return 0, err
}
// Never assign index zero, to avoid collisions with the EmptyKeystone.
if nextID == 0 {
nextID++
}
// If our batch sequence allocation succeed, update our in-memory values
// so we can continue to allocate sequence numbers without hitting disk.
// The nextID is incremented by one in memory so the in can be used
// issued directly on the next invocation.
s.nextID = nextID + 1
s.horizonID = nextHorizonID
return nextID, nil
}
// initDB populates the bucket used to generate payment sequence numbers.
func (s *persistentSequencer) initDB() error {
return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(nextPaymentIDKey)
return err
}, func() {})
}
package htlcswitch
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
)
const (
// DefaultFwdEventInterval is the duration between attempts to flush
// pending forwarding events to disk.
DefaultFwdEventInterval = 15 * time.Second
// DefaultLogInterval is the duration between attempts to log statistics
// about forwarding events.
DefaultLogInterval = 10 * time.Second
// DefaultAckInterval is the duration between attempts to ack any settle
// fails in a forwarding package.
DefaultAckInterval = 15 * time.Second
// DefaultMailboxDeliveryTimeout is the duration after which Adds will
// be cancelled if they could not get added to an outgoing commitment.
DefaultMailboxDeliveryTimeout = time.Minute
)
var (
// ErrChannelLinkNotFound is used when channel link hasn't been found.
ErrChannelLinkNotFound = errors.New("channel link not found")
// ErrDuplicateAdd signals that the ADD htlc was already forwarded
// through the switch and is locked into another commitment txn.
ErrDuplicateAdd = errors.New("duplicate add HTLC detected")
// ErrUnknownErrorDecryptor signals that we were unable to locate the
// error decryptor for this payment. This is likely due to restarting
// the daemon.
ErrUnknownErrorDecryptor = errors.New("unknown error decryptor")
// ErrSwitchExiting signaled when the switch has received a shutdown
// request.
ErrSwitchExiting = errors.New("htlcswitch shutting down")
// ErrNoLinksFound is an error returned when we attempt to retrieve the
// active links in the switch for a specific destination.
ErrNoLinksFound = errors.New("no channel links found")
// ErrUnreadableFailureMessage is returned when the failure message
// cannot be decrypted.
ErrUnreadableFailureMessage = errors.New("unreadable failure message")
// ErrLocalAddFailed signals that the ADD htlc for a local payment
// failed to be processed.
ErrLocalAddFailed = errors.New("local add HTLC failed")
// errFeeExposureExceeded is only surfaced to callers of SendHTLC and
// signals that sending the HTLC would exceed the outgoing link's fee
// exposure threshold.
errFeeExposureExceeded = errors.New("fee exposure exceeded")
// DefaultMaxFeeExposure is the default threshold after which we'll
// fail payments if they increase our fee exposure. This is currently
// set to 500m msats.
DefaultMaxFeeExposure = lnwire.MilliSatoshi(500_000_000)
)
// plexPacket encapsulates switch packet and adds error channel to receive
// error from request handler.
type plexPacket struct {
pkt *htlcPacket
err chan error
}
// ChanClose represents a request which close a particular channel specified by
// its id.
type ChanClose struct {
// CloseType is a variable which signals the type of channel closure the
// peer should execute.
CloseType contractcourt.ChannelCloseType
// ChanPoint represent the id of the channel which should be closed.
ChanPoint *wire.OutPoint
// TargetFeePerKw is the ideal fee that was specified by the caller.
// This value is only utilized if the closure type is CloseRegular.
// This will be the starting offered fee when the fee negotiation
// process for the cooperative closure transaction kicks off.
TargetFeePerKw chainfee.SatPerKWeight
// MaxFee is the highest fee the caller is willing to pay.
//
// NOTE: This field is only respected if the caller is the initiator of
// the channel.
MaxFee chainfee.SatPerKWeight
// DeliveryScript is an optional delivery script to pay funds out to.
DeliveryScript lnwire.DeliveryAddress
// Updates is used by request creator to receive the notifications about
// execution of the close channel request.
Updates chan interface{}
// Err is used by request creator to receive request execution error.
Err chan error
// Ctx is a context linked to the lifetime of the caller.
Ctx context.Context //nolint:containedctx
}
// Config defines the configuration for the service. ALL elements within the
// configuration MUST be non-nil for the service to carry out its duties.
type Config struct {
// FwdingLog is an interface that will be used by the switch to log
// forwarding events. A forwarding event happens each time a payment
// circuit is successfully completed. So when we forward an HTLC, and a
// settle is eventually received.
FwdingLog ForwardingLog
// LocalChannelClose kicks-off the workflow to execute a cooperative or
// forced unilateral closure of the channel initiated by a local
// subsystem.
LocalChannelClose func(pubKey []byte, request *ChanClose)
// DB is the database backend that will be used to back the switch's
// persistent circuit map.
DB kvdb.Backend
// FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchAllChannels is a function that fetches all pending open, open,
// and waiting close channels from the database.
FetchAllChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// SwitchPackager provides access to the forwarding packages of all
// active channels. This gives the switch the ability to read arbitrary
// forwarding packages, and ack settles and fails contained within them.
SwitchPackager channeldb.FwdOperator
// ExtractErrorEncrypter is an interface allowing switch to reextract
// error encrypters stored in the circuit map on restarts, since they
// are not stored directly within the database.
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
// FetchLastChannelUpdate retrieves the latest routing policy for a
// target channel. This channel will typically be the outgoing channel
// specified when we receive an incoming HTLC. This will be used to
// provide payment senders our latest policy when sending encrypted
// error messages.
FetchLastChannelUpdate func(lnwire.ShortChannelID) (
*lnwire.ChannelUpdate1, error)
// Notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
Notifier chainntnfs.ChainNotifier
// HtlcNotifier is an instance of a htlcNotifier which we will pipe htlc
// events through.
HtlcNotifier htlcNotifier
// FwdEventTicker is a signal that instructs the htlcswitch to flush any
// pending forwarding events.
FwdEventTicker ticker.Ticker
// LogEventTicker is a signal instructing the htlcswitch to log
// aggregate stats about it's forwarding during the last interval.
LogEventTicker ticker.Ticker
// AckEventTicker is a signal instructing the htlcswitch to ack any settle
// fails in forwarding packages.
AckEventTicker ticker.Ticker
// AllowCircularRoute is true if the user has configured their node to
// allow forwards that arrive and depart our node over the same channel.
AllowCircularRoute bool
// RejectHTLC is a flag that instructs the htlcswitch to reject any
// HTLCs that are not from the source hop.
RejectHTLC bool
// Clock is a time source for the switch.
Clock clock.Clock
// MailboxDeliveryTimeout is the interval after which Adds will be
// cancelled if they have not been yet been delivered to a link. The
// computed deadline will expiry this long after the Adds are added to
// a mailbox via AddPacket.
MailboxDeliveryTimeout time.Duration
// MaxFeeExposure is the threshold in milli-satoshis after which we'll
// fail incoming or outgoing payments for a particular channel.
MaxFeeExposure lnwire.MilliSatoshi
// SignAliasUpdate is used when sending FailureMessages backwards for
// option_scid_alias channels. This avoids a potential privacy leak by
// replacing the public, confirmed SCID with the alias in the
// ChannelUpdate.
SignAliasUpdate func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature,
error)
// IsAlias returns whether or not a given SCID is an alias.
IsAlias func(scid lnwire.ShortChannelID) bool
}
// Switch is the central messaging bus for all incoming/outgoing HTLCs.
// Connected peers with active channels are treated as named interfaces which
// refer to active channels as links. A link is the switch's message
// communication point with the goroutine that manages an active channel. New
// links are registered each time a channel is created, and unregistered once
// the channel is closed. The switch manages the hand-off process for multi-hop
// HTLCs, forwarding HTLCs initiated from within the daemon, and finally
// notifies users local-systems concerning their outstanding payment requests.
type Switch struct {
started int32 // To be used atomically.
shutdown int32 // To be used atomically.
// bestHeight is the best known height of the main chain. The links will
// be used this information to govern decisions based on HTLC timeouts.
// This will be retrieved by the registered links atomically.
bestHeight uint32
wg sync.WaitGroup
quit chan struct{}
// cfg is a copy of the configuration struct that the htlc switch
// service was initialized with.
cfg *Config
// networkResults stores the results of payments initiated by the user.
// The store is used to later look up the payments and notify the
// user of the result when they are complete. Each payment attempt
// should be given a unique integer ID when it is created, otherwise
// results might be overwritten.
networkResults *networkResultStore
// circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator.
circuits CircuitMap
// mailOrchestrator manages the lifecycle of mailboxes used throughout
// the switch, and facilitates delayed delivery of packets to links that
// later come online.
mailOrchestrator *mailOrchestrator
// indexMtx is a read/write mutex that protects the set of indexes
// below.
indexMtx sync.RWMutex
// pendingLinkIndex holds links that have not had their final, live
// short_chan_id assigned.
pendingLinkIndex map[lnwire.ChannelID]ChannelLink
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// forwardingIndex is an index which is consulted by the switch when it
// needs to locate the next hop to forward an incoming/outgoing HTLC
// update to/from.
//
// TODO(roasbeef): eventually add a NetworkHop mapping before the
// ChannelLink
forwardingIndex map[lnwire.ShortChannelID]ChannelLink
// interfaceIndex maps the compressed public key of a peer to all the
// channels that the switch maintains with that peer.
interfaceIndex map[[33]byte]map[lnwire.ChannelID]ChannelLink
// linkStopIndex stores the currently stopping ChannelLinks,
// represented by their ChannelID. The key is the link's ChannelID and
// the value is a chan that is closed when the link has fully stopped.
// This map is only added to if RemoveLink is called and is not added
// to when the Switch is shutting down and calls Stop() on each link.
//
// MUST be used with the indexMtx.
linkStopIndex map[lnwire.ChannelID]chan struct{}
// htlcPlex is the channel which all connected links use to coordinate
// the setup/teardown of Sphinx (onion routing) payment circuits.
// Active links forward any add/settle messages over this channel each
// state transition, sending new adds/settles which are fully locked
// in.
htlcPlex chan *plexPacket
// chanCloseRequests is used to transfer the channel close request to
// the channel close handler.
chanCloseRequests chan *ChanClose
// resolutionMsgs is the channel that all external contract resolution
// messages will be sent over.
resolutionMsgs chan *resolutionMsg
// pendingFwdingEvents is the set of forwarding events which have been
// collected during the current interval, but hasn't yet been written
// to the forwarding log.
fwdEventMtx sync.Mutex
pendingFwdingEvents []channeldb.ForwardingEvent
// blockEpochStream is an active block epoch event stream backed by an
// active ChainNotifier instance. This will be used to retrieve the
// latest height of the chain.
blockEpochStream *chainntnfs.BlockEpochEvent
// pendingSettleFails is the set of settle/fail entries that we need to
// ack in the forwarding package of the outgoing link. This was added to
// make pipelining settles more efficient.
pendingSettleFails []channeldb.SettleFailRef
// resMsgStore is used to store the set of ResolutionMsg that come from
// contractcourt. This is used so the Switch can properly forward them,
// even on restarts.
resMsgStore *resolutionStore
// aliasToReal is a map used for option-scid-alias feature-bit links.
// The alias SCID is the key and the real, confirmed SCID is the value.
// If the channel is unconfirmed, there will not be a mapping for it.
// Since channels can have multiple aliases, this map is essentially a
// N->1 mapping for a channel. This MUST be accessed with the indexMtx.
aliasToReal map[lnwire.ShortChannelID]lnwire.ShortChannelID
// baseIndex is a map used for option-scid-alias feature-bit links.
// The value is the SCID of the link's ShortChannelID. This value may
// be an alias for zero-conf channels or a confirmed SCID for
// non-zero-conf channels with the option-scid-alias feature-bit. The
// key includes the value itself and also any other aliases. This MUST
// be accessed with the indexMtx.
baseIndex map[lnwire.ShortChannelID]lnwire.ShortChannelID
}
// New creates the new instance of htlc switch.
func New(cfg Config, currentHeight uint32) (*Switch, error) {
resStore := newResolutionStore(cfg.DB)
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB,
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
FetchClosedChannels: cfg.FetchClosedChannels,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
CheckResolutionMsg: resStore.checkResolutionMsg,
})
if err != nil {
return nil, err
}
s := &Switch{
bestHeight: currentHeight,
cfg: &cfg,
circuits: circuitMap,
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink),
pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink),
linkStopIndex: make(map[lnwire.ChannelID]chan struct{}),
networkResults: newNetworkResultStore(cfg.DB),
htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg),
resMsgStore: resStore,
quit: make(chan struct{}),
}
s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
s.baseIndex = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
s.mailOrchestrator = newMailOrchestrator(&mailOrchConfig{
forwardPackets: s.ForwardPackets,
clock: s.cfg.Clock,
expiry: s.cfg.MailboxDeliveryTimeout,
failMailboxUpdate: s.failMailboxUpdate,
})
return s, nil
}
// resolutionMsg is a struct that wraps an existing ResolutionMsg with a done
// channel. We'll use this channel to synchronize delivery of the message with
// the caller.
type resolutionMsg struct {
contractcourt.ResolutionMsg
errChan chan error
}
// ProcessContractResolution is called by active contract resolvers once a
// contract they are watching over has been fully resolved. The message carries
// an external signal that *would* have been sent if the outgoing channel
// didn't need to go to the chain in order to fulfill a contract. We'll process
// this message just as if it came from an active outgoing channel.
func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) error {
errChan := make(chan error, 1)
select {
case s.resolutionMsgs <- &resolutionMsg{
ResolutionMsg: msg,
errChan: errChan,
}:
case <-s.quit:
return ErrSwitchExiting
}
select {
case err := <-errChan:
return err
case <-s.quit:
return ErrSwitchExiting
}
}
// HasAttemptResult reads the network result store to fetch the specified
// attempt. Returns true if the attempt result exists.
func (s *Switch) HasAttemptResult(attemptID uint64) (bool, error) {
_, err := s.networkResults.getResult(attemptID)
if err == nil {
return true, nil
}
if !errors.Is(err, ErrPaymentIDNotFound) {
return false, err
}
return false, nil
}
// GetAttemptResult returns the result of the HTLC attempt with the given
// attemptID. The paymentHash should be set to the payment's overall hash, or
// in case of AMP payments the payment's unique identifier.
//
// The method returns a channel where the HTLC attempt result will be sent when
// available, or an error is encountered during forwarding. When a result is
// received on the channel, the HTLC is guaranteed to no longer be in flight.
// The switch shutting down is signaled by closing the channel. If the
// attemptID is unknown, ErrPaymentIDNotFound will be returned.
func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
var (
nChan <-chan *networkResult
err error
inKey = CircuitKey{
ChanID: hop.Source,
HtlcID: attemptID,
}
)
// If the HTLC is not found in the circuit map, check whether a result
// is already available.
// Assumption: no one will add this attempt ID other than the caller.
if s.circuits.LookupCircuit(inKey) == nil {
res, err := s.networkResults.getResult(attemptID)
if err != nil {
return nil, err
}
c := make(chan *networkResult, 1)
c <- res
nChan = c
} else {
// The HTLC was committed to the circuits, subscribe for a
// result.
nChan, err = s.networkResults.subscribeResult(attemptID)
if err != nil {
return nil, err
}
}
resultChan := make(chan *PaymentResult, 1)
// Since the attempt was known, we can start a goroutine that can
// extract the result when it is available, and pass it on to the
// caller.
s.wg.Add(1)
go func() {
defer s.wg.Done()
var n *networkResult
select {
case n = <-nChan:
case <-s.quit:
// We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is
// still in flight.
close(resultChan)
return
}
log.Debugf("Received network result %T for attemptID=%v", n.msg,
attemptID)
// Extract the result and pass it to the result channel.
result, err := s.extractResult(
deobfuscator, n, attemptID, paymentHash,
)
if err != nil {
e := fmt.Errorf("unable to extract result: %w", err)
log.Error(e)
resultChan <- &PaymentResult{
Error: e,
}
return
}
resultChan <- result
}()
return resultChan, nil
}
// CleanStore calls the underlying result store, telling it is safe to delete
// all entries except the ones in the keepPids map. This should be called
// preiodically to let the switch clean up payment results that we have
// handled.
func (s *Switch) CleanStore(keepPids map[uint64]struct{}) error {
return s.networkResults.cleanStore(keepPids)
}
// SendHTLC is used by other subsystems which aren't belong to htlc switch
// package in order to send the htlc update. The attemptID used MUST be unique
// for this HTLC, and MUST be used only once, otherwise the switch might reject
// it.
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, attemptID uint64,
htlc *lnwire.UpdateAddHTLC) error {
// Generate and send new update packet, if error will be received on
// this stage it means that packet haven't left boundaries of our
// system and something wrong happened.
packet := &htlcPacket{
incomingChanID: hop.Source,
incomingHTLCID: attemptID,
outgoingChanID: firstHop,
htlc: htlc,
amount: htlc.Amount,
}
// Attempt to fetch the target link before creating a circuit so that
// we don't leave dangling circuits. The getLocalLink method does not
// require the circuit variable to be set on the *htlcPacket.
link, linkErr := s.getLocalLink(packet, htlc)
if linkErr != nil {
// Notify the htlc notifier of a link failure on our outgoing
// link. Incoming timelock/amount values are not set because
// they are not present for local sends.
s.cfg.HtlcNotifier.NotifyLinkFailEvent(
newHtlcKey(packet),
HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventTypeSend,
linkErr,
false,
)
return linkErr
}
// Evaluate whether this HTLC would bypass our fee exposure. If it
// does, don't send it out and instead return an error.
if s.dustExceedsFeeThreshold(link, htlc.Amount, false) {
// Notify the htlc notifier of a link failure on our outgoing
// link. We use the FailTemporaryChannelFailure in place of a
// more descriptive error message.
linkErr := NewLinkError(
&lnwire.FailTemporaryChannelFailure{},
)
s.cfg.HtlcNotifier.NotifyLinkFailEvent(
newHtlcKey(packet),
HtlcInfo{
OutgoingTimeLock: htlc.Expiry,
OutgoingAmt: htlc.Amount,
},
HtlcEventTypeSend,
linkErr,
false,
)
return errFeeExposureExceeded
}
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
actions, err := s.circuits.CommitCircuits(circuit)
if err != nil {
log.Errorf("unable to commit circuit in switch: %v", err)
return err
}
// Drop duplicate packet if it has already been seen.
switch {
case len(actions.Drops) == 1:
return ErrDuplicateAdd
case len(actions.Fails) == 1:
return ErrLocalAddFailed
}
// Give the packet to the link's mailbox so that HTLC's are properly
// canceled back if the mailbox timeout elapses.
packet.circuit = circuit
return link.handleSwitchPacket(packet)
}
// UpdateForwardingPolicies sends a message to the switch to update the
// forwarding policies for the set of target channels, keyed in chanPolicies.
//
// NOTE: This function is synchronous and will block until either the
// forwarding policies for all links have been updated, or the switch shuts
// down.
func (s *Switch) UpdateForwardingPolicies(
chanPolicies map[wire.OutPoint]models.ForwardingPolicy) {
log.Tracef("Updating link policies: %v", lnutils.SpewLogClosure(
chanPolicies))
s.indexMtx.RLock()
// Update each link in chanPolicies.
for targetLink, policy := range chanPolicies {
cid := lnwire.NewChanIDFromOutPoint(targetLink)
link, ok := s.linkIndex[cid]
if !ok {
log.Debugf("Unable to find ChannelPoint(%v) to update "+
"link policy", targetLink)
continue
}
link.UpdateForwardingPolicy(policy)
}
s.indexMtx.RUnlock()
}
// IsForwardedHTLC checks for a given channel and htlc index if it is related
// to an opened circuit that represents a forwarded payment.
func (s *Switch) IsForwardedHTLC(chanID lnwire.ShortChannelID,
htlcIndex uint64) bool {
circuit := s.circuits.LookupOpenCircuit(models.CircuitKey{
ChanID: chanID,
HtlcID: htlcIndex,
})
return circuit != nil && circuit.Incoming.ChanID != hop.Source
}
// ForwardPackets adds a list of packets to the switch for processing. Fails
// and settles are added on a first past, simultaneously constructing circuits
// for any adds. After persisting the circuits, another pass of the adds is
// given to forward them through the router. The sending link's quit channel is
// used to prevent deadlocks when the switch stops a link in the midst of
// forwarding.
func (s *Switch) ForwardPackets(linkQuit <-chan struct{},
packets ...*htlcPacket) error {
var (
// fwdChan is a buffered channel used to receive err msgs from
// the htlcPlex when forwarding this batch.
fwdChan = make(chan error, len(packets))
// numSent keeps a running count of how many packets are
// forwarded to the switch, which determines how many responses
// we will wait for on the fwdChan..
numSent int
)
// No packets, nothing to do.
if len(packets) == 0 {
return nil
}
// Setup a barrier to prevent the background tasks from processing
// responses until this function returns to the user.
var wg sync.WaitGroup
wg.Add(1)
defer wg.Done()
// Before spawning the following goroutine to proxy our error responses,
// check to see if we have already been issued a shutdown request. If
// so, we exit early to avoid incrementing the switch's waitgroup while
// it is already in the process of shutting down.
select {
case <-linkQuit:
return nil
case <-s.quit:
return nil
default:
// Spawn a goroutine to log the errors returned from failed packets.
s.wg.Add(1)
go s.logFwdErrs(&numSent, &wg, fwdChan)
}
// Make a first pass over the packets, forwarding any settles or fails.
// As adds are found, we create a circuit and append it to our set of
// circuits to be written to disk.
var circuits []*PaymentCircuit
var addBatch []*htlcPacket
for _, packet := range packets {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
packet.circuit = circuit
circuits = append(circuits, circuit)
addBatch = append(addBatch, packet)
default:
err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return fmt.Errorf("failed to forward packet %w",
err)
}
numSent++
}
}
// If this batch did not contain any circuits to commit, we can return
// early.
if len(circuits) == 0 {
return nil
}
// Write any circuits that we found to disk.
actions, err := s.circuits.CommitCircuits(circuits...)
if err != nil {
log.Errorf("unable to commit circuits in switch: %v", err)
}
// Split the htlc packets by comparing an in-order seek to the head of
// the added, dropped, or failed circuits.
//
// NOTE: This assumes each list is guaranteed to be a subsequence of the
// circuits, and that the union of the sets results in the original set
// of circuits.
var addedPackets, failedPackets []*htlcPacket
for _, packet := range addBatch {
switch {
case len(actions.Adds) > 0 && packet.circuit == actions.Adds[0]:
addedPackets = append(addedPackets, packet)
actions.Adds = actions.Adds[1:]
case len(actions.Drops) > 0 && packet.circuit == actions.Drops[0]:
actions.Drops = actions.Drops[1:]
case len(actions.Fails) > 0 && packet.circuit == actions.Fails[0]:
failedPackets = append(failedPackets, packet)
actions.Fails = actions.Fails[1:]
}
}
// Now, forward any packets for circuits that were successfully added to
// the switch's circuit map.
for _, packet := range addedPackets {
err := s.routeAsync(packet, fwdChan, linkQuit)
if err != nil {
return fmt.Errorf("failed to forward packet %w", err)
}
numSent++
}
// Lastly, for any packets that failed, this implies that they were
// left in a half added state, which can happen when recovering from
// failures.
if len(failedPackets) > 0 {
var failure lnwire.FailureMessage
incomingID := failedPackets[0].incomingChanID
// If the incoming channel is an option_scid_alias channel,
// then we'll need to replace the SCID in the ChannelUpdate.
update := s.failAliasUpdate(incomingID, true)
if update == nil {
// Fallback to the original non-option behavior.
update, err := s.cfg.FetchLastChannelUpdate(
incomingID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
} else {
// This is an option_scid_alias channel.
failure = lnwire.NewTemporaryChannelFailure(update)
}
linkError := NewDetailedLinkError(
failure, OutgoingFailureIncompleteForward,
)
for _, packet := range failedPackets {
// We don't handle the error here since this method
// always returns an error.
_ = s.failAddPacket(packet, linkError)
}
}
return nil
}
// logFwdErrs logs any errors received on `fwdChan`.
func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
defer s.wg.Done()
// Wait here until the outer function has finished persisting
// and routing the packets. This guarantees we don't read from num until
// the value is accurate.
wg.Wait()
numSent := *num
for i := 0; i < numSent; i++ {
select {
case err := <-fwdChan:
if err != nil {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
case <-s.quit:
log.Errorf("unable to forward htlc packet " +
"htlc switch was stopped")
return
}
}
}
// routeAsync sends a packet through the htlc switch, using the provided err
// chan to propagate errors back to the caller. The link's quit channel is
// provided so that the send can be canceled if either the link or the switch
// receive a shutdown requuest. This method does not wait for a response from
// the htlcForwarder before returning.
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
linkQuit <-chan struct{}) error {
command := &plexPacket{
pkt: packet,
err: errChan,
}
select {
case s.htlcPlex <- command:
return nil
case <-linkQuit:
return ErrLinkShuttingDown
case <-s.quit:
return errors.New("htlc switch was stopped")
}
}
// getLocalLink handles the addition of a htlc for a send that originates from
// our node. It returns the link that the htlc should be forwarded outwards on,
// and a link error if the htlc cannot be forwarded.
func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) (
ChannelLink, *LinkError) {
// Try to find links by node destination.
s.indexMtx.RLock()
link, err := s.getLinkByShortID(pkt.outgoingChanID)
defer s.indexMtx.RUnlock()
if err != nil {
// If the link was not found for the outgoingChanID, an outside
// subsystem may be using the confirmed SCID of a zero-conf
// channel. In this case, we'll consult the Switch maps to see
// if an alias exists and use the alias to lookup the link.
// This extra step is a consequence of not updating the Switch
// forwardingIndex when a zero-conf channel is confirmed. We
// don't need to change the outgoingChanID since the link will
// do that upon receiving the packet.
baseScid, ok := s.baseIndex[pkt.outgoingChanID]
if !ok {
log.Errorf("Link %v not found", pkt.outgoingChanID)
return nil, NewLinkError(&lnwire.FailUnknownNextPeer{})
}
// The base SCID was found, so we'll use that to fetch the
// link.
link, err = s.getLinkByShortID(baseScid)
if err != nil {
log.Errorf("Link %v not found", baseScid)
return nil, NewLinkError(&lnwire.FailUnknownNextPeer{})
}
}
if !link.EligibleToForward() {
log.Errorf("Link %v is not available to forward",
pkt.outgoingChanID)
// The update does not need to be populated as the error
// will be returned back to the router.
return nil, NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureLinkNotEligible,
)
}
// Ensure that the htlc satisfies the outgoing channel policy.
currentHeight := atomic.LoadUint32(&s.bestHeight)
htlcErr := link.CheckHtlcTransit(
htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight,
htlc.CustomRecords,
)
if htlcErr != nil {
log.Errorf("Link %v policy for local forward not "+
"satisfied", pkt.outgoingChanID)
return nil, htlcErr
}
return link, nil
}
// handleLocalResponse processes a Settle or Fail responding to a
// locally-initiated payment. This is handled asynchronously to avoid blocking
// the main event loop within the switch, as these operations can require
// multiple db transactions. The guarantees of the circuit map are stringent
// enough such that we are able to tolerate reordering of these operations
// without side effects. The primary operations handled are:
// 1. Save the payment result to the pending payment store.
// 2. Notify subscribers about the payment result.
// 3. Ack settle/fail references, to avoid resending this response internally
// 4. Teardown the closing circuit in the circuit map
//
// NOTE: This method MUST be spawned as a goroutine.
func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done()
attemptID := pkt.incomingHTLCID
// The error reason will be unencypted in case this a local
// failure or a converted error.
unencrypted := pkt.localFailure || pkt.convertedError
n := &networkResult{
msg: pkt.htlc,
unencrypted: unencrypted,
isResolution: pkt.isResolution,
}
// Store the result to the db. This will also notify subscribers about
// the result.
if err := s.networkResults.storeResult(attemptID, n); err != nil {
log.Errorf("Unable to store attempt result for pid=%v: %v",
attemptID, err)
return
}
// First, we'll clean up any fwdpkg references, circuit entries, and
// mark in our db that the payment for this payment hash has either
// succeeded or failed.
//
// If this response is contained in a forwarding package, we'll start by
// acking the settle/fail so that we don't continue to retransmit the
// HTLC internally.
if pkt.destRef != nil {
if err := s.ackSettleFail(*pkt.destRef); err != nil {
log.Warnf("Unable to ack settle/fail reference: %s: %v",
*pkt.destRef, err)
return
}
}
// Next, we'll remove the circuit since we are about to complete an
// fulfill/fail of this HTLC. Since we've already removed the
// settle/fail fwdpkg reference, the response from the peer cannot be
// replayed internally if this step fails. If this happens, this logic
// will be executed when a provided resolution message comes through.
// This can only happen if the circuit is still open, which is why this
// ordering is chosen.
if err := s.teardownCircuit(pkt); err != nil {
log.Errorf("Unable to teardown circuit %s: %v",
pkt.inKey(), err)
return
}
// Finally, notify on the htlc failure or success that has been handled.
key := newHtlcKey(pkt)
eventType := getEventType(pkt)
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateFulfillHTLC:
s.cfg.HtlcNotifier.NotifySettleEvent(key, htlc.PaymentPreimage,
eventType)
case *lnwire.UpdateFailHTLC:
s.cfg.HtlcNotifier.NotifyForwardingFailEvent(key, eventType)
}
}
// extractResult uses the given deobfuscator to extract the payment result from
// the given network message.
func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
attemptID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) {
switch htlc := n.msg.(type) {
// We've received a settle update which means we can finalize the user
// payment and return successful response.
case *lnwire.UpdateFulfillHTLC:
return &PaymentResult{
Preimage: htlc.PaymentPreimage,
}, nil
// We've received a fail update which means we can finalize the
// user payment and return fail response.
case *lnwire.UpdateFailHTLC:
// TODO(yy): construct deobfuscator here to avoid creating it
// in paymentLifecycle even for settled HTLCs.
paymentErr := s.parseFailedPayment(
deobfuscator, attemptID, paymentHash, n.unencrypted,
n.isResolution, htlc,
)
return &PaymentResult{
Error: paymentErr,
}, nil
default:
return nil, fmt.Errorf("received unknown response type: %T",
htlc)
}
}
// parseFailedPayment determines the appropriate failure message to return to
// a user initiated payment. The three cases handled are:
// 1. An unencrypted failure, which should already plaintext.
// 2. A resolution from the chain arbitrator, which possibly has no failure
// reason attached.
// 3. A failure from the remote party, which will need to be decrypted using
// the payment deobfuscator.
func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
attemptID uint64, paymentHash lntypes.Hash, unencrypted,
isResolution bool, htlc *lnwire.UpdateFailHTLC) error {
switch {
// The payment never cleared the link, so we don't need to
// decrypt the error, simply decode it them report back to the
// user.
case unencrypted:
r := bytes.NewReader(htlc.Reason)
failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil {
// If we could not decode the failure reason, return a link
// error indicating that we failed to decode the onion.
linkError := NewDetailedLinkError(
// As this didn't even clear the link, we don't
// need to apply an update here since it goes
// directly to the router.
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureDecodeError,
)
log.Errorf("%v: (hash=%v, pid=%d): %v",
linkError.FailureDetail.FailureString(),
paymentHash, attemptID, err)
return linkError
}
// If we successfully decoded the failure reason, return it.
return NewLinkError(failureMsg)
// A payment had to be timed out on chain before it got past
// the first hop. In this case, we'll report a permanent
// channel failure as this means us, or the remote party had to
// go on chain.
case isResolution && htlc.Reason == nil:
linkError := NewDetailedLinkError(
&lnwire.FailPermanentChannelFailure{},
OutgoingFailureOnChainTimeout,
)
log.Infof("%v: hash=%v, pid=%d",
linkError.FailureDetail.FailureString(),
paymentHash, attemptID)
return linkError
// A regular multi-hop payment error that we'll need to
// decrypt.
default:
// We'll attempt to fully decrypt the onion encrypted
// error. If we're unable to then we'll bail early.
failure, err := deobfuscator.DecryptError(htlc.Reason)
if err != nil {
log.Errorf("unable to de-obfuscate onion failure "+
"(hash=%v, pid=%d): %v",
paymentHash, attemptID, err)
return ErrUnreadableFailureMessage
}
return failure
}
}
// handlePacketForward is used in cases when we need forward the htlc update
// from one channel link to another and be able to propagate the settle/fail
// updates back. This behaviour is achieved by creation of payment circuits.
func (s *Switch) handlePacketForward(packet *htlcPacket) error {
switch htlc := packet.htlc.(type) {
// Channel link forwarded us a new htlc, therefore we initiate the
// payment circuit within our internal state so we can properly forward
// the ultimate settle message back latter.
case *lnwire.UpdateAddHTLC:
return s.handlePacketAdd(packet, htlc)
case *lnwire.UpdateFulfillHTLC:
return s.handlePacketSettle(packet)
// Channel link forwarded us an update_fail_htlc message.
//
// NOTE: when the channel link receives an update_fail_malformed_htlc
// from upstream, it will convert the message into update_fail_htlc and
// forward it. Thus there's no need to catch `UpdateFailMalformedHTLC`
// here.
case *lnwire.UpdateFailHTLC:
return s.handlePacketFail(packet, htlc)
default:
return fmt.Errorf("wrong update type: %T", htlc)
}
}
// checkCircularForward checks whether a forward is circular (arrives and
// departs on the same link) and returns a link error if the switch is
// configured to disallow this behaviour.
func (s *Switch) checkCircularForward(incoming, outgoing lnwire.ShortChannelID,
allowCircular bool, paymentHash lntypes.Hash) *LinkError {
log.Tracef("Checking for circular route: incoming=%v, outgoing=%v "+
"(payment hash: %x)", incoming, outgoing, paymentHash[:])
// If they are equal, we can skip the alias mapping checks.
if incoming == outgoing {
// The switch may be configured to allow circular routes, so
// just log and return nil.
if allowCircular {
log.Debugf("allowing circular route over link: %v "+
"(payment hash: %x)", incoming, paymentHash)
return nil
}
// Otherwise, we'll return a temporary channel failure.
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
)
}
// We'll fetch the "base" SCID from the baseIndex for the incoming and
// outgoing SCIDs. If either one does not have a base SCID, then the
// two channels are not equal since one will be a channel that does not
// need a mapping and SCID equality was checked above. If the "base"
// SCIDs are equal, then this is a circular route. Otherwise, it isn't.
s.indexMtx.RLock()
incomingBaseScid, ok := s.baseIndex[incoming]
if !ok {
// This channel does not use baseIndex, bail out.
s.indexMtx.RUnlock()
return nil
}
outgoingBaseScid, ok := s.baseIndex[outgoing]
if !ok {
// This channel does not use baseIndex, bail out.
s.indexMtx.RUnlock()
return nil
}
s.indexMtx.RUnlock()
// Check base SCID equality.
if incomingBaseScid != outgoingBaseScid {
log.Tracef("Incoming base SCID %v does not match outgoing "+
"base SCID %v (payment hash: %x)", incomingBaseScid,
outgoingBaseScid, paymentHash[:])
// The base SCIDs are not equal so these are not the same
// channel.
return nil
}
// If the incoming and outgoing link are equal, the htlc is part of a
// circular route which may be used to lock up our liquidity. If the
// switch is configured to allow circular routes, log that we are
// allowing the route then return nil.
if allowCircular {
log.Debugf("allowing circular route over link: %v "+
"(payment hash: %x)", incoming, paymentHash)
return nil
}
// If our node disallows circular routes, return a temporary channel
// failure. There is nothing wrong with the policy used by the remote
// node, so we do not include a channel update.
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
OutgoingFailureCircularRoute,
)
}
// failAddPacket encrypts a fail packet back to an add packet's source.
// The ciphertext will be derived from the failure message proivded by context.
// This method returns the failErr if all other steps complete successfully.
func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error {
// Encrypt the failure so that the sender will be able to read the error
// message. Since we failed this packet, we use EncryptFirstHop to
// obfuscate the failure for their eyes only.
reason, err := packet.obfuscator.EncryptFirstHop(failure.WireMessage())
if err != nil {
err := fmt.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
log.Error(failure.Error())
// Create a failure packet for this htlc. The full set of
// information about the htlc failure is included so that they can
// be included in link failure notifications.
failPkt := &htlcPacket{
sourceRef: packet.sourceRef,
incomingChanID: packet.incomingChanID,
incomingHTLCID: packet.incomingHTLCID,
outgoingChanID: packet.outgoingChanID,
outgoingHTLCID: packet.outgoingHTLCID,
incomingAmount: packet.incomingAmount,
amount: packet.amount,
incomingTimeout: packet.incomingTimeout,
outgoingTimeout: packet.outgoingTimeout,
circuit: packet.circuit,
obfuscator: packet.obfuscator,
linkFailure: failure,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
// Route a fail packet back to the source link.
err = s.mailOrchestrator.Deliver(failPkt.incomingChanID, failPkt)
if err != nil {
err = fmt.Errorf("source chanid=%v unable to "+
"handle switch packet: %v",
packet.incomingChanID, err)
log.Error(err)
return err
}
return failure
}
// closeCircuit accepts a settle or fail htlc and the associated htlc packet and
// attempts to determine the source that forwarded this htlc. This method will
// set the incoming chan and htlc ID of the given packet if the source was
// found, and will properly [re]encrypt any failure messages.
func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// If the packet has its source, that means it was failed locally by
// the outgoing link. We fail it here to make sure only one response
// makes it through the switch.
if pkt.hasSource {
circuit, err := s.circuits.FailCircuit(pkt.inKey())
switch err {
// Circuit successfully closed.
case nil:
return circuit, nil
// Circuit was previously closed, but has not been deleted.
// We'll just drop this response until the circuit has been
// fully removed.
case ErrCircuitClosing:
return nil, err
// Failed to close circuit because it does not exist. This is
// likely because the circuit was already successfully closed.
// Since this packet failed locally, there is no forwarding
// package entry to acknowledge.
case ErrUnknownCircuit:
return nil, err
// Unexpected error.
default:
return nil, err
}
}
// Otherwise, this is packet was received from the remote party. Use
// circuit map to find the incoming link to receive the settle/fail.
circuit, err := s.circuits.CloseCircuit(pkt.outKey())
switch err {
// Open circuit successfully closed.
case nil:
pkt.incomingChanID = circuit.Incoming.ChanID
pkt.incomingHTLCID = circuit.Incoming.HtlcID
pkt.circuit = circuit
pkt.sourceRef = &circuit.AddRef
pktType := "SETTLE"
if _, ok := pkt.htlc.(*lnwire.UpdateFailHTLC); ok {
pktType = "FAIL"
}
log.Debugf("Closed completed %s circuit for %x: "+
"(%s, %d) <-> (%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID)
return circuit, nil
// Circuit was previously closed, but has not been deleted. We'll just
// drop this response until the circuit has been removed.
case ErrCircuitClosing:
return nil, err
// Failed to close circuit because it does not exist. This is likely
// because the circuit was already successfully closed.
case ErrUnknownCircuit:
if pkt.destRef != nil {
// Add this SettleFailRef to the set of pending settle/fail entries
// awaiting acknowledgement.
s.pendingSettleFails = append(s.pendingSettleFails, *pkt.destRef)
}
// If this is a settle, we will not log an error message as settles
// are expected to hit the ErrUnknownCircuit case. The only way fails
// can hit this case if the link restarts after having just sent a fail
// to the switch.
_, isSettle := pkt.htlc.(*lnwire.UpdateFulfillHTLC)
if !isSettle {
err := fmt.Errorf("unable to find target channel "+
"for HTLC fail: channel ID = %s, "+
"HTLC ID = %d", pkt.outgoingChanID,
pkt.outgoingHTLCID)
log.Error(err)
return nil, err
}
return nil, nil
// Unexpected error.
default:
return nil, err
}
}
// ackSettleFail is used by the switch to ACK any settle/fail entries in the
// forwarding package of the outgoing link for a payment circuit. We do this if
// we're the originator of the payment, so the link stops attempting to
// re-broadcast.
func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error {
return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...)
})
}
// teardownCircuit removes a pending or open circuit from the switch's circuit
// map and prints useful logging statements regarding the outcome.
func (s *Switch) teardownCircuit(pkt *htlcPacket) error {
var pktType string
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateFulfillHTLC:
pktType = "SETTLE"
case *lnwire.UpdateFailHTLC:
pktType = "FAIL"
default:
return fmt.Errorf("cannot tear down packet of type: %T", htlc)
}
var paymentHash lntypes.Hash
// Perform a defensive check to make sure we don't try to access a nil
// circuit.
circuit := pkt.circuit
if circuit != nil {
copy(paymentHash[:], circuit.PaymentHash[:])
}
log.Debugf("Tearing down circuit with %s pkt, removing circuit=%v "+
"with keystone=%v", pktType, pkt.inKey(), pkt.outKey())
err := s.circuits.DeleteCircuits(pkt.inKey())
if err != nil {
log.Warnf("Failed to tear down circuit (%s, %d) <-> (%s, %d) "+
"with payment_hash=%v using %s pkt", pkt.incomingChanID,
pkt.incomingHTLCID, pkt.outgoingChanID,
pkt.outgoingHTLCID, pkt.circuit.PaymentHash, pktType)
return err
}
log.Debugf("Closed %s circuit for %v: (%s, %d) <-> (%s, %d)", pktType,
paymentHash, pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID)
return nil
}
// CloseLink creates and sends the close channel command to the target link
// directing the specified closure type. If the closure type is CloseRegular,
// targetFeePerKw parameter should be the ideal fee-per-kw that will be used as
// a starting point for close negotiation. The deliveryScript parameter is an
// optional parameter which sets a user specified script to close out to.
func (s *Switch) CloseLink(ctx context.Context, chanPoint *wire.OutPoint,
closeType contractcourt.ChannelCloseType,
targetFeePerKw, maxFee chainfee.SatPerKWeight,
deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) {
// TODO(roasbeef) abstract out the close updates.
updateChan := make(chan interface{}, 2)
errChan := make(chan error, 1)
command := &ChanClose{
CloseType: closeType,
ChanPoint: chanPoint,
Updates: updateChan,
TargetFeePerKw: targetFeePerKw,
DeliveryScript: deliveryScript,
Err: errChan,
MaxFee: maxFee,
Ctx: ctx,
}
select {
case s.chanCloseRequests <- command:
return updateChan, errChan
case <-s.quit:
errChan <- ErrSwitchExiting
close(updateChan)
return updateChan, errChan
}
}
// htlcForwarder is responsible for optimally forwarding (and possibly
// fragmenting) incoming/outgoing HTLCs amongst all active interfaces and their
// links. The duties of the forwarder are similar to that of a network switch,
// in that it facilitates multi-hop payments by acting as a central messaging
// bus. The switch communicates will active links to create, manage, and tear
// down active onion routed payments. Each active channel is modeled as
// networked device with metadata such as the available payment bandwidth, and
// total link capacity.
//
// NOTE: This MUST be run as a goroutine.
func (s *Switch) htlcForwarder() {
defer s.wg.Done()
defer func() {
s.blockEpochStream.Cancel()
// Remove all links once we've been signalled for shutdown.
var linksToStop []ChannelLink
s.indexMtx.Lock()
for _, link := range s.linkIndex {
activeLink := s.removeLink(link.ChanID())
if activeLink == nil {
log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
}
linksToStop = append(linksToStop, activeLink)
}
for _, link := range s.pendingLinkIndex {
pendingLink := s.removeLink(link.ChanID())
if pendingLink == nil {
log.Errorf("unable to remove ChannelLink(%v) "+
"on stop", link.ChanID())
continue
}
linksToStop = append(linksToStop, pendingLink)
}
s.indexMtx.Unlock()
// Now that all pending and live links have been removed from
// the forwarding indexes, stop each one before shutting down.
// We'll shut them down in parallel to make exiting as fast as
// possible.
var wg sync.WaitGroup
for _, link := range linksToStop {
wg.Add(1)
go func(l ChannelLink) {
defer wg.Done()
l.Stop()
}(link)
}
wg.Wait()
// Before we exit fully, we'll attempt to flush out any
// forwarding events that may still be lingering since the last
// batch flush.
if err := s.FlushForwardingEvents(); err != nil {
log.Errorf("unable to flush forwarding events: %v", err)
}
}()
// TODO(roasbeef): cleared vs settled distinction
var (
totalNumUpdates uint64
totalSatSent btcutil.Amount
totalSatRecv btcutil.Amount
)
s.cfg.LogEventTicker.Resume()
defer s.cfg.LogEventTicker.Stop()
// Every 15 seconds, we'll flush out the forwarding events that
// occurred during that period.
s.cfg.FwdEventTicker.Resume()
defer s.cfg.FwdEventTicker.Stop()
defer s.cfg.AckEventTicker.Stop()
out:
for {
// If the set of pending settle/fail entries is non-zero,
// reinstate the ack ticker so we can batch ack them.
if len(s.pendingSettleFails) > 0 {
s.cfg.AckEventTicker.Resume()
}
select {
case blockEpoch, ok := <-s.blockEpochStream.Epochs:
if !ok {
break out
}
atomic.StoreUint32(&s.bestHeight, uint32(blockEpoch.Height))
// A local close request has arrived, we'll forward this to the
// relevant link (if it exists) so the channel can be
// cooperatively closed (if possible).
case req := <-s.chanCloseRequests:
chanID := lnwire.NewChanIDFromOutPoint(*req.ChanPoint)
s.indexMtx.RLock()
link, ok := s.linkIndex[chanID]
if !ok {
s.indexMtx.RUnlock()
req.Err <- fmt.Errorf("no peer for channel with "+
"chan_id=%x", chanID[:])
continue
}
s.indexMtx.RUnlock()
peerPub := link.PeerPubKey()
log.Debugf("Requesting local channel close: peer=%x, "+
"chan_id=%x", link.PeerPubKey(), chanID[:])
go s.cfg.LocalChannelClose(peerPub[:], req)
case resolutionMsg := <-s.resolutionMsgs:
// We'll persist the resolution message to the Switch's
// resolution store.
resMsg := resolutionMsg.ResolutionMsg
err := s.resMsgStore.addResolutionMsg(&resMsg)
if err != nil {
// This will only fail if there is a database
// error or a serialization error. Sending the
// error prevents the contractcourt from being
// in a state where it believes the send was
// successful, when it wasn't.
log.Errorf("unable to add resolution msg: %v",
err)
resolutionMsg.errChan <- err
continue
}
// At this point, the resolution message has been
// persisted. It is safe to signal success by sending
// a nil error since the Switch will re-deliver the
// resolution message on restart.
resolutionMsg.errChan <- nil
// Create a htlc packet for this resolution. We do
// not have some of the information that we'll need
// for blinded error handling here , so we'll rely on
// our forwarding logic to fill it in later.
pkt := &htlcPacket{
outgoingChanID: resolutionMsg.SourceChan,
outgoingHTLCID: resolutionMsg.HtlcIndex,
isResolution: true,
}
// Resolution messages will either be cancelling
// backwards an existing HTLC, or settling a previously
// outgoing HTLC. Based on this, we'll map the message
// to the proper htlcPacket.
if resolutionMsg.Failure != nil {
pkt.htlc = &lnwire.UpdateFailHTLC{}
} else {
pkt.htlc = &lnwire.UpdateFulfillHTLC{
PaymentPreimage: *resolutionMsg.PreImage,
}
}
log.Debugf("Received outside contract resolution, "+
"mapping to: %v", spew.Sdump(pkt))
// We don't check the error, as the only failure we can
// encounter is due to the circuit already being
// closed. This is fine, as processing this message is
// meant to be idempotent.
err = s.handlePacketForward(pkt)
if err != nil {
log.Errorf("Unable to forward resolution msg: %v", err)
}
// A new packet has arrived for forwarding, we'll interpret the
// packet concretely, then either forward it along, or
// interpret a return packet to a locally initialized one.
case cmd := <-s.htlcPlex:
cmd.err <- s.handlePacketForward(cmd.pkt)
// When this time ticks, then it indicates that we should
// collect all the forwarding events since the last internal,
// and write them out to our log.
case <-s.cfg.FwdEventTicker.Ticks():
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.FlushForwardingEvents(); err != nil {
log.Errorf("Unable to flush "+
"forwarding events: %v", err)
}
}()
// The log ticker has fired, so we'll calculate some forwarding
// stats for the last 10 seconds to display within the logs to
// users.
case <-s.cfg.LogEventTicker.Ticks():
// First, we'll collate the current running tally of
// our forwarding stats.
prevSatSent := totalSatSent
prevSatRecv := totalSatRecv
prevNumUpdates := totalNumUpdates
var (
newNumUpdates uint64
newSatSent btcutil.Amount
newSatRecv btcutil.Amount
)
// Next, we'll run through all the registered links and
// compute their up-to-date forwarding stats.
s.indexMtx.RLock()
for _, link := range s.linkIndex {
// TODO(roasbeef): when links first registered
// stats printed.
updates, sent, recv := link.Stats()
newNumUpdates += updates
newSatSent += sent.ToSatoshis()
newSatRecv += recv.ToSatoshis()
}
s.indexMtx.RUnlock()
var (
diffNumUpdates uint64
diffSatSent btcutil.Amount
diffSatRecv btcutil.Amount
)
// If this is the first time we're computing these
// stats, then the diff is just the new value. We do
// this in order to avoid integer underflow issues.
if prevNumUpdates == 0 {
diffNumUpdates = newNumUpdates
diffSatSent = newSatSent
diffSatRecv = newSatRecv
} else {
diffNumUpdates = newNumUpdates - prevNumUpdates
diffSatSent = newSatSent - prevSatSent
diffSatRecv = newSatRecv - prevSatRecv
}
// If the diff of num updates is zero, then we haven't
// forwarded anything in the last 10 seconds, so we can
// skip this update.
if diffNumUpdates == 0 {
continue
}
// If the diff of num updates is negative, then some
// links may have been unregistered from the switch, so
// we'll update our stats to only include our registered
// links.
if int64(diffNumUpdates) < 0 {
totalNumUpdates = newNumUpdates
totalSatSent = newSatSent
totalSatRecv = newSatRecv
continue
}
// Otherwise, we'll log this diff, then accumulate the
// new stats into the running total.
log.Debugf("Sent %d satoshis and received %d satoshis "+
"in the last 10 seconds (%f tx/sec)",
diffSatSent, diffSatRecv,
float64(diffNumUpdates)/10)
totalNumUpdates += diffNumUpdates
totalSatSent += diffSatSent
totalSatRecv += diffSatRecv
// The ack ticker has fired so if we have any settle/fail entries
// for a forwarding package to ack, we will do so here in a batch
// db call.
case <-s.cfg.AckEventTicker.Ticks():
// If the current set is empty, pause the ticker.
if len(s.pendingSettleFails) == 0 {
s.cfg.AckEventTicker.Pause()
continue
}
// Batch ack the settle/fail entries.
if err := s.ackSettleFail(s.pendingSettleFails...); err != nil {
log.Errorf("Unable to ack batch of settle/fails: %v", err)
continue
}
log.Tracef("Acked %d settle fails: %v",
len(s.pendingSettleFails),
lnutils.SpewLogClosure(s.pendingSettleFails))
// Reset the pendingSettleFails buffer while keeping acquired
// memory.
s.pendingSettleFails = s.pendingSettleFails[:0]
case <-s.quit:
return
}
}
}
// Start starts all helper goroutines required for the operation of the switch.
func (s *Switch) Start() error {
if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
log.Warn("Htlc Switch already started")
return errors.New("htlc switch already started")
}
log.Infof("HTLC Switch starting")
blockEpochStream, err := s.cfg.Notifier.RegisterBlockEpochNtfn(nil)
if err != nil {
return err
}
s.blockEpochStream = blockEpochStream
s.wg.Add(1)
go s.htlcForwarder()
if err := s.reforwardResponses(); err != nil {
s.Stop()
log.Errorf("unable to reforward responses: %v", err)
return err
}
if err := s.reforwardResolutions(); err != nil {
// We are already stopping so we can ignore the error.
_ = s.Stop()
log.Errorf("unable to reforward resolutions: %v", err)
return err
}
return nil
}
// reforwardResolutions fetches the set of resolution messages stored on-disk
// and reforwards them if their circuits are still open. If the circuits have
// been deleted, then we will delete the resolution message from the database.
func (s *Switch) reforwardResolutions() error {
// Fetch all stored resolution messages, deleting the ones that are
// resolved.
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
if err != nil {
return err
}
switchPackets := make([]*htlcPacket, 0, len(resMsgs))
for _, resMsg := range resMsgs {
// If the open circuit no longer exists, then we can remove the
// message from the store.
outKey := CircuitKey{
ChanID: resMsg.SourceChan,
HtlcID: resMsg.HtlcIndex,
}
if s.circuits.LookupOpenCircuit(outKey) == nil {
// The open circuit doesn't exist.
err := s.resMsgStore.deleteResolutionMsg(&outKey)
if err != nil {
return err
}
continue
}
// The circuit is still open, so we can assume that the link or
// switch (if we are the source) hasn't cleaned it up yet.
// We rely on our forwarding logic to fill in details that
// are not currently available to us.
resPkt := &htlcPacket{
outgoingChanID: resMsg.SourceChan,
outgoingHTLCID: resMsg.HtlcIndex,
isResolution: true,
}
if resMsg.Failure != nil {
resPkt.htlc = &lnwire.UpdateFailHTLC{}
} else {
resPkt.htlc = &lnwire.UpdateFulfillHTLC{
PaymentPreimage: *resMsg.PreImage,
}
}
switchPackets = append(switchPackets, resPkt)
}
// We'll now dispatch the set of resolution messages to the proper
// destination. An error is only encountered here if the switch is
// shutting down.
if err := s.ForwardPackets(nil, switchPackets...); err != nil {
return err
}
return nil
}
// reforwardResponses for every known, non-pending channel, loads all associated
// forwarding packages and reforwards any Settle or Fail HTLCs found. This is
// used to resurrect the switch's mailboxes after a restart. This also runs for
// waiting close channels since there may be settles or fails that need to be
// reforwarded before they completely close.
func (s *Switch) reforwardResponses() error {
openChannels, err := s.cfg.FetchAllChannels()
if err != nil {
return err
}
for _, openChannel := range openChannels {
shortChanID := openChannel.ShortChanID()
// Locally-initiated payments never need reforwarding.
if shortChanID == hop.Source {
continue
}
// If the channel is pending, it should have no forwarding
// packages, and nothing to reforward.
if openChannel.IsPending {
continue
}
// Channels in open or waiting-close may still have responses in
// their forwarding packages. We will continue to reattempt
// forwarding on startup until the channel is fully-closed.
//
// Load this channel's forwarding packages, and deliver them to
// the switch.
fwdPkgs, err := s.loadChannelFwdPkgs(shortChanID)
if err != nil {
log.Errorf("unable to load forwarding "+
"packages for %v: %v", shortChanID, err)
return err
}
s.reforwardSettleFails(fwdPkgs)
}
return nil
}
// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short
// channel identifier.
func (s *Switch) loadChannelFwdPkgs(source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg
if err := kvdb.View(s.cfg.DB, func(tx kvdb.RTx) error {
var err error
fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs(
tx, source,
)
return err
}, func() {
fwdPkgs = nil
}); err != nil {
return nil, err
}
return fwdPkgs, nil
}
// reforwardSettleFails parses the Settle and Fail HTLCs from the list of
// forwarding packages, and reforwards those that have not been acknowledged.
// This is intended to occur on startup, in order to recover the switch's
// mailboxes, and to ensure that responses can be propagated in case the
// outgoing link never comes back online.
//
// NOTE: This should mimic the behavior processRemoteSettleFails.
func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) {
for _, fwdPkg := range fwdPkgs {
switchPackets := make([]*htlcPacket, 0, len(fwdPkg.SettleFails))
for i, update := range fwdPkg.SettleFails {
// Skip any settles or fails that have already been
// acknowledged by the incoming link that originated the
// forwarded Add.
if fwdPkg.SettleFailFilter.Contains(uint16(i)) {
continue
}
switch msg := update.UpdateMsg.(type) {
// A settle for an HTLC we previously forwarded HTLC has
// been received. So we'll forward the HTLC to the
// switch which will handle propagating the settle to
// the prior hop.
case *lnwire.UpdateFulfillHTLC:
destRef := fwdPkg.DestRef(uint16(i))
settlePacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: msg.ID,
destRef: &destRef,
htlc: msg,
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, settlePacket)
// A failureCode message for a previously forwarded HTLC has been
// received. As a result a new slot will be freed up in our
// commitment state, so we'll forward this to the switch so the
// backwards undo can continue.
case *lnwire.UpdateFailHTLC:
// Fetch the reason the HTLC was canceled so
// we can continue to propagate it. This
// failure originated from another node, so
// the linkFailure field is not set on this
// packet. We rely on the link to fill in
// additional circuit information for us.
failPacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: msg.ID,
destRef: &channeldb.SettleFailRef{
Source: fwdPkg.Source,
Height: fwdPkg.Height,
Index: uint16(i),
},
htlc: msg,
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, failPacket)
}
}
// Since this send isn't tied to a specific link, we pass a nil
// link quit channel, meaning the send will fail only if the
// switch receives a shutdown request.
if err := s.ForwardPackets(nil, switchPackets...); err != nil {
log.Errorf("Unhandled error while reforwarding packets "+
"settle/fail over htlcswitch: %v", err)
}
}
}
// Stop gracefully stops all active helper goroutines, then waits until they've
// exited.
func (s *Switch) Stop() error {
if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
log.Warn("Htlc Switch already stopped")
return errors.New("htlc switch already shutdown")
}
log.Info("HTLC Switch shutting down...")
defer log.Debug("HTLC Switch shutdown complete")
close(s.quit)
s.wg.Wait()
// Wait until all active goroutines have finished exiting before
// stopping the mailboxes, otherwise the mailbox map could still be
// accessed and modified.
s.mailOrchestrator.Stop()
return nil
}
// CreateAndAddLink will create a link and then add it to the internal maps
// when given a ChannelLinkConfig and LightningChannel.
func (s *Switch) CreateAndAddLink(linkCfg ChannelLinkConfig,
lnChan *lnwallet.LightningChannel) error {
link := NewChannelLink(linkCfg, lnChan)
return s.AddLink(link)
}
// AddLink is used to initiate the handling of the add link command. The
// request will be propagated and handled in the main goroutine.
func (s *Switch) AddLink(link ChannelLink) error {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
chanID := link.ChanID()
// First, ensure that this link is not already active in the switch.
_, err := s.getLink(chanID)
if err == nil {
return fmt.Errorf("unable to add ChannelLink(%v), already "+
"active", chanID)
}
// Get and attach the mailbox for this link, which buffers packets in
// case there packets that we tried to deliver while this link was
// offline.
shortChanID := link.ShortChanID()
mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID)
link.AttachMailBox(mailbox)
// Attach the Switch's failAliasUpdate function to the link.
link.attachFailAliasUpdate(s.failAliasUpdate)
if err := link.Start(); err != nil {
log.Errorf("AddLink failed to start link with chanID=%v: %v",
chanID, err)
s.removeLink(chanID)
return err
}
if shortChanID == hop.Source {
log.Infof("Adding pending link chan_id=%v, short_chan_id=%v",
chanID, shortChanID)
s.pendingLinkIndex[chanID] = link
} else {
log.Infof("Adding live link chan_id=%v, short_chan_id=%v",
chanID, shortChanID)
s.addLiveLink(link)
s.mailOrchestrator.BindLiveShortChanID(
mailbox, chanID, shortChanID,
)
}
return nil
}
// addLiveLink adds a link to all associated forwarding index, this makes it a
// candidate for forwarding HTLCs.
func (s *Switch) addLiveLink(link ChannelLink) {
linkScid := link.ShortChanID()
// We'll add the link to the linkIndex which lets us quickly
// look up a channel when we need to close or register it, and
// the forwarding index which'll be used when forwarding HTLC's
// in the multi-hop setting.
s.linkIndex[link.ChanID()] = link
s.forwardingIndex[linkScid] = link
// Next we'll add the link to the interface index so we can
// quickly look up all the channels for a particular node.
peerPub := link.PeerPubKey()
if _, ok := s.interfaceIndex[peerPub]; !ok {
s.interfaceIndex[peerPub] = make(map[lnwire.ChannelID]ChannelLink)
}
s.interfaceIndex[peerPub][link.ChanID()] = link
s.updateLinkAliases(link)
}
// UpdateLinkAliases is the externally exposed wrapper for updating link
// aliases. It acquires the indexMtx and calls the internal method.
func (s *Switch) UpdateLinkAliases(link ChannelLink) {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
s.updateLinkAliases(link)
}
// updateLinkAliases updates the aliases for a given link. This will cause the
// htlcswitch to consult the alias manager on the up to date values of its
// alias maps.
//
// NOTE: this MUST be called with the indexMtx held.
func (s *Switch) updateLinkAliases(link ChannelLink) {
linkScid := link.ShortChanID()
aliases := link.getAliases()
if link.isZeroConf() {
if link.zeroConfConfirmed() {
// Since the zero-conf channel has confirmed, we can
// populate the aliasToReal mapping.
confirmedScid := link.confirmedScid()
for _, alias := range aliases {
s.aliasToReal[alias] = confirmedScid
}
// Add the confirmed SCID as a key in the baseIndex.
s.baseIndex[confirmedScid] = linkScid
}
// Now we populate the baseIndex which will be used to fetch
// the link given any of the channel's alias SCIDs or the real
// SCID. The link's SCID is an alias, so we don't need to
// special-case it like the option-scid-alias feature-bit case
// further down.
for _, alias := range aliases {
s.baseIndex[alias] = linkScid
}
} else if link.negotiatedAliasFeature() {
// First, we flush any alias mappings for this link's scid
// before we populate the map again, in order to get rid of old
// values that no longer exist.
for alias, real := range s.aliasToReal {
if real == linkScid {
delete(s.aliasToReal, alias)
}
}
for alias, real := range s.baseIndex {
if real == linkScid {
delete(s.baseIndex, alias)
}
}
// The link's SCID is the confirmed SCID for non-zero-conf
// option-scid-alias feature bit channels.
for _, alias := range aliases {
s.aliasToReal[alias] = linkScid
s.baseIndex[alias] = linkScid
}
// Since the link's SCID is confirmed, it was not included in
// the baseIndex above as a key. Add it now.
s.baseIndex[linkScid] = linkScid
}
}
// GetLink is used to initiate the handling of the get link command. The
// request will be propagated/handled to/in the main goroutine.
func (s *Switch) GetLink(chanID lnwire.ChannelID) (ChannelUpdateHandler,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLink(chanID)
}
// getLink returns the link stored in either the pending index or the live
// lindex.
func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
link, ok := s.linkIndex[chanID]
if !ok {
link, ok = s.pendingLinkIndex[chanID]
if !ok {
return nil, ErrChannelLinkNotFound
}
}
return link, nil
}
// GetLinkByShortID attempts to return the link which possesses the target short
// channel ID.
func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
link, err := s.getLinkByShortID(chanID)
if err != nil {
// If we failed to find the link under the passed-in SCID, we
// consult the Switch's baseIndex map to see if the confirmed
// SCID was used for a zero-conf channel.
aliasID, ok := s.baseIndex[chanID]
if !ok {
return nil, err
}
// An alias was found, use it to lookup if a link exists.
return s.getLinkByShortID(aliasID)
}
return link, nil
}
// getLinkByShortID attempts to return the link which possesses the target
// short channel ID.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, error) {
link, ok := s.forwardingIndex[chanID]
if !ok {
return nil, ErrChannelLinkNotFound
}
return link, nil
}
// getLinkByMapping attempts to fetch the link via the htlcPacket's
// outgoingChanID, possibly using a mapping. If it finds the link via mapping,
// the outgoingChanID will be changed so that an error can be properly
// attributed when looping over linkErrs in handlePacketForward.
//
// * If the outgoingChanID is an alias, we'll fetch the link regardless if it's
// public or not.
//
// * If the outgoingChanID is a confirmed SCID, we'll need to do more checks.
// - If there is no entry found in baseIndex, fetch the link. This channel
// did not have the option-scid-alias feature negotiated (which includes
// zero-conf and option-scid-alias channel-types).
// - If there is an entry found, fetch the link from forwardingIndex and
// fail if this is a private link.
//
// NOTE: This MUST be called with the indexMtx read lock held.
func (s *Switch) getLinkByMapping(pkt *htlcPacket) (ChannelLink, error) {
// Determine if this ShortChannelID is an alias or a confirmed SCID.
chanID := pkt.outgoingChanID
aliasID := s.cfg.IsAlias(chanID)
log.Debugf("Querying outgoing link using chanID=%v, aliasID=%v", chanID,
aliasID)
// Set the originalOutgoingChanID so the proper channel_update can be
// sent back if the option-scid-alias feature bit was negotiated.
pkt.originalOutgoingChanID = chanID
if aliasID {
// Since outgoingChanID is an alias, we'll fetch the link via
// baseIndex.
baseScid, ok := s.baseIndex[chanID]
if !ok {
// No mapping exists, bail.
return nil, ErrChannelLinkNotFound
}
// A mapping exists, so use baseScid to find the link in the
// forwardingIndex.
link, ok := s.forwardingIndex[baseScid]
if !ok {
// Link not found, bail.
return nil, ErrChannelLinkNotFound
}
// Change the packet's outgoingChanID field so that errors are
// properly attributed.
pkt.outgoingChanID = baseScid
// Return the link without checking if it's private or not.
return link, nil
}
// The outgoingChanID is a confirmed SCID. Attempt to fetch the base
// SCID from baseIndex.
baseScid, ok := s.baseIndex[chanID]
if !ok {
// outgoingChanID is not a key in base index meaning this
// channel did not have the option-scid-alias feature bit
// negotiated. We'll fetch the link and return it.
link, ok := s.forwardingIndex[chanID]
if !ok {
// The link wasn't found, bail out.
return nil, ErrChannelLinkNotFound
}
return link, nil
}
// Fetch the link whose internal SCID is baseScid.
link, ok := s.forwardingIndex[baseScid]
if !ok {
// Link wasn't found, bail out.
return nil, ErrChannelLinkNotFound
}
// If the link is unadvertised, we fail since the real SCID was used to
// forward over it and this is a channel where the option-scid-alias
// feature bit was negotiated.
if link.IsUnadvertised() {
log.Debugf("Link is unadvertised, chanID=%v, baseScid=%v",
chanID, baseScid)
return nil, ErrChannelLinkNotFound
}
// The link is public so the confirmed SCID can be used to forward over
// it. We'll also replace pkt's outgoingChanID field so errors can
// properly be attributed in the calling function.
pkt.outgoingChanID = baseScid
return link, nil
}
// HasActiveLink returns true if the given channel ID has a link in the link
// index AND the link is eligible to forward.
func (s *Switch) HasActiveLink(chanID lnwire.ChannelID) bool {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
if link, ok := s.linkIndex[chanID]; ok {
return link.EligibleToForward()
}
return false
}
// RemoveLink purges the switch of any link associated with chanID. If a pending
// or active link is not found, this method does nothing. Otherwise, the method
// returns after the link has been completely shutdown.
func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
s.indexMtx.Lock()
link, err := s.getLink(chanID)
if err != nil {
// If err is non-nil, this means that link is also nil. The
// link variable cannot be nil without err being non-nil.
s.indexMtx.Unlock()
log.Tracef("Unable to remove link for ChannelID(%v): %v",
chanID, err)
return
}
// Check if the link is already stopping and grab the stop chan if it
// is.
stopChan, ok := s.linkStopIndex[chanID]
if !ok {
// If the link is non-nil, it is not currently stopping, so
// we'll add a stop chan to the linkStopIndex.
stopChan = make(chan struct{})
s.linkStopIndex[chanID] = stopChan
}
s.indexMtx.Unlock()
if ok {
// If the stop chan exists, we will wait for it to be closed.
// Once it is closed, we will exit.
select {
case <-stopChan:
return
case <-s.quit:
return
}
}
// Stop the link before removing it from the maps.
link.Stop()
s.indexMtx.Lock()
_ = s.removeLink(chanID)
// Close stopChan and remove this link from the linkStopIndex.
// Deleting from the index and removing from the link must be done
// in the same block while the mutex is held.
close(stopChan)
delete(s.linkStopIndex, chanID)
s.indexMtx.Unlock()
}
// removeLink is used to remove and stop the channel link.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) removeLink(chanID lnwire.ChannelID) ChannelLink {
log.Infof("Removing channel link with ChannelID(%v)", chanID)
link, err := s.getLink(chanID)
if err != nil {
return nil
}
// Remove the channel from live link indexes.
delete(s.pendingLinkIndex, link.ChanID())
delete(s.linkIndex, link.ChanID())
delete(s.forwardingIndex, link.ShortChanID())
// If the link has been added to the peer index, then we'll move to
// delete the entry within the index.
peerPub := link.PeerPubKey()
if peerIndex, ok := s.interfaceIndex[peerPub]; ok {
delete(peerIndex, link.ChanID())
// If after deletion, there are no longer any links, then we'll
// remove the interface map all together.
if len(peerIndex) == 0 {
delete(s.interfaceIndex, peerPub)
}
}
return link
}
// UpdateShortChanID locates the link with the passed-in chanID and updates the
// underlying channel state. This is only used in zero-conf channels to allow
// the confirmed SCID to be updated.
func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID) error {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
// Locate the target link in the link index. If no such link exists,
// then we will ignore the request.
link, ok := s.linkIndex[chanID]
if !ok {
return fmt.Errorf("link %v not found", chanID)
}
// Try to update the link's underlying channel state, returning early
// if this update failed.
_, err := link.UpdateShortChanID()
if err != nil {
return err
}
// Since the zero-conf channel is confirmed, we should populate the
// aliasToReal map and update the baseIndex.
aliases := link.getAliases()
confirmedScid := link.confirmedScid()
for _, alias := range aliases {
s.aliasToReal[alias] = confirmedScid
}
s.baseIndex[confirmedScid] = link.ShortChanID()
return nil
}
// GetLinksByInterface fetches all the links connected to a particular node
// identified by the serialized compressed form of its public key.
func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelUpdateHandler,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
var handlers []ChannelUpdateHandler
links, err := s.getLinks(hop)
if err != nil {
return nil, err
}
// Range over the returned []ChannelLink to convert them into
// []ChannelUpdateHandler.
for _, link := range links {
handlers = append(handlers, link)
}
return handlers, nil
}
// getLinks is function which returns the channel links of the peer by hop
// destination id.
//
// NOTE: This MUST be called with the indexMtx held.
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
links, ok := s.interfaceIndex[destination]
if !ok {
return nil, ErrNoLinksFound
}
channelLinks := make([]ChannelLink, 0, len(links))
for _, link := range links {
channelLinks = append(channelLinks, link)
}
return channelLinks, nil
}
// CircuitModifier returns a reference to subset of the interfaces provided by
// the circuit map, to allow links to open and close circuits.
func (s *Switch) CircuitModifier() CircuitModifier {
return s.circuits
}
// CircuitLookup returns a reference to subset of the interfaces provided by the
// circuit map, to allow looking up circuits.
func (s *Switch) CircuitLookup() CircuitLookup {
return s.circuits
}
// commitCircuits persistently adds a circuit to the switch's circuit map.
func (s *Switch) commitCircuits(circuits ...*PaymentCircuit) (
*CircuitFwdActions, error) {
return s.circuits.CommitCircuits(circuits...)
}
// FlushForwardingEvents flushes out the set of pending forwarding events to
// the persistent log. This will be used by the switch to periodically flush
// out the set of forwarding events to disk. External callers can also use this
// method to ensure all data is flushed to dis before querying the log.
func (s *Switch) FlushForwardingEvents() error {
// First, we'll obtain a copy of the current set of pending forwarding
// events.
s.fwdEventMtx.Lock()
// If we won't have any forwarding events, then we can exit early.
if len(s.pendingFwdingEvents) == 0 {
s.fwdEventMtx.Unlock()
return nil
}
events := make([]channeldb.ForwardingEvent, len(s.pendingFwdingEvents))
copy(events[:], s.pendingFwdingEvents[:])
// With the copy obtained, we can now clear out the header pointer of
// the current slice. This way, we can re-use the underlying storage
// allocated for the slice.
s.pendingFwdingEvents = s.pendingFwdingEvents[:0]
s.fwdEventMtx.Unlock()
// Finally, we'll write out the copied events to the persistent
// forwarding log.
return s.cfg.FwdingLog.AddForwardingEvents(events)
}
// BestHeight returns the best height known to the switch.
func (s *Switch) BestHeight() uint32 {
return atomic.LoadUint32(&s.bestHeight)
}
// dustExceedsFeeThreshold takes in a ChannelLink, HTLC amount, and a boolean
// to determine whether the default fee threshold has been exceeded. This
// heuristic takes into account the trimmed-to-dust mechanism. The sum of the
// commitment's dust with the mailbox's dust with the amount is checked against
// the fee exposure threshold. If incoming is true, then the amount is not
// included in the sum as it was already included in the commitment's dust. A
// boolean is returned telling the caller whether the HTLC should be failed
// back.
func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
amount lnwire.MilliSatoshi, incoming bool) bool {
// Retrieve the link's current commitment feerate and dustClosure.
feeRate := link.getFeeRate()
isDust := link.getDustClosure()
// Evaluate if the HTLC is dust on either sides' commitment.
isLocalDust := isDust(
feeRate, incoming, lntypes.Local, amount.ToSatoshis(),
)
isRemoteDust := isDust(
feeRate, incoming, lntypes.Remote, amount.ToSatoshis(),
)
if !(isLocalDust || isRemoteDust) {
// If the HTLC is not dust on either commitment, it's fine to
// forward.
return false
}
// Fetch the dust sums currently in the mailbox for this link.
cid := link.ChanID()
sid := link.ShortChanID()
mailbox := s.mailOrchestrator.GetOrCreateMailBox(cid, sid)
localMailDust, remoteMailDust := mailbox.DustPackets()
// If the htlc is dust on the local commitment, we'll obtain the dust
// sum for it.
if isLocalDust {
localSum := link.getDustSum(
lntypes.Local, fn.None[chainfee.SatPerKWeight](),
)
localSum += localMailDust
// Optionally include the HTLC amount only for outgoing
// HTLCs.
if !incoming {
localSum += amount
}
// Finally check against the defined fee threshold.
if localSum > s.cfg.MaxFeeExposure {
return true
}
}
// Also check if the htlc is dust on the remote commitment, if we've
// reached this point.
if isRemoteDust {
remoteSum := link.getDustSum(
lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
)
remoteSum += remoteMailDust
// Optionally include the HTLC amount only for outgoing
// HTLCs.
if !incoming {
remoteSum += amount
}
// Finally check against the defined fee threshold.
if remoteSum > s.cfg.MaxFeeExposure {
return true
}
}
// If we reached this point, this HTLC is fine to forward.
return false
}
// failMailboxUpdate is passed to the mailbox orchestrator which in turn passes
// it to individual mailboxes. It allows the mailboxes to construct a
// FailureMessage when failing back HTLC's due to expiry and may include an
// alias in the ShortChannelID field. The outgoingScid is the SCID originally
// used in the onion. The mailboxScid is the SCID that the mailbox and link
// use. The mailboxScid is only used in the non-alias case, so it is always
// the confirmed SCID.
func (s *Switch) failMailboxUpdate(outgoingScid,
mailboxScid lnwire.ShortChannelID) lnwire.FailureMessage {
// Try to use the failAliasUpdate function in case this is a channel
// that uses aliases. If it returns nil, we'll fallback to the original
// pre-alias behavior.
update := s.failAliasUpdate(outgoingScid, false)
if update == nil {
// Execute the fallback behavior.
var err error
update, err = s.cfg.FetchLastChannelUpdate(mailboxScid)
if err != nil {
return &lnwire.FailTemporaryNodeFailure{}
}
}
return lnwire.NewTemporaryChannelFailure(update)
}
// failAliasUpdate prepares a ChannelUpdate for a failed incoming or outgoing
// HTLC on a channel where the option-scid-alias feature bit was negotiated. If
// the associated channel is not one of these, this function will return nil
// and the caller is expected to handle this properly. In this case, a return
// to the original non-alias behavior is expected.
func (s *Switch) failAliasUpdate(scid lnwire.ShortChannelID,
incoming bool) *lnwire.ChannelUpdate1 {
// This function does not defer the unlocking because of the database
// lookups for ChannelUpdate.
s.indexMtx.RLock()
if s.cfg.IsAlias(scid) {
// The alias SCID was used. In the incoming case this means
// the channel is zero-conf as the link sets the scid. In the
// outgoing case, the sender set the scid to use and may be
// either the alias or the confirmed one, if it exists.
realScid, ok := s.aliasToReal[scid]
if !ok {
// The real, confirmed SCID does not exist yet. Find
// the "base" SCID that the link uses via the
// baseIndex. If we can't find it, return nil. This
// means the channel is zero-conf.
baseScid, ok := s.baseIndex[scid]
s.indexMtx.RUnlock()
if !ok {
return nil
}
update, err := s.cfg.FetchLastChannelUpdate(baseScid)
if err != nil {
return nil
}
// Replace the baseScid with the passed-in alias.
update.ShortChannelID = scid
sig, err := s.cfg.SignAliasUpdate(update)
if err != nil {
return nil
}
update.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil {
return nil
}
return update
}
s.indexMtx.RUnlock()
// Fetch the SCID via the confirmed SCID and replace it with
// the alias.
update, err := s.cfg.FetchLastChannelUpdate(realScid)
if err != nil {
return nil
}
// In the incoming case, we want to ensure that we don't leak
// the UTXO in case the channel is private. In the outgoing
// case, since the alias was used, we do the same thing.
update.ShortChannelID = scid
sig, err := s.cfg.SignAliasUpdate(update)
if err != nil {
return nil
}
update.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil {
return nil
}
return update
}
// If the confirmed SCID is not in baseIndex, this is not an
// option-scid-alias or zero-conf channel.
baseScid, ok := s.baseIndex[scid]
if !ok {
s.indexMtx.RUnlock()
return nil
}
// Fetch the link so we can get an alias to use in the ShortChannelID
// of the ChannelUpdate.
link, ok := s.forwardingIndex[baseScid]
s.indexMtx.RUnlock()
if !ok {
// This should never happen, but if it does for some reason,
// fallback to the old behavior.
return nil
}
aliases := link.getAliases()
if len(aliases) == 0 {
// This should never happen, but if it does, fallback.
return nil
}
// Fetch the ChannelUpdate via the real, confirmed SCID.
update, err := s.cfg.FetchLastChannelUpdate(scid)
if err != nil {
return nil
}
// The incoming case will replace the ShortChannelID in the retrieved
// ChannelUpdate with the alias to ensure no privacy leak occurs. This
// would happen if a private non-zero-conf option-scid-alias
// feature-bit channel leaked its UTXO here rather than supplying an
// alias. In the outgoing case, the confirmed SCID was actually used
// for forwarding in the onion, so no replacement is necessary as the
// sender knows the scid.
if incoming {
// We will replace and sign the update with the first alias.
// Since this happens on the incoming side, it's not actually
// possible to know what the sender used in the onion.
update.ShortChannelID = aliases[0]
sig, err := s.cfg.SignAliasUpdate(update)
if err != nil {
return nil
}
update.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil {
return nil
}
}
return update
}
// AddAliasForLink instructs the Switch to update its in-memory maps to reflect
// that a link has a new alias.
func (s *Switch) AddAliasForLink(chanID lnwire.ChannelID,
alias lnwire.ShortChannelID) error {
// Fetch the link so that we can update the underlying channel's set of
// aliases.
s.indexMtx.RLock()
link, err := s.getLink(chanID)
s.indexMtx.RUnlock()
if err != nil {
return err
}
// If the link is a channel where the option-scid-alias feature bit was
// not negotiated, we'll return an error.
if !link.negotiatedAliasFeature() {
return fmt.Errorf("attempted to update non-alias channel")
}
linkScid := link.ShortChanID()
// We'll update the maps so the Switch includes this alias in its
// forwarding decisions.
if link.isZeroConf() {
if link.zeroConfConfirmed() {
// If the channel has confirmed on-chain, we'll
// add this alias to the aliasToReal map.
confirmedScid := link.confirmedScid()
s.aliasToReal[alias] = confirmedScid
}
// Add this alias to the baseIndex mapping.
s.baseIndex[alias] = linkScid
} else if link.negotiatedAliasFeature() {
// The channel is confirmed, so we'll populate the aliasToReal
// and baseIndex maps.
s.aliasToReal[alias] = linkScid
s.baseIndex[alias] = linkScid
}
return nil
}
// handlePacketAdd handles forwarding an Add packet.
func (s *Switch) handlePacketAdd(packet *htlcPacket,
htlc *lnwire.UpdateAddHTLC) error {
// Check if the node is set to reject all onward HTLCs and also make
// sure that HTLC is not from the source node.
if s.cfg.RejectHTLC {
failure := NewDetailedLinkError(
&lnwire.FailChannelDisabled{},
OutgoingFailureForwardsDisabled,
)
return s.failAddPacket(packet, failure)
}
// Before we attempt to find a non-strict forwarding path for this
// htlc, check whether the htlc is being routed over the same incoming
// and outgoing channel. If our node does not allow forwards of this
// nature, we fail the htlc early. This check is in place to disallow
// inefficiently routed htlcs from locking up our balance. With
// channels where the option-scid-alias feature was negotiated, we also
// have to be sure that the IDs aren't the same since one or both could
// be an alias.
linkErr := s.checkCircularForward(
packet.incomingChanID, packet.outgoingChanID,
s.cfg.AllowCircularRoute, htlc.PaymentHash,
)
if linkErr != nil {
return s.failAddPacket(packet, linkErr)
}
s.indexMtx.RLock()
targetLink, err := s.getLinkByMapping(packet)
if err != nil {
s.indexMtx.RUnlock()
log.Debugf("unable to find link with "+
"destination %v", packet.outgoingChanID)
// If packet was forwarded from another channel link than we
// should notify this link that some error occurred.
linkError := NewLinkError(
&lnwire.FailUnknownNextPeer{},
)
return s.failAddPacket(packet, linkError)
}
targetPeerKey := targetLink.PeerPubKey()
interfaceLinks, _ := s.getLinks(targetPeerKey)
s.indexMtx.RUnlock()
// We'll keep track of any HTLC failures during the link selection
// process. This way we can return the error for precise link that the
// sender selected, while optimistically trying all links to utilize
// our available bandwidth.
linkErrs := make(map[lnwire.ShortChannelID]*LinkError)
// Find all destination channel links with appropriate bandwidth.
var destinations []ChannelLink
for _, link := range interfaceLinks {
var failure *LinkError
// We'll skip any links that aren't yet eligible for
// forwarding.
if !link.EligibleToForward() {
failure = NewDetailedLinkError(
&lnwire.FailUnknownNextPeer{},
OutgoingFailureLinkNotEligible,
)
} else {
// We'll ensure that the HTLC satisfies the current
// forwarding conditions of this target link.
currentHeight := atomic.LoadUint32(&s.bestHeight)
failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout, packet.inboundFee,
currentHeight, packet.originalOutgoingChanID,
htlc.CustomRecords,
)
}
// If this link can forward the htlc, add it to the set of
// destinations.
if failure == nil {
destinations = append(destinations, link)
continue
}
linkErrs[link.ShortChanID()] = failure
}
// If we had a forwarding failure due to the HTLC not satisfying the
// current policy, then we'll send back an error, but ensure we send
// back the error sourced at the *target* link.
if len(destinations) == 0 {
// At this point, some or all of the links rejected the HTLC so
// we couldn't forward it. So we'll try to look up the error
// that came from the source.
linkErr, ok := linkErrs[packet.outgoingChanID]
if !ok {
// If we can't find the error of the source, then we'll
// return an unknown next peer, though this should
// never happen.
linkErr = NewLinkError(
&lnwire.FailUnknownNextPeer{},
)
log.Warnf("unable to find err source for "+
"outgoing_link=%v, errors=%v",
packet.outgoingChanID,
lnutils.SpewLogClosure(linkErrs))
}
log.Tracef("incoming HTLC(%x) violated "+
"target outgoing link (id=%v) policy: %v",
htlc.PaymentHash[:], packet.outgoingChanID,
linkErr)
return s.failAddPacket(packet, linkErr)
}
// Choose a random link out of the set of links that can forward this
// htlc. The reason for randomization is to evenly distribute the htlc
// load without making assumptions about what the best channel is.
//nolint:gosec
destination := destinations[rand.Intn(len(destinations))]
// Retrieve the incoming link by its ShortChannelID. Note that the
// incomingChanID is never set to hop.Source here.
s.indexMtx.RLock()
incomingLink, err := s.getLinkByShortID(packet.incomingChanID)
s.indexMtx.RUnlock()
if err != nil {
// If we couldn't find the incoming link, we can't evaluate the
// incoming's exposure to dust, so we just fail the HTLC back.
linkErr := NewLinkError(
&lnwire.FailTemporaryChannelFailure{},
)
return s.failAddPacket(packet, linkErr)
}
// Evaluate whether this HTLC would increase our fee exposure over the
// threshold on the incoming link. If it does, fail it backwards.
if s.dustExceedsFeeThreshold(
incomingLink, packet.incomingAmount, true,
) {
// The incoming dust exceeds the threshold, so we fail the add
// back.
linkErr := NewLinkError(
&lnwire.FailTemporaryChannelFailure{},
)
return s.failAddPacket(packet, linkErr)
}
// Also evaluate whether this HTLC would increase our fee exposure over
// the threshold on the destination link. If it does, fail it back.
if s.dustExceedsFeeThreshold(
destination, packet.amount, false,
) {
// The outgoing dust exceeds the threshold, so we fail the add
// back.
linkErr := NewLinkError(
&lnwire.FailTemporaryChannelFailure{},
)
return s.failAddPacket(packet, linkErr)
}
// Send the packet to the destination channel link which manages the
// channel.
packet.outgoingChanID = destination.ShortChanID()
return destination.handleSwitchPacket(packet)
}
// handlePacketSettle handles forwarding a settle packet.
func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
// If the source of this packet has not been set, use the circuit map
// to lookup the origin.
circuit, err := s.closeCircuit(packet)
// If the circuit is in the process of closing, we will return a nil as
// there's another packet handling undergoing.
if errors.Is(err, ErrCircuitClosing) {
log.Debugf("Circuit is closing for packet=%v", packet)
return nil
}
// Exit early if there's another error.
if err != nil {
return err
}
// closeCircuit returns a nil circuit when a settle packet returns an
// ErrUnknownCircuit error upon the inner call to CloseCircuit.
//
// NOTE: We can only get a nil circuit when it has already been deleted
// and when `UpdateFulfillHTLC` is received. After which `RevokeAndAck`
// is received, which invokes `processRemoteSettleFails` in its link.
if circuit == nil {
log.Debugf("Circuit already closed for packet=%v", packet)
return nil
}
localHTLC := packet.incomingChanID == hop.Source
// If this is a locally initiated HTLC, we need to handle the packet by
// storing the network result.
//
// A blank IncomingChanID in a circuit indicates that it is a pending
// user-initiated payment.
//
// NOTE: `closeCircuit` modifies the state of `packet`.
if localHTLC {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
go s.handleLocalResponse(packet)
// If this is a locally initiated HTLC, there's no need to
// forward it so we exit.
return nil
}
// If this is an HTLC settle, and it wasn't from a locally initiated
// HTLC, then we'll log a forwarding event so we can flush it to disk
// later.
if circuit.Outgoing != nil {
log.Infof("Forwarded HTLC(%x) of %v (fee: %v) "+
"from IncomingChanID(%v) to OutgoingChanID(%v)",
circuit.PaymentHash[:], circuit.OutgoingAmount,
circuit.IncomingAmount-circuit.OutgoingAmount,
circuit.Incoming.ChanID, circuit.Outgoing.ChanID)
s.fwdEventMtx.Lock()
s.pendingFwdingEvents = append(
s.pendingFwdingEvents,
channeldb.ForwardingEvent{
Timestamp: time.Now(),
IncomingChanID: circuit.Incoming.ChanID,
OutgoingChanID: circuit.Outgoing.ChanID,
AmtIn: circuit.IncomingAmount,
AmtOut: circuit.OutgoingAmount,
},
)
s.fwdEventMtx.Unlock()
}
// Deliver this packet.
return s.mailOrchestrator.Deliver(packet.incomingChanID, packet)
}
// handlePacketFail handles forwarding a fail packet.
func (s *Switch) handlePacketFail(packet *htlcPacket,
htlc *lnwire.UpdateFailHTLC) error {
// If the source of this packet has not been set, use the circuit map
// to lookup the origin.
circuit, err := s.closeCircuit(packet)
if err != nil {
return err
}
// If this is a locally initiated HTLC, we need to handle the packet by
// storing the network result.
//
// A blank IncomingChanID in a circuit indicates that it is a pending
// user-initiated payment.
//
// NOTE: `closeCircuit` modifies the state of `packet`.
if packet.incomingChanID == hop.Source {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
go s.handleLocalResponse(packet)
// If this is a locally initiated HTLC, there's no need to
// forward it so we exit.
return nil
}
// Exit early if this hasSource is true. This flag is only set via
// mailbox's `FailAdd`. This method has two callsites,
// - the packet has timed out after `MailboxDeliveryTimeout`, defaults
// to 1 min.
// - the HTLC fails the validation in `channel.AddHTLC`.
// In either case, the `Reason` field is populated. Thus there's no
// need to proceed and extract the failure reason below.
if packet.hasSource {
// Deliver this packet.
return s.mailOrchestrator.Deliver(packet.incomingChanID, packet)
}
// HTLC resolutions and messages restored from disk don't have the
// obfuscator set from the original htlc add packet - set it here for
// use in blinded errors.
packet.obfuscator = circuit.ErrorEncrypter
switch {
// No message to encrypt, locally sourced payment.
case circuit.ErrorEncrypter == nil:
// TODO(yy) further check this case as we shouldn't end up here
// as `isLocal` is already false.
// If this is a resolution message, then we'll need to encrypt it as
// it's actually internally sourced.
case packet.isResolution:
var err error
// TODO(roasbeef): don't need to pass actually?
failure := &lnwire.FailPermanentChannelFailure{}
htlc.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop(
failure,
)
if err != nil {
err = fmt.Errorf("unable to obfuscate error: %w", err)
log.Error(err)
}
// Alternatively, if the remote party sends us an
// UpdateFailMalformedHTLC, then we'll need to convert this into a
// proper well formatted onion error as there's no HMAC currently.
case packet.convertedError:
log.Infof("Converting malformed HTLC error for circuit for "+
"Circuit(%x: (%s, %d) <-> (%s, %d))",
packet.circuit.PaymentHash,
packet.incomingChanID, packet.incomingHTLCID,
packet.outgoingChanID, packet.outgoingHTLCID)
htlc.Reason = circuit.ErrorEncrypter.EncryptMalformedError(
htlc.Reason,
)
default:
// Otherwise, it's a forwarded error, so we'll perform a
// wrapper encryption as normal.
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
htlc.Reason,
)
}
// Deliver this packet.
return s.mailOrchestrator.Deliver(packet.incomingChanID, packet)
}
package htlcswitch
import (
"bytes"
"context"
crand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"net"
"os"
"runtime"
"runtime/pprof"
"sync/atomic"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/require"
)
// maxInflightHtlcs specifies the max number of inflight HTLCs. This number is
// chosen to be smaller than the default 483 so the test can run faster.
const maxInflightHtlcs = 50
var (
alicePrivKey = []byte("alice priv key")
bobPrivKey = []byte("bob priv key")
carolPrivKey = []byte("carol priv key")
testRBytes, _ = hex.DecodeString("8ce2bc69281ce27da07e6683571319d18e949ddfa2965fb6caa1bf0314f882d7")
testSBytes, _ = hex.DecodeString("299105481d63e0f4bc2a88121167221b6700d72a0ead154c03be696a292d24ae")
testRScalar = new(btcec.ModNScalar)
testSScalar = new(btcec.ModNScalar)
_ = testRScalar.SetByteSlice(testRBytes)
_ = testSScalar.SetByteSlice(testSBytes)
testSig = ecdsa.NewSignature(testRScalar, testSScalar)
wireSig, _ = lnwire.NewSigFromSignature(testSig)
testBatchTimeout = 50 * time.Millisecond
)
var idSeqNum uint64
// genID generates a unique tuple to identify a test channel.
func genID() (lnwire.ChannelID, lnwire.ShortChannelID) {
id := atomic.AddUint64(&idSeqNum, 1)
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], id)
hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
chanPoint1 := wire.NewOutPoint(hash1, uint32(id))
chanID1 := lnwire.NewChanIDFromOutPoint(*chanPoint1)
aliceChanID := lnwire.NewShortChanIDFromInt(id)
return chanID1, aliceChanID
}
// genIDs generates ids for two test channels.
func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID,
lnwire.ShortChannelID) {
chanID1, aliceChanID := genID()
chanID2, bobChanID := genID()
return chanID1, chanID2, aliceChanID, bobChanID
}
// mockGetChanUpdateMessage helper function which returns topology update of
// the channel
func mockGetChanUpdateMessage(_ lnwire.ShortChannelID) (*lnwire.ChannelUpdate1,
error) {
return &lnwire.ChannelUpdate1{
Signature: wireSig,
}, nil
}
// generateRandomBytes returns securely generated random bytes.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func generateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
// TODO(roasbeef): should use counter in tests (atomic) rather than
// this
_, err := crand.Read(b)
// Note that Err == nil only if we read len(b) bytes.
if err != nil {
return nil, err
}
return b, nil
}
type testLightningChannel struct {
channel *lnwallet.LightningChannel
restore func() (*lnwallet.LightningChannel, error)
}
// createTestChannel creates the channel and returns our and remote channels
// representations.
//
// TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase
func createTestChannel(t *testing.T, alicePrivKey, bobPrivKey []byte,
aliceAmount, bobAmount, aliceReserve, bobReserve btcutil.Amount,
chanID lnwire.ShortChannelID) (*testLightningChannel,
*testLightningChannel, error) {
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(alicePrivKey)
bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(bobPrivKey)
channelCapacity := aliceAmount + bobAmount
csvTimeoutAlice := uint32(5)
csvTimeoutBob := uint32(4)
isAliceInitiator := true
aliceBounds := channeldb.ChannelStateBounds{
MaxPendingAmount: lnwire.NewMSatFromSatoshis(
channelCapacity),
ChanReserve: aliceReserve,
MinHTLC: 0,
MaxAcceptedHtlcs: maxInflightHtlcs,
}
aliceCommitParams := channeldb.CommitmentParams{
DustLimit: btcutil.Amount(200),
CsvDelay: uint16(csvTimeoutAlice),
}
bobBounds := channeldb.ChannelStateBounds{
MaxPendingAmount: lnwire.NewMSatFromSatoshis(
channelCapacity),
ChanReserve: bobReserve,
MinHTLC: 0,
MaxAcceptedHtlcs: maxInflightHtlcs,
}
bobCommitParams := channeldb.CommitmentParams{
DustLimit: btcutil.Amount(800),
CsvDelay: uint16(csvTimeoutBob),
}
var hash [sha256.Size]byte
randomSeed, err := generateRandomBytes(sha256.Size)
if err != nil {
return nil, nil, err
}
copy(hash[:], randomSeed)
prevOut := &wire.OutPoint{
Hash: chainhash.Hash(hash),
Index: 0,
}
fundingTxIn := wire.NewTxIn(prevOut, nil, nil)
aliceCfg := channeldb.ChannelConfig{
ChannelStateBounds: aliceBounds,
CommitmentParams: aliceCommitParams,
MultiSigKey: keychain.KeyDescriptor{
PubKey: aliceKeyPub,
},
RevocationBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeyPub,
},
PaymentBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeyPub,
},
DelayBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeyPub,
},
HtlcBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeyPub,
},
}
bobCfg := channeldb.ChannelConfig{
ChannelStateBounds: bobBounds,
CommitmentParams: bobCommitParams,
MultiSigKey: keychain.KeyDescriptor{
PubKey: bobKeyPub,
},
RevocationBasePoint: keychain.KeyDescriptor{
PubKey: bobKeyPub,
},
PaymentBasePoint: keychain.KeyDescriptor{
PubKey: bobKeyPub,
},
DelayBasePoint: keychain.KeyDescriptor{
PubKey: bobKeyPub,
},
HtlcBasePoint: keychain.KeyDescriptor{
PubKey: bobKeyPub,
},
}
bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
if err != nil {
return nil, nil, err
}
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, err
}
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
if err != nil {
return nil, nil, err
}
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, err
}
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
aliceCommitTx, bobCommitTx, err := lnwallet.CreateCommitmentTxns(
aliceAmount, bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint,
bobCommitPoint, *fundingTxIn, channeldb.SingleFunderTweaklessBit,
isAliceInitiator, 0,
)
if err != nil {
return nil, nil, err
}
dbAlice := channeldb.OpenForTesting(t, t.TempDir())
dbBob := channeldb.OpenForTesting(t, t.TempDir())
estimator := chainfee.NewStaticEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil {
return nil, nil, err
}
commitFee := feePerKw.FeeForWeight(724)
const broadcastHeight = 1
bobAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
aliceAddr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
aliceCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: lnwire.NewMSatFromSatoshis(aliceAmount - commitFee),
RemoteBalance: lnwire.NewMSatFromSatoshis(bobAmount),
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: aliceCommitTx,
CommitSig: bytes.Repeat([]byte{1}, 71),
}
bobCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: lnwire.NewMSatFromSatoshis(bobAmount),
RemoteBalance: lnwire.NewMSatFromSatoshis(aliceAmount - commitFee),
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: bobCommitTx,
CommitSig: bytes.Repeat([]byte{1}, 71),
}
aliceChannelState := &channeldb.OpenChannel{
LocalChanCfg: aliceCfg,
RemoteChanCfg: bobCfg,
IdentityPub: aliceKeyPub,
FundingOutpoint: *prevOut,
ChanType: channeldb.SingleFunderTweaklessBit,
IsInitiator: isAliceInitiator,
Capacity: channelCapacity,
RemoteCurrentRevocation: bobCommitPoint,
RevocationProducer: alicePreimageProducer,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit,
ShortChannelID: chanID,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID),
FundingTxn: channels.TestFundingTx,
}
bobChannelState := &channeldb.OpenChannel{
LocalChanCfg: bobCfg,
RemoteChanCfg: aliceCfg,
IdentityPub: bobKeyPub,
FundingOutpoint: *prevOut,
ChanType: channeldb.SingleFunderTweaklessBit,
IsInitiator: !isAliceInitiator,
Capacity: channelCapacity,
RemoteCurrentRevocation: aliceCommitPoint,
RevocationProducer: bobPreimageProducer,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobCommit,
RemoteCommitment: bobCommit,
ShortChannelID: chanID,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID),
}
if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil {
return nil, nil, err
}
if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil {
return nil, nil, err
}
aliceSigner := input.NewMockSigner(
[]*btcec.PrivateKey{aliceKeyPriv}, nil,
)
bobSigner := input.NewMockSigner(
[]*btcec.PrivateKey{bobKeyPriv}, nil,
)
alicePool := lnwallet.NewSigPool(runtime.NumCPU(), aliceSigner)
signerMock := lnwallet.NewDefaultAuxSignerMock(t)
channelAlice, err := lnwallet.NewLightningChannel(
aliceSigner, aliceChannelState, alicePool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
)
if err != nil {
return nil, nil, err
}
alicePool.Start()
bobPool := lnwallet.NewSigPool(runtime.NumCPU(), bobSigner)
channelBob, err := lnwallet.NewLightningChannel(
bobSigner, bobChannelState, bobPool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
)
if err != nil {
return nil, nil, err
}
bobPool.Start()
// Now that the channel are open, simulate the start of a session by
// having Alice and Bob extend their revocation windows to each other.
aliceNextRevoke, err := channelAlice.NextRevocationKey()
if err != nil {
return nil, nil, err
}
if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil {
return nil, nil, err
}
bobNextRevoke, err := channelBob.NextRevocationKey()
if err != nil {
return nil, nil, err
}
if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil {
return nil, nil, err
}
restoreAlice := func() (*lnwallet.LightningChannel, error) {
aliceStoredChannels, err := dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
switch err {
case nil:
case kvdb.ErrDatabaseNotOpen:
dbAlice = channeldb.OpenForTesting(t, dbAlice.Path())
aliceStoredChannels, err = dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
if err != nil {
return nil, errors.Errorf("unable to fetch alice "+
"channel: %v", err)
}
default:
return nil, errors.Errorf("unable to fetch alice channel: "+
"%v", err)
}
var aliceStoredChannel *channeldb.OpenChannel
for _, channel := range aliceStoredChannels {
if channel.FundingOutpoint.String() == prevOut.String() {
aliceStoredChannel = channel
break
}
}
if aliceStoredChannel == nil {
return nil, errors.New("unable to find stored alice channel")
}
newAliceChannel, err := lnwallet.NewLightningChannel(
aliceSigner, aliceStoredChannel, alicePool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
)
if err != nil {
return nil, errors.Errorf("unable to create new channel: %v",
err)
}
return newAliceChannel, nil
}
restoreBob := func() (*lnwallet.LightningChannel, error) {
bobStoredChannels, err := dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
switch err {
case nil:
case kvdb.ErrDatabaseNotOpen:
dbBob = channeldb.OpenForTesting(t, dbBob.Path())
if err != nil {
return nil, errors.Errorf("unable to reopen bob "+
"db: %v", err)
}
bobStoredChannels, err = dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
if err != nil {
return nil, errors.Errorf("unable to fetch bob "+
"channel: %v", err)
}
default:
return nil, errors.Errorf("unable to fetch bob channel: "+
"%v", err)
}
var bobStoredChannel *channeldb.OpenChannel
for _, channel := range bobStoredChannels {
if channel.FundingOutpoint.String() == prevOut.String() {
bobStoredChannel = channel
break
}
}
if bobStoredChannel == nil {
return nil, errors.New("unable to find stored bob channel")
}
newBobChannel, err := lnwallet.NewLightningChannel(
bobSigner, bobStoredChannel, bobPool,
lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
lnwallet.WithAuxSigner(signerMock),
)
if err != nil {
return nil, errors.Errorf("unable to create new channel: %v",
err)
}
return newBobChannel, nil
}
testLightningChannelAlice := &testLightningChannel{
channel: channelAlice,
restore: restoreAlice,
}
testLightningChannelBob := &testLightningChannel{
channel: channelBob,
restore: restoreBob,
}
return testLightningChannelAlice, testLightningChannelBob, nil
}
// getChanID retrieves the channel point from an lnnwire message.
func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
var chanID lnwire.ChannelID
switch msg := msg.(type) {
case *lnwire.UpdateAddHTLC:
chanID = msg.ChanID
case *lnwire.UpdateFulfillHTLC:
chanID = msg.ChanID
case *lnwire.UpdateFailHTLC:
chanID = msg.ChanID
case *lnwire.RevokeAndAck:
chanID = msg.ChanID
case *lnwire.CommitSig:
chanID = msg.ChanID
case *lnwire.ChannelReestablish:
chanID = msg.ChanID
case *lnwire.ChannelReady:
chanID = msg.ChanID
case *lnwire.UpdateFee:
chanID = msg.ChanID
default:
return chanID, fmt.Errorf("unknown type: %T", msg)
}
return chanID, nil
}
// generateHoldPayment generates the htlc add request by given path blob and
// invoice which should be added by destination peer.
func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32, blob [lnwire.OnionPacketSize]byte,
preimage *lntypes.Preimage, rhash, payAddr [32]byte) (
*invoices.Invoice, *lnwire.UpdateAddHTLC, uint64, error) {
// Create the db invoice. Normally the payment requests needs to be set,
// because it is decoded in InvoiceRegistry to obtain the cltv expiry.
// But because the mock registry used in tests is mocking the decode
// step and always returning the value of testInvoiceCltvExpiry, we
// don't need to bother here with creating and signing a payment
// request.
invoice := &invoices.Invoice{
CreationDate: time.Now(),
Terms: invoices.ContractTerm{
FinalCltvDelta: testInvoiceCltvExpiry,
Value: invoiceAmt,
PaymentPreimage: preimage,
PaymentAddr: payAddr,
Features: lnwire.NewFeatureVector(
nil, lnwire.Features,
),
},
HodlInvoice: preimage == nil,
}
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: htlcAmt,
Expiry: timelock,
OnionBlob: blob,
}
pid, err := generateRandomBytes(8)
if err != nil {
return nil, nil, 0, err
}
paymentID := binary.BigEndian.Uint64(pid)
return invoice, htlc, paymentID, nil
}
// generatePayment generates the htlc add request by given path blob and
// invoice which should be added by destination peer.
func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
blob [lnwire.OnionPacketSize]byte) (*invoices.Invoice,
*lnwire.UpdateAddHTLC, uint64, error) {
var preimage lntypes.Preimage
r, err := generateRandomBytes(sha256.Size)
if err != nil {
return nil, nil, 0, err
}
copy(preimage[:], r)
rhash := sha256.Sum256(preimage[:])
var payAddr [sha256.Size]byte
r, err = generateRandomBytes(sha256.Size)
if err != nil {
return nil, nil, 0, err
}
copy(payAddr[:], r)
return generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob, &preimage, rhash, payAddr,
)
}
// generateRoute generates the path blob by given array of peers.
func generateRoute(hops ...*hop.Payload) (
[lnwire.OnionPacketSize]byte, error) {
var blob [lnwire.OnionPacketSize]byte
if len(hops) == 0 {
return blob, errors.New("empty path")
}
iterator := newMockHopIterator(hops...)
w := bytes.NewBuffer(blob[0:0])
if err := iterator.EncodeNextHop(w); err != nil {
return blob, err
}
return blob, nil
}
// threeHopNetwork is used for managing the created cluster of 3 hops.
type threeHopNetwork struct {
aliceServer *mockServer
aliceChannelLink *channelLink
aliceOnionDecoder *mockIteratorDecoder
bobServer *mockServer
firstBobChannelLink *channelLink
secondBobChannelLink *channelLink
bobOnionDecoder *mockIteratorDecoder
carolServer *mockServer
carolChannelLink *channelLink
carolOnionDecoder *mockIteratorDecoder
hopNetwork
}
// generateHops creates the per hop payload, the total amount to be sent, and
// also the time lock value needed to route an HTLC with the target amount over
// the specified path.
func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
path ...*channelLink) (lnwire.MilliSatoshi, uint32, []*hop.Payload) {
totalTimelock := startingHeight
runningAmt := payAmt
hops := make([]*hop.Payload, len(path))
for i := len(path) - 1; i >= 0; i-- {
// If this is the last hop, then the next hop is the special
// "exit node". Otherwise, we look to the "prior" hop.
nextHop := hop.Exit
if i != len(path)-1 {
nextHop = path[i+1].channel.ShortChanID()
}
var timeLock uint32
// If this is the last, hop, then the time lock will be their
// specified delta policy plus our starting height.
if i == len(path)-1 {
totalTimelock += testInvoiceCltvExpiry
timeLock = totalTimelock
} else {
// Otherwise, the outgoing time lock should be the
// incoming timelock minus their specified delta.
delta := path[i+1].cfg.FwrdingPolicy.TimeLockDelta
totalTimelock += delta
timeLock = totalTimelock - delta
}
// Finally, we'll need to calculate the amount to forward. For
// the last hop, it's just the payment amount.
amount := payAmt
if i != len(path)-1 {
prevHop := hops[i+1]
prevAmount := prevHop.ForwardingInfo().AmountToForward
fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount)
runningAmt += fee
// Otherwise, for a node to forward an HTLC, then
// following inequality most hold true:
// * amt_in - fee >= amt_to_forward
amount = runningAmt - fee
}
var nextHopBytes [8]byte
binary.BigEndian.PutUint64(nextHopBytes[:], nextHop.ToUint64())
hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
Realm: [1]byte{}, // hop.BitcoinNetwork
NextAddress: nextHopBytes,
ForwardAmount: uint64(amount),
OutgoingCltv: timeLock,
})
}
return runningAmt, totalTimelock, hops
}
type paymentResponse struct {
rhash lntypes.Hash
err chan error
}
func (r *paymentResponse) Wait(d time.Duration) (lntypes.Hash, error) {
return r.rhash, waitForPaymentResult(r.err, d)
}
// waitForPaymentResult waits for either an error to be received on c or a
// timeout.
func waitForPaymentResult(c chan error, d time.Duration) error {
select {
case err := <-c:
close(c)
return err
case <-time.After(d):
return errors.New("htlc was not settled in time")
}
}
// waitForPayFuncResult executes the given function and waits for a result with
// a timeout.
func waitForPayFuncResult(payFunc func() error, d time.Duration) error {
errChan := make(chan error)
go func() {
errChan <- payFunc()
}()
return waitForPaymentResult(errChan, d)
}
// makePayment takes the destination node and amount as input, sends the
// payment and returns the error channel to wait for error to be received and
// invoice in order to check its status after the payment finished.
//
// With this function you can send payments:
// * from Alice to Bob
// * from Alice to Carol through the Bob
// * from Alice to some another peer through the Bob
func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32) *paymentResponse {
paymentErr := make(chan error, 1)
var rhash lntypes.Hash
invoice, payFunc, err := preparePayment(sendingPeer, receivingPeer,
firstHop, hops, invoiceAmt, htlcAmt, timelock,
)
if err != nil {
paymentErr <- err
return &paymentResponse{
rhash: rhash,
err: paymentErr,
}
}
rhash = invoice.Terms.PaymentPreimage.Hash()
// Send payment and expose err channel.
go func() {
paymentErr <- payFunc()
}()
return &paymentResponse{
rhash: rhash,
err: paymentErr,
}
}
// preparePayment creates an invoice at the receivingPeer and returns a function
// that, when called, launches the payment from the sendingPeer.
func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32) (*invoices.Invoice, func() error, error) {
sender := sendingPeer.(*mockServer)
receiver := receivingPeer.(*mockServer)
// Generate route convert it to blob, and return next destination for
// htlc add request.
blob, err := generateRoute(hops...)
if err != nil {
return nil, nil, err
}
// Generate payment: invoice and htlc.
invoice, htlc, pid, err := generatePayment(
invoiceAmt, htlcAmt, timelock, blob,
)
if err != nil {
return nil, nil, err
}
// Check who is last in the route and add invoice to server registry.
hash := invoice.Terms.PaymentPreimage.Hash()
if err := receiver.registry.AddInvoice(
context.Background(), *invoice, hash,
); err != nil {
return nil, nil, err
}
// Send payment and expose err channel.
return invoice, func() error {
err := sender.htlcSwitch.SendHTLC(
firstHop, pid, htlc,
)
if err != nil {
return err
}
resultChan, err := sender.htlcSwitch.GetAttemptResult(
pid, hash, newMockDeobfuscator(),
)
if err != nil {
return err
}
result, ok := <-resultChan
if !ok {
return fmt.Errorf("shutting down")
}
if result.Error != nil {
return result.Error
}
return nil
}, nil
}
// start starts the three hop network alice,bob,carol servers.
func (n *threeHopNetwork) start() error {
if err := n.aliceServer.Start(); err != nil {
return err
}
if err := n.bobServer.Start(); err != nil {
return err
}
if err := n.carolServer.Start(); err != nil {
return err
}
return waitLinksEligible(map[string]*channelLink{
"alice": n.aliceChannelLink,
"bob first": n.firstBobChannelLink,
"bob second": n.secondBobChannelLink,
"carol": n.carolChannelLink,
})
}
// stop stops nodes and cleanup its databases.
func (n *threeHopNetwork) stop() {
done := make(chan struct{})
go func() {
n.aliceServer.Stop()
done <- struct{}{}
}()
go func() {
n.bobServer.Stop()
done <- struct{}{}
}()
go func() {
n.carolServer.Stop()
done <- struct{}{}
}()
for i := 0; i < 3; i++ {
<-done
}
}
type clusterChannels struct {
aliceToBob *lnwallet.LightningChannel
bobToAlice *lnwallet.LightningChannel
bobToCarol *lnwallet.LightningChannel
carolToBob *lnwallet.LightningChannel
}
// createClusterChannels creates lightning channels which are needed for
// network cluster to be initialized.
func createClusterChannels(t *testing.T, aliceToBob, bobToCarol btcutil.Amount) (
*clusterChannels, func() (*clusterChannels, error), error) {
_, _, firstChanID, secondChanID := genIDs()
// Create lightning channels between Alice<->Bob and Bob<->Carol
aliceChannel, firstBobChannel, err := createTestChannel(t, alicePrivKey,
bobPrivKey, aliceToBob, aliceToBob, 0, 0, firstChanID,
)
if err != nil {
return nil, nil, errors.Errorf("unable to create "+
"alice<->bob channel: %v", err)
}
secondBobChannel, carolChannel, err := createTestChannel(t, bobPrivKey,
carolPrivKey, bobToCarol, bobToCarol, 0, 0, secondChanID,
)
if err != nil {
return nil, nil, errors.Errorf("unable to create "+
"bob<->carol channel: %v", err)
}
restoreFromDb := func() (*clusterChannels, error) {
a2b, err := aliceChannel.restore()
if err != nil {
return nil, err
}
b2a, err := firstBobChannel.restore()
if err != nil {
return nil, err
}
b2c, err := secondBobChannel.restore()
if err != nil {
return nil, err
}
c2b, err := carolChannel.restore()
if err != nil {
return nil, err
}
return &clusterChannels{
aliceToBob: a2b,
bobToAlice: b2a,
bobToCarol: b2c,
carolToBob: c2b,
}, nil
}
return &clusterChannels{
aliceToBob: aliceChannel.channel,
bobToAlice: firstBobChannel.channel,
bobToCarol: secondBobChannel.channel,
carolToBob: carolChannel.channel,
}, restoreFromDb, nil
}
// newThreeHopNetwork function creates the following topology and returns the
// control object to manage this cluster:
//
// alice bob carol
// server - <-connection-> - server - - <-connection-> - - - server
//
// | | |
//
// alice htlc bob htlc carol htlc
// switch switch \ switch
//
// | | \ |
// | | \ |
//
// alice first bob second bob carol
// channel link channel link channel link channel link
//
// This function takes server options which can be used to apply custom
// settings to alice, bob and carol.
func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
secondBobChannel, carolChannel *lnwallet.LightningChannel,
startingHeight uint32, opts ...serverOption) *threeHopNetwork {
aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := firstBobChannel.State().Db.GetParentDB()
carolDb := carolChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork()
// Create three peers/servers.
aliceServer, err := newMockServer(
t, "alice", startingHeight, aliceDb, hopNetwork.defaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobServer, err := newMockServer(
t, "bob", startingHeight, bobDb, hopNetwork.defaultDelta,
)
require.NoError(t, err, "unable to create bob server")
carolServer, err := newMockServer(
t, "carol", startingHeight, carolDb, hopNetwork.defaultDelta,
)
require.NoError(t, err, "unable to create carol server")
// Apply all additional functional options to the servers before
// creating any links.
for _, option := range opts {
option(aliceServer, bobServer, carolServer)
}
// Create mock decoder instead of sphinx one in order to mock the route
// which htlc should follow.
aliceDecoder := newMockIteratorDecoder()
bobDecoder := newMockIteratorDecoder()
carolDecoder := newMockIteratorDecoder()
aliceChannelLink, err := hopNetwork.createChannelLink(aliceServer,
bobServer, aliceChannel, aliceDecoder,
)
if err != nil {
t.Fatal(err)
}
firstBobChannelLink, err := hopNetwork.createChannelLink(bobServer,
aliceServer, firstBobChannel, bobDecoder)
if err != nil {
t.Fatal(err)
}
secondBobChannelLink, err := hopNetwork.createChannelLink(bobServer,
carolServer, secondBobChannel, bobDecoder)
if err != nil {
t.Fatal(err)
}
carolChannelLink, err := hopNetwork.createChannelLink(carolServer,
bobServer, carolChannel, carolDecoder)
if err != nil {
t.Fatal(err)
}
return &threeHopNetwork{
aliceServer: aliceServer,
aliceChannelLink: aliceChannelLink.(*channelLink),
aliceOnionDecoder: aliceDecoder,
bobServer: bobServer,
firstBobChannelLink: firstBobChannelLink.(*channelLink),
secondBobChannelLink: secondBobChannelLink.(*channelLink),
bobOnionDecoder: bobDecoder,
carolServer: carolServer,
carolChannelLink: carolChannelLink.(*channelLink),
carolOnionDecoder: carolDecoder,
hopNetwork: *hopNetwork,
}
}
// serverOption is a function which alters the three servers created for
// a three hop network to allow custom settings on each server.
type serverOption func(aliceServer, bobServer, carolServer *mockServer)
// serverOptionWithHtlcNotifier is a functional option for the creation of
// three hop network servers which allows setting of htlc notifiers.
// Note that these notifiers should be started and stopped by the calling
// function.
func serverOptionWithHtlcNotifier(alice, bob,
carol *HtlcNotifier) serverOption {
return func(aliceServer, bobServer, carolServer *mockServer) {
aliceServer.htlcSwitch.cfg.HtlcNotifier = alice
bobServer.htlcSwitch.cfg.HtlcNotifier = bob
carolServer.htlcSwitch.cfg.HtlcNotifier = carol
}
}
// serverOptionRejectHtlc is the functional option for setting the reject
// htlc config option in each server's switch.
func serverOptionRejectHtlc(alice, bob, carol bool) serverOption {
return func(aliceServer, bobServer, carolServer *mockServer) {
aliceServer.htlcSwitch.cfg.RejectHTLC = alice
bobServer.htlcSwitch.cfg.RejectHTLC = bob
carolServer.htlcSwitch.cfg.RejectHTLC = carol
}
}
// createMirroredChannel creates two LightningChannel objects which represent
// the state machines on either side of a single channel between alice and bob.
func createMirroredChannel(t *testing.T, aliceToBob,
bobToAlice btcutil.Amount) (*testLightningChannel,
*testLightningChannel, error) {
_, _, firstChanID, _ := genIDs()
// Create lightning channels between Alice<->Bob for Alice and Bob
alice, bob, err := createTestChannel(t, alicePrivKey, bobPrivKey,
aliceToBob, bobToAlice, 0, 0, firstChanID,
)
if err != nil {
return nil, nil, errors.Errorf("unable to create "+
"alice<->bob channel: %v", err)
}
return alice, bob, nil
}
// hopNetwork is the base struct for two and three hop networks
type hopNetwork struct {
feeEstimator *mockFeeEstimator
globalPolicy models.ForwardingPolicy
obfuscator hop.ErrorEncrypter
defaultDelta uint32
}
func newHopNetwork() *hopNetwork {
defaultDelta := uint32(6)
globalPolicy := models.ForwardingPolicy{
MinHTLCOut: lnwire.NewMSatFromSatoshis(5),
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: defaultDelta,
}
obfuscator := NewMockObfuscator()
return &hopNetwork{
feeEstimator: newMockFeeEstimator(),
globalPolicy: globalPolicy,
obfuscator: obfuscator,
defaultDelta: defaultDelta,
}
}
func (h *hopNetwork) createChannelLink(server, peer *mockServer,
channel *lnwallet.LightningChannel,
decoder *mockIteratorDecoder) (ChannelLink, error) {
const (
fwdPkgTimeout = 15 * time.Second
minFeeUpdateTimeout = 30 * time.Minute
maxFeeUpdateTimeout = 40 * time.Minute
)
notifyUpdateChan := make(chan *contractcourt.ContractUpdate)
doneChan := make(chan struct{})
notifyContractUpdate := func(u *contractcourt.ContractUpdate) error {
select {
case notifyUpdateChan <- u:
case <-doneChan:
}
return nil
}
getAliases := func(
base lnwire.ShortChannelID) []lnwire.ShortChannelID {
return nil
}
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {
return server.htlcSwitch.ForwardPackets(linkQuit, packets...)
}
//nolint:ll
link := NewChannelLink(
ChannelLinkConfig{
BestHeight: server.htlcSwitch.BestHeight,
FwrdingPolicy: h.globalPolicy,
Peer: peer,
Circuits: server.htlcSwitch.CircuitModifier(),
ForwardPackets: forwardPackets,
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {
return h.obfuscator, lnwire.CodeNone
},
FetchLastChannelUpdate: mockGetChanUpdateMessage,
Registry: server.registry,
FeeEstimator: h.feeEstimator,
PreimageCache: server.pCache,
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
NotifyContractUpdate: notifyContractUpdate,
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchSize: 10,
BatchTicker: ticker.NewForce(testBatchTimeout),
FwdPkgGCTicker: ticker.NewForce(fwdPkgTimeout),
PendingCommitTicker: ticker.New(2 * time.Minute),
MinUpdateTimeout: minFeeUpdateTimeout,
MaxUpdateTimeout: maxFeeUpdateTimeout,
OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
OutgoingCltvRejectDelta: 3,
MaxOutgoingCltvExpiry: DefaultMaxOutgoingCltvExpiry,
MaxFeeAllocation: DefaultMaxLinkFeeAllocation,
MaxAnchorsCommitFeeRate: chainfee.SatPerKVByte(10 * 1000).FeePerKWeight(),
NotifyActiveLink: func(wire.OutPoint) {},
NotifyActiveChannel: func(wire.OutPoint) {},
NotifyInactiveChannel: func(wire.OutPoint) {},
NotifyInactiveLinkEvent: func(wire.OutPoint) {},
HtlcNotifier: server.htlcSwitch.cfg.HtlcNotifier,
GetAliases: getAliases,
ShouldFwdExpEndorsement: func() bool { return true },
},
channel,
)
if err := server.htlcSwitch.AddLink(link); err != nil {
return nil, fmt.Errorf("unable to add channel link: %w", err)
}
go func() {
if chanLink, ok := link.(*channelLink); ok {
for {
select {
case <-notifyUpdateChan:
case <-chanLink.cg.Done():
close(doneChan)
return
}
}
}
}()
return link, nil
}
// twoHopNetwork is used for managing the created cluster of 2 hops.
type twoHopNetwork struct {
hopNetwork
aliceServer *mockServer
aliceChannelLink *channelLink
bobServer *mockServer
bobChannelLink *channelLink
}
// newTwoHopNetwork function creates and starts the following topology and
// returns the control object to manage this cluster:
//
// alice bob
// server - <-connection-> - server
//
// | |
//
// alice htlc bob htlc
// switch switch
//
// | |
// | |
//
// alice bob
// channel link channel link.
func newTwoHopNetwork(t testing.TB,
aliceChannel, bobChannel *lnwallet.LightningChannel,
startingHeight uint32) *twoHopNetwork {
aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := bobChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork()
// Create two peers/servers.
aliceServer, err := newMockServer(
t, "alice", startingHeight, aliceDb, hopNetwork.defaultDelta,
)
require.NoError(t, err, "unable to create alice server")
bobServer, err := newMockServer(
t, "bob", startingHeight, bobDb, hopNetwork.defaultDelta,
)
require.NoError(t, err, "unable to create bob server")
// Create mock decoder instead of sphinx one in order to mock the route
// which htlc should follow.
aliceDecoder := newMockIteratorDecoder()
bobDecoder := newMockIteratorDecoder()
aliceChannelLink, err := hopNetwork.createChannelLink(
aliceServer, bobServer, aliceChannel, aliceDecoder,
)
if err != nil {
t.Fatal(err)
}
bobChannelLink, err := hopNetwork.createChannelLink(
bobServer, aliceServer, bobChannel, bobDecoder,
)
if err != nil {
t.Fatal(err)
}
n := &twoHopNetwork{
aliceServer: aliceServer,
aliceChannelLink: aliceChannelLink.(*channelLink),
bobServer: bobServer,
bobChannelLink: bobChannelLink.(*channelLink),
hopNetwork: *hopNetwork,
}
require.NoError(t, n.start())
t.Cleanup(n.stop)
return n
}
// start starts the two hop network alice,bob servers.
func (n *twoHopNetwork) start() error {
if err := n.aliceServer.Start(); err != nil {
return err
}
if err := n.bobServer.Start(); err != nil {
n.aliceServer.Stop()
return err
}
return waitLinksEligible(map[string]*channelLink{
"alice": n.aliceChannelLink,
"bob": n.bobChannelLink,
})
}
// stop stops nodes and cleanup its databases.
func (n *twoHopNetwork) stop() {
done := make(chan struct{})
go func() {
n.aliceServer.Stop()
done <- struct{}{}
}()
go func() {
n.bobServer.Stop()
done <- struct{}{}
}()
for i := 0; i < 2; i++ {
<-done
}
}
func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32, preimage lntypes.Preimage) chan error {
paymentErr := make(chan error, 1)
sender := sendingPeer.(*mockServer)
receiver := receivingPeer.(*mockServer)
// Generate route convert it to blob, and return next destination for
// htlc add request.
blob, err := generateRoute(hops...)
if err != nil {
paymentErr <- err
return paymentErr
}
rhash := preimage.Hash()
var payAddr [32]byte
if _, err := crand.Read(payAddr[:]); err != nil {
panic(err)
}
// Generate payment: invoice and htlc.
invoice, htlc, pid, err := generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob,
nil, rhash, payAddr,
)
if err != nil {
paymentErr <- err
return paymentErr
}
// Check who is last in the route and add invoice to server registry.
if err := receiver.registry.AddInvoice(
context.Background(), *invoice, rhash,
); err != nil {
paymentErr <- err
return paymentErr
}
// Send payment and expose err channel.
err = sender.htlcSwitch.SendHTLC(firstHop, pid, htlc)
if err != nil {
paymentErr <- err
return paymentErr
}
go func() {
resultChan, err := sender.htlcSwitch.GetAttemptResult(
pid, rhash, newMockDeobfuscator(),
)
if err != nil {
paymentErr <- err
return
}
result, ok := <-resultChan
if !ok {
paymentErr <- fmt.Errorf("shutting down")
return
}
if result.Error != nil {
paymentErr <- result.Error
return
}
paymentErr <- nil
}()
return paymentErr
}
// waitLinksEligible blocks until all links the provided name-to-link map are
// eligible to forward HTLCs.
func waitLinksEligible(links map[string]*channelLink) error {
return wait.NoError(func() error {
for name, link := range links {
if link.EligibleToForward() {
continue
}
return fmt.Errorf("%s channel link not eligible", name)
}
return nil
}, 3*time.Second)
}
// timeout implements a test level timeout.
func timeout() func() {
done := make(chan struct{})
go func() {
select {
case <-time.After(20 * time.Second):
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
panic("test timeout")
case <-done:
}
}()
return func() {
close(done)
}
}
package input
import (
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
)
// EmptyOutPoint is a zeroed outpoint.
var EmptyOutPoint wire.OutPoint
// Input represents an abstract UTXO which is to be spent using a sweeping
// transaction. The method provided give the caller all information needed to
// construct a valid input within a sweeping transaction to sweep this
// lingering UTXO.
type Input interface {
// OutPoint returns the reference to the output being spent, used to
// construct the corresponding transaction input.
OutPoint() wire.OutPoint
// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to
// make sure any presigned input is still valid by including the
// output.
RequiredTxOut() *wire.TxOut
// RequiredLockTime returns whether this input commits to a tx locktime
// that must be used in the transaction including it.
RequiredLockTime() (uint32, bool)
// WitnessType returns an enum specifying the type of witness that must
// be generated in order to spend this output.
WitnessType() WitnessType
// SignDesc returns a reference to a spendable output's sign
// descriptor, which is used during signing to compute a valid witness
// that spends this output.
SignDesc() *SignDescriptor
// CraftInputScript returns a valid set of input scripts allowing this
// output to be spent. The returns input scripts should target the
// input at location txIndex within the passed transaction. The input
// scripts generated by this method support spending p2wkh, p2wsh, and
// also nested p2sh outputs.
CraftInputScript(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (*Script, error)
// BlocksToMaturity returns the relative timelock, as a number of
// blocks, that must be built on top of the confirmation height before
// the output can be spent. For non-CSV locked inputs this is always
// zero.
BlocksToMaturity() uint32
// HeightHint returns the minimum height at which a confirmed spending
// tx can occur.
HeightHint() uint32
// UnconfParent returns information about a possibly unconfirmed parent
// tx.
UnconfParent() *TxInfo
// ResolutionBlob returns a special opaque blob to be used to
// sweep/resolve this input.
ResolutionBlob() fn.Option[tlv.Blob]
// Preimage returns the preimage for the input if it is an HTLC input.
Preimage() fn.Option[lntypes.Preimage]
}
// TxInfo describes properties of a parent tx that are relevant for CPFP.
type TxInfo struct {
// Fee is the fee of the tx.
Fee btcutil.Amount
// Weight is the weight of the tx.
Weight lntypes.WeightUnit
}
// String returns a human readable version of the tx info.
func (t *TxInfo) String() string {
return fmt.Sprintf("fee=%v, weight=%v", t.Fee, t.Weight)
}
// SignDetails is a struct containing information needed to resign certain
// inputs. It is used to re-sign 2nd level HTLC transactions that uses the
// SINGLE|ANYONECANPAY sighash type, as we have a signature provided by our
// peer, but we can aggregate multiple of these 2nd level transactions into a
// new transaction, that needs to be signed by us.
type SignDetails struct {
// SignDesc is the sign descriptor needed for us to sign the input.
SignDesc SignDescriptor
// PeerSig is the peer's signature for this input.
PeerSig Signature
// SigHashType is the sighash signed by the peer.
SigHashType txscript.SigHashType
}
type inputKit struct {
outpoint wire.OutPoint
witnessType WitnessType
signDesc SignDescriptor
heightHint uint32
blockToMaturity uint32
cltvExpiry uint32
// unconfParent contains information about a potential unconfirmed
// parent transaction.
unconfParent *TxInfo
// resolutionBlob is an optional blob that can be used to resolve an
// input.
resolutionBlob fn.Option[tlv.Blob]
}
// OutPoint returns the breached output's identifier that is to be included as
// a transaction input.
func (i *inputKit) OutPoint() wire.OutPoint {
return i.outpoint
}
// RequiredTxOut returns a nil for the base input type.
func (i *inputKit) RequiredTxOut() *wire.TxOut {
return nil
}
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it. This will be false for the
// base input type since we can re-sign for any lock time.
func (i *inputKit) RequiredLockTime() (uint32, bool) {
return i.cltvExpiry, i.cltvExpiry > 0
}
// WitnessType returns the type of witness that must be generated to spend the
// breached output.
func (i *inputKit) WitnessType() WitnessType {
return i.witnessType
}
// SignDesc returns the breached output's SignDescriptor, which is used during
// signing to compute the witness.
func (i *inputKit) SignDesc() *SignDescriptor {
return &i.signDesc
}
// HeightHint returns the minimum height at which a confirmed spending
// tx can occur.
func (i *inputKit) HeightHint() uint32 {
return i.heightHint
}
// BlocksToMaturity returns the relative timelock, as a number of blocks, that
// must be built on top of the confirmation height before the output can be
// spent. For non-CSV locked inputs this is always zero.
func (i *inputKit) BlocksToMaturity() uint32 {
return i.blockToMaturity
}
// Cpfp returns information about a possibly unconfirmed parent tx.
func (i *inputKit) UnconfParent() *TxInfo {
return i.unconfParent
}
// ResolutionBlob returns a special opaque blob to be used to sweep/resolve
// this input.
func (i *inputKit) ResolutionBlob() fn.Option[tlv.Blob] {
return i.resolutionBlob
}
// inputOpts contains options for constructing a new input.
type inputOpts struct {
// resolutionBlob is an optional blob that can be used to resolve an
// input.
resolutionBlob fn.Option[tlv.Blob]
}
// defaultInputOpts returns a new inputOpts with default values.
func defaultInputOpts() *inputOpts {
return &inputOpts{}
}
// InputOpt is a functional option that can be used to modify the default input
// options.
type InputOpt func(*inputOpts) //nolint:revive
// WithResolutionBlob is an option that can be used to set a resolution blob on
// for an input.
func WithResolutionBlob(b fn.Option[tlv.Blob]) InputOpt {
return func(o *inputOpts) {
o.resolutionBlob = b
}
}
// BaseInput contains all the information needed to sweep a basic
// output (CSV/CLTV/no time lock).
type BaseInput struct {
inputKit
}
// MakeBaseInput assembles a new BaseInput that can be used to construct a
// sweep transaction.
func MakeBaseInput(outpoint *wire.OutPoint, witnessType WitnessType,
signDescriptor *SignDescriptor, heightHint uint32,
unconfParent *TxInfo, opts ...InputOpt) BaseInput {
opt := defaultInputOpts()
for _, optF := range opts {
optF(opt)
}
return BaseInput{
inputKit{
outpoint: *outpoint,
witnessType: witnessType,
signDesc: *signDescriptor,
heightHint: heightHint,
unconfParent: unconfParent,
resolutionBlob: opt.resolutionBlob,
},
}
}
// NewBaseInput allocates and assembles a new *BaseInput that can be used to
// construct a sweep transaction.
func NewBaseInput(outpoint *wire.OutPoint, witnessType WitnessType,
signDescriptor *SignDescriptor, heightHint uint32,
opts ...InputOpt) *BaseInput {
input := MakeBaseInput(
outpoint, witnessType, signDescriptor, heightHint, nil, opts...,
)
return &input
}
// NewCsvInput assembles a new csv-locked input that can be used to
// construct a sweep transaction.
func NewCsvInput(outpoint *wire.OutPoint, witnessType WitnessType,
signDescriptor *SignDescriptor, heightHint uint32,
blockToMaturity uint32, opts ...InputOpt) *BaseInput {
input := MakeBaseInput(
outpoint, witnessType, signDescriptor, heightHint, nil, opts...,
)
input.blockToMaturity = blockToMaturity
return &input
}
// NewCsvInputWithCltv assembles a new csv and cltv locked input that can be
// used to construct a sweep transaction.
func NewCsvInputWithCltv(outpoint *wire.OutPoint, witnessType WitnessType,
signDescriptor *SignDescriptor, heightHint uint32,
csvDelay uint32, cltvExpiry uint32, opts ...InputOpt) *BaseInput {
input := MakeBaseInput(
outpoint, witnessType, signDescriptor, heightHint, nil, opts...,
)
input.blockToMaturity = csvDelay
input.cltvExpiry = cltvExpiry
return &input
}
// CraftInputScript returns a valid set of input scripts allowing this output
// to be spent. The returned input scripts should target the input at location
// txIndex within the passed transaction. The input scripts generated by this
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
func (bi *BaseInput) CraftInputScript(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher, txinIdx int) (*Script,
error) {
signDesc := bi.SignDesc()
signDesc.PrevOutputFetcher = prevOutputFetcher
witnessFunc := bi.witnessType.WitnessGenerator(signer, signDesc)
return witnessFunc(txn, hashCache, txinIdx)
}
// Preimage returns the preimage for the input if it is an HTLC input.
func (bi *BaseInput) Preimage() fn.Option[lntypes.Preimage] {
return fn.None[lntypes.Preimage]()
}
// HtlcSucceedInput constitutes a sweep input that needs a pre-image. The input
// is expected to reside on the commitment tx of the remote party and should
// not be a second level tx output.
type HtlcSucceedInput struct {
inputKit
preimage []byte
}
// MakeHtlcSucceedInput assembles a new redeem input that can be used to
// construct a sweep transaction.
func MakeHtlcSucceedInput(outpoint *wire.OutPoint,
signDescriptor *SignDescriptor, preimage []byte, heightHint,
blocksToMaturity uint32, opts ...InputOpt) HtlcSucceedInput {
input := MakeBaseInput(
outpoint, HtlcAcceptedRemoteSuccess, signDescriptor,
heightHint, nil, opts...,
)
input.blockToMaturity = blocksToMaturity
return HtlcSucceedInput{
inputKit: input.inputKit,
preimage: preimage,
}
}
// MakeTaprootHtlcSucceedInput creates a new HtlcSucceedInput that can be used
// to spend an HTLC output for a taproot channel on the remote party's
// commitment transaction.
func MakeTaprootHtlcSucceedInput(op *wire.OutPoint, signDesc *SignDescriptor,
preimage []byte, heightHint, blocksToMaturity uint32,
opts ...InputOpt) HtlcSucceedInput {
input := MakeBaseInput(
op, TaprootHtlcAcceptedRemoteSuccess, signDesc,
heightHint, nil, opts...,
)
input.blockToMaturity = blocksToMaturity
return HtlcSucceedInput{
inputKit: input.inputKit,
preimage: preimage,
}
}
// CraftInputScript returns a valid set of input scripts allowing this output
// to be spent. The returns input scripts should target the input at location
// txIndex within the passed transaction. The input scripts generated by this
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
func (h *HtlcSucceedInput) CraftInputScript(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher, txinIdx int) (*Script,
error) {
desc := h.signDesc
desc.SigHashes = hashCache
desc.InputIndex = txinIdx
desc.PrevOutputFetcher = prevOutputFetcher
isTaproot := txscript.IsPayToTaproot(desc.Output.PkScript)
var (
witness wire.TxWitness
err error
)
if isTaproot {
if desc.ControlBlock == nil {
return nil, fmt.Errorf("ctrl block must be set")
}
desc.SignMethod = TaprootScriptSpendSignMethod
witness, err = SenderHTLCScriptTaprootRedeem(
signer, &desc, txn, h.preimage, nil, nil,
)
} else {
witness, err = SenderHtlcSpendRedeem(
signer, &desc, txn, h.preimage,
)
}
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
}
// Preimage returns the preimage for the input if it is an HTLC input.
func (h *HtlcSucceedInput) Preimage() fn.Option[lntypes.Preimage] {
if len(h.preimage) == 0 {
return fn.None[lntypes.Preimage]()
}
return fn.Some(lntypes.Preimage(h.preimage))
}
// HtlcSecondLevelAnchorInput is an input type used to spend HTLC outputs
// using a re-signed second level transaction, either via the timeout or success
// paths.
type HtlcSecondLevelAnchorInput struct {
inputKit
// SignedTx is the original second level transaction signed by the
// channel peer.
SignedTx *wire.MsgTx
// createWitness creates a witness allowing the passed transaction to
// spend the input.
createWitness func(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (wire.TxWitness, error)
preimage []byte
}
// RequiredTxOut returns the tx out needed to be present on the sweep tx for
// the spend of the input to be valid.
func (i *HtlcSecondLevelAnchorInput) RequiredTxOut() *wire.TxOut {
return i.SignedTx.TxOut[0]
}
// RequiredLockTime returns the locktime needed for the sweep tx for the spend
// of the input to be valid. For a second level HTLC timeout this will be the
// CLTV expiry, for HTLC success it will be zero.
func (i *HtlcSecondLevelAnchorInput) RequiredLockTime() (uint32, bool) {
return i.SignedTx.LockTime, true
}
// CraftInputScript returns a valid set of input scripts allowing this output
// to be spent. The returns input scripts should target the input at location
// txIndex within the passed transaction. The input scripts generated by this
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
func (i *HtlcSecondLevelAnchorInput) CraftInputScript(signer Signer,
txn *wire.MsgTx, hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher, txinIdx int) (*Script,
error) {
witness, err := i.createWitness(
signer, txn, hashCache, prevOutputFetcher, txinIdx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
}
// Preimage returns the preimage for the input if it is an HTLC input.
func (i *HtlcSecondLevelAnchorInput) Preimage() fn.Option[lntypes.Preimage] {
if len(i.preimage) == 0 {
return fn.None[lntypes.Preimage]()
}
return fn.Some(lntypes.Preimage(i.preimage))
}
// MakeHtlcSecondLevelTimeoutAnchorInput creates an input allowing the sweeper
// to spend the HTLC output on our commit using the second level timeout
// transaction.
func MakeHtlcSecondLevelTimeoutAnchorInput(signedTx *wire.MsgTx,
signDetails *SignDetails, heightHint uint32,
opts ...InputOpt) HtlcSecondLevelAnchorInput {
// Spend an HTLC output on our local commitment tx using the
// 2nd timeout transaction.
createWitness := func(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (wire.TxWitness, error) {
desc := signDetails.SignDesc
desc.SigHashes = txscript.NewTxSigHashes(txn, prevOutputFetcher)
desc.InputIndex = txinIdx
desc.PrevOutputFetcher = prevOutputFetcher
return SenderHtlcSpendTimeout(
signDetails.PeerSig, signDetails.SigHashType, signer,
&desc, txn,
)
}
input := MakeBaseInput(
&signedTx.TxIn[0].PreviousOutPoint,
HtlcOfferedTimeoutSecondLevelInputConfirmed,
&signDetails.SignDesc, heightHint, nil, opts...,
)
input.blockToMaturity = 1
return HtlcSecondLevelAnchorInput{
inputKit: input.inputKit,
SignedTx: signedTx,
createWitness: createWitness,
}
}
// MakeHtlcSecondLevelTimeoutTaprootInput creates an input that allows the
// sweeper to spend an HTLC output to the second level on our commitment
// transaction. The sweeper is also able to generate witnesses on demand to
// sweep the second level HTLC aggregated with other transactions.
func MakeHtlcSecondLevelTimeoutTaprootInput(signedTx *wire.MsgTx,
signDetails *SignDetails,
heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput {
createWitness := func(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (wire.TxWitness, error) {
desc := signDetails.SignDesc
if desc.ControlBlock == nil {
return nil, fmt.Errorf("ctrl block must be set")
}
desc.SigHashes = txscript.NewTxSigHashes(txn, prevOutputFetcher)
desc.InputIndex = txinIdx
desc.PrevOutputFetcher = prevOutputFetcher
desc.SignMethod = TaprootScriptSpendSignMethod
return SenderHTLCScriptTaprootTimeout(
signDetails.PeerSig, signDetails.SigHashType, signer,
&desc, txn, nil, nil,
)
}
input := MakeBaseInput(
&signedTx.TxIn[0].PreviousOutPoint,
TaprootHtlcLocalOfferedTimeout,
&signDetails.SignDesc, heightHint, nil, opts...,
)
input.blockToMaturity = 1
return HtlcSecondLevelAnchorInput{
inputKit: input.inputKit,
SignedTx: signedTx,
createWitness: createWitness,
}
}
// MakeHtlcSecondLevelSuccessAnchorInput creates an input allowing the sweeper
// to spend the HTLC output on our commit using the second level success
// transaction.
func MakeHtlcSecondLevelSuccessAnchorInput(signedTx *wire.MsgTx,
signDetails *SignDetails, preimage lntypes.Preimage,
heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput {
// Spend an HTLC output on our local commitment tx using the 2nd
// success transaction.
createWitness := func(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (wire.TxWitness, error) {
desc := signDetails.SignDesc
desc.SigHashes = hashCache
desc.InputIndex = txinIdx
desc.PrevOutputFetcher = prevOutputFetcher
return ReceiverHtlcSpendRedeem(
signDetails.PeerSig, signDetails.SigHashType,
preimage[:], signer, &desc, txn,
)
}
input := MakeBaseInput(
&signedTx.TxIn[0].PreviousOutPoint,
HtlcAcceptedSuccessSecondLevelInputConfirmed,
&signDetails.SignDesc, heightHint, nil, opts...,
)
input.blockToMaturity = 1
return HtlcSecondLevelAnchorInput{
SignedTx: signedTx,
inputKit: input.inputKit,
createWitness: createWitness,
preimage: preimage[:],
}
}
// MakeHtlcSecondLevelSuccessTaprootInput creates an input that allows the
// sweeper to spend an HTLC output to the second level on our taproot
// commitment transaction.
func MakeHtlcSecondLevelSuccessTaprootInput(signedTx *wire.MsgTx,
signDetails *SignDetails, preimage lntypes.Preimage,
heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput {
createWitness := func(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (wire.TxWitness, error) {
desc := signDetails.SignDesc
if desc.ControlBlock == nil {
return nil, fmt.Errorf("ctrl block must be set")
}
desc.SigHashes = txscript.NewTxSigHashes(txn, prevOutputFetcher)
desc.InputIndex = txinIdx
desc.PrevOutputFetcher = prevOutputFetcher
desc.SignMethod = TaprootScriptSpendSignMethod
return ReceiverHTLCScriptTaprootRedeem(
signDetails.PeerSig, signDetails.SigHashType,
preimage[:], signer, &desc, txn, nil, nil,
)
}
input := MakeBaseInput(
&signedTx.TxIn[0].PreviousOutPoint,
TaprootHtlcAcceptedLocalSuccess,
&signDetails.SignDesc, heightHint, nil, opts...,
)
input.blockToMaturity = 1
return HtlcSecondLevelAnchorInput{
inputKit: input.inputKit,
SignedTx: signedTx,
createWitness: createWitness,
preimage: preimage[:],
}
}
// Compile-time constraints to ensure each input struct implement the Input
// interface.
var _ Input = (*BaseInput)(nil)
var _ Input = (*HtlcSucceedInput)(nil)
var _ Input = (*HtlcSecondLevelAnchorInput)(nil)
package input
import (
"crypto/sha256"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/mock"
)
// MockInput implements the `Input` interface and is used by other packages for
// mock testing.
type MockInput struct {
mock.Mock
}
// Compile time assertion that MockInput implements Input.
var _ Input = (*MockInput)(nil)
// OutPoint returns the reference to the output being spent, used to construct
// the corresponding transaction input.
func (m *MockInput) OutPoint() wire.OutPoint {
args := m.Called()
op := args.Get(0)
return op.(wire.OutPoint)
}
// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to make
// sure any presigned input is still valid by including the output.
func (m *MockInput) RequiredTxOut() *wire.TxOut {
args := m.Called()
txOut := args.Get(0)
if txOut == nil {
return nil
}
return txOut.(*wire.TxOut)
}
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it.
func (m *MockInput) RequiredLockTime() (uint32, bool) {
args := m.Called()
return args.Get(0).(uint32), args.Bool(1)
}
// WitnessType returns an enum specifying the type of witness that must be
// generated in order to spend this output.
func (m *MockInput) WitnessType() WitnessType {
args := m.Called()
wt := args.Get(0)
if wt == nil {
return nil
}
return wt.(WitnessType)
}
// SignDesc returns a reference to a spendable output's sign descriptor, which
// is used during signing to compute a valid witness that spends this output.
func (m *MockInput) SignDesc() *SignDescriptor {
args := m.Called()
sd := args.Get(0)
if sd == nil {
return nil
}
return sd.(*SignDescriptor)
}
// CraftInputScript returns a valid set of input scripts allowing this output
// to be spent. The returns input scripts should target the input at location
// txIndex within the passed transaction. The input scripts generated by this
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
func (m *MockInput) CraftInputScript(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (*Script, error) {
args := m.Called(signer, txn, hashCache, prevOutputFetcher, txinIdx)
s := args.Get(0)
if s == nil {
return nil, args.Error(1)
}
return s.(*Script), args.Error(1)
}
// BlocksToMaturity returns the relative timelock, as a number of blocks, that
// must be built on top of the confirmation height before the output can be
// spent. For non-CSV locked inputs this is always zero.
func (m *MockInput) BlocksToMaturity() uint32 {
args := m.Called()
return args.Get(0).(uint32)
}
// HeightHint returns the minimum height at which a confirmed spending tx can
// occur.
func (m *MockInput) HeightHint() uint32 {
args := m.Called()
return args.Get(0).(uint32)
}
// UnconfParent returns information about a possibly unconfirmed parent tx.
func (m *MockInput) UnconfParent() *TxInfo {
args := m.Called()
info := args.Get(0)
if info == nil {
return nil
}
return info.(*TxInfo)
}
func (m *MockInput) ResolutionBlob() fn.Option[tlv.Blob] {
args := m.Called()
info := args.Get(0)
if info == nil {
return fn.None[tlv.Blob]()
}
return info.(fn.Option[tlv.Blob])
}
func (m *MockInput) Preimage() fn.Option[lntypes.Preimage] {
args := m.Called()
info := args.Get(0)
if info == nil {
return fn.None[lntypes.Preimage]()
}
return info.(fn.Option[lntypes.Preimage])
}
// MockWitnessType implements the `WitnessType` interface and is used by other
// packages for mock testing.
type MockWitnessType struct {
mock.Mock
}
// Compile time assertion that MockWitnessType implements WitnessType.
var _ WitnessType = (*MockWitnessType)(nil)
// String returns a human readable version of the WitnessType.
func (m *MockWitnessType) String() string {
args := m.Called()
return args.String(0)
}
// WitnessGenerator will return a WitnessGenerator function that an output uses
// to generate the witness and optionally the sigScript for a sweep
// transaction.
func (m *MockWitnessType) WitnessGenerator(signer Signer,
descriptor *SignDescriptor) WitnessGenerator {
args := m.Called()
return args.Get(0).(WitnessGenerator)
}
// SizeUpperBound returns the maximum length of the witness of this WitnessType
// if it would be included in a tx. It also returns if the output itself is a
// nested p2sh output, if so then we need to take into account the extra
// sigScript data size.
func (m *MockWitnessType) SizeUpperBound() (lntypes.WeightUnit, bool, error) {
args := m.Called()
return args.Get(0).(lntypes.WeightUnit), args.Bool(1), args.Error(2)
}
// AddWeightEstimation adds the estimated size of the witness in bytes to the
// given weight estimator.
func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error {
args := m.Called()
return args.Error(0)
}
// MockInputSigner is a mock implementation of the Signer interface.
type MockInputSigner struct {
mock.Mock
}
// Compile-time constraint to ensure MockInputSigner implements Signer.
var _ Signer = (*MockInputSigner)(nil)
// SignOutputRaw generates a signature for the passed transaction according to
// the data within the passed SignDescriptor.
func (m *MockInputSigner) SignOutputRaw(tx *wire.MsgTx,
signDesc *SignDescriptor) (Signature, error) {
args := m.Called(tx, signDesc)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(Signature), args.Error(1)
}
// ComputeInputScript generates a complete InputIndex for the passed
// transaction with the signature as defined within the passed SignDescriptor.
func (m *MockInputSigner) ComputeInputScript(tx *wire.MsgTx,
signDesc *SignDescriptor) (*Script, error) {
args := m.Called(tx, signDesc)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*Script), args.Error(1)
}
// MuSig2CreateSession creates a new MuSig2 signing session using the local key
// identified by the key locator.
func (m *MockInputSigner) MuSig2CreateSession(version MuSig2Version,
locator keychain.KeyLocator, pubkey []*btcec.PublicKey,
tweak *MuSig2Tweaks, pubNonces [][musig2.PubNonceSize]byte,
nonces *musig2.Nonces) (*MuSig2SessionInfo, error) {
args := m.Called(version, locator, pubkey, tweak, pubNonces, nonces)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*MuSig2SessionInfo), args.Error(1)
}
// MuSig2RegisterNonces registers one or more public nonces of other signing
// participants for a session identified by its ID.
func (m *MockInputSigner) MuSig2RegisterNonces(versio MuSig2SessionID,
pubNonces [][musig2.PubNonceSize]byte) (bool, error) {
args := m.Called(versio, pubNonces)
if args.Get(0) == nil {
return false, args.Error(1)
}
return args.Bool(0), args.Error(1)
}
// MuSig2Sign creates a partial signature using the local signing key that was
// specified when the session was created.
func (m *MockInputSigner) MuSig2Sign(sessionID MuSig2SessionID,
msg [sha256.Size]byte, withSortedKeys bool) (
*musig2.PartialSignature, error) {
args := m.Called(sessionID, msg, withSortedKeys)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*musig2.PartialSignature), args.Error(1)
}
// MuSig2CombineSig combines the given partial signature(s) with the local one,
// if it already exists.
func (m *MockInputSigner) MuSig2CombineSig(sessionID MuSig2SessionID,
partialSig []*musig2.PartialSignature) (
*schnorr.Signature, bool, error) {
args := m.Called(sessionID, partialSig)
if args.Get(0) == nil {
return nil, false, args.Error(2)
}
return args.Get(0).(*schnorr.Signature), args.Bool(1), args.Error(2)
}
// MuSig2Cleanup removes a session from memory to free up resources.
func (m *MockInputSigner) MuSig2Cleanup(sessionID MuSig2SessionID) error {
args := m.Called(sessionID)
return args.Error(0)
}
package input
import (
"bytes"
"crypto/sha256"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/internal/musig2v040"
"github.com/lightningnetwork/lnd/keychain"
)
// MuSig2Version is a type that defines the different versions of the MuSig2
// as defined in the BIP draft:
// (https://github.com/jonasnick/bips/blob/musig2/bip-musig2.mediawiki)
type MuSig2Version uint8
const (
// MuSig2Version040 is version 0.4.0 of the MuSig2 BIP draft. This will
// use the lnd internal/musig2v040 package.
MuSig2Version040 MuSig2Version = 0
// MuSig2Version100RC2 is version 1.0.0rc2 of the MuSig2 BIP draft. This
// uses the github.com/btcsuite/btcd/btcec/v2/schnorr/musig2 package
// at git tag `btcec/v2.3.1`.
MuSig2Version100RC2 MuSig2Version = 1
)
const (
// MuSig2PartialSigSize is the size of a MuSig2 partial signature.
// Because a partial signature is just the s value, this corresponds to
// the length of a scalar.
MuSig2PartialSigSize = 32
)
// MuSig2SessionID is a type for a session ID that is just a hash of the MuSig2
// combined key and the local public nonces.
type MuSig2SessionID [sha256.Size]byte
// MuSig2Signer is an interface that declares all methods that a MuSig2
// compatible signer needs to implement.
type MuSig2Signer interface {
// MuSig2CreateSession creates a new MuSig2 signing session using the
// local key identified by the key locator. The complete list of all
// public keys of all signing parties must be provided, including the
// public key of the local signing key. If nonces of other parties are
// already known, they can be submitted as well to reduce the number of
// method calls necessary later on.
//
// The localNonces field is optional. If it is set, then the specified
// nonces will be used instead of generating from scratch. This is
// useful in instances where the nonces are generated ahead of time
// before the set of signers is known.
MuSig2CreateSession(MuSig2Version, keychain.KeyLocator,
[]*btcec.PublicKey, *MuSig2Tweaks, [][musig2.PubNonceSize]byte,
*musig2.Nonces) (*MuSig2SessionInfo, error)
// MuSig2RegisterNonces registers one or more public nonces of other
// signing participants for a session identified by its ID. This method
// returns true once we have all nonces for all other signing
// participants.
MuSig2RegisterNonces(MuSig2SessionID,
[][musig2.PubNonceSize]byte) (bool, error)
// MuSig2Sign creates a partial signature using the local signing key
// that was specified when the session was created. This can only be
// called when all public nonces of all participants are known and have
// been registered with the session. If this node isn't responsible for
// combining all the partial signatures, then the cleanup parameter
// should be set, indicating that the session can be removed from memory
// once the signature was produced.
MuSig2Sign(MuSig2SessionID, [sha256.Size]byte,
bool) (*musig2.PartialSignature, error)
// MuSig2CombineSig combines the given partial signature(s) with the
// local one, if it already exists. Once a partial signature of all
// participants is registered, the final signature will be combined and
// returned.
MuSig2CombineSig(MuSig2SessionID,
[]*musig2.PartialSignature) (*schnorr.Signature, bool, error)
// MuSig2Cleanup removes a session from memory to free up resources.
MuSig2Cleanup(MuSig2SessionID) error
}
// MuSig2Context is an interface that is an abstraction over the MuSig2 signing
// context. This interface does not contain all of the methods the underlying
// implementations have because those use package specific types which cannot
// easily be made compatible. Those calls (such as NewSession) are implemented
// in this package instead and do the necessary type switch (see
// MuSig2CreateContext).
type MuSig2Context interface {
// SigningKeys returns the set of keys used for signing.
SigningKeys() []*btcec.PublicKey
// CombinedKey returns the combined public key that will be used to
// generate multi-signatures against.
CombinedKey() (*btcec.PublicKey, error)
// TaprootInternalKey returns the internal taproot key, which is the
// aggregated key _before_ the tweak is applied. If a taproot tweak was
// specified, then CombinedKey() will return the fully tweaked output
// key, with this method returning the internal key. If a taproot tweak
// wasn't specified, then this method will return an error.
TaprootInternalKey() (*btcec.PublicKey, error)
}
// MuSig2Session is an interface that is an abstraction over the MuSig2 signing
// session. This interface does not contain all of the methods the underlying
// implementations have because those use package specific types which cannot
// easily be made compatible. Those calls (such as CombineSig or Sign) are
// implemented in this package instead and do the necessary type switch (see
// MuSig2CombineSig or MuSig2Sign).
type MuSig2Session interface {
// FinalSig returns the final combined multi-signature, if present.
FinalSig() *schnorr.Signature
// PublicNonce returns the public nonce for a signer. This should be
// sent to other parties before signing begins, so they can compute the
// aggregated public nonce.
PublicNonce() [musig2.PubNonceSize]byte
// NumRegisteredNonces returns the total number of nonces that have been
// registered so far.
NumRegisteredNonces() int
// RegisterPubNonce should be called for each public nonce from the set
// of signers. This method returns true once all the public nonces have
// been accounted for.
RegisterPubNonce(nonce [musig2.PubNonceSize]byte) (bool, error)
}
// MuSig2SessionInfo is a struct for keeping track of a signing session
// information in memory.
type MuSig2SessionInfo struct {
// SessionID is the wallet's internal unique ID of this session. The ID
// is the hash over the combined public key and the local public nonces.
SessionID [32]byte
// Version is the version of the MuSig2 BIP this signing session is
// using.
Version MuSig2Version
// PublicNonce contains the public nonce of the local signer session.
PublicNonce [musig2.PubNonceSize]byte
// CombinedKey is the combined public key with all tweaks applied to it.
CombinedKey *btcec.PublicKey
// TaprootTweak indicates whether a taproot tweak (BIP-0086 or script
// path) was used. The TaprootInternalKey will only be set if this is
// set to true.
TaprootTweak bool
// TaprootInternalKey is the raw combined public key without any tweaks
// applied to it. This is only set if TaprootTweak is true.
TaprootInternalKey *btcec.PublicKey
// HaveAllNonces indicates whether this session already has all nonces
// of all other signing participants registered.
HaveAllNonces bool
// HaveAllSigs indicates whether this session already has all partial
// signatures of all other signing participants registered.
HaveAllSigs bool
}
// MuSig2Tweaks is a struct that contains all tweaks that can be applied to a
// MuSig2 combined public key.
type MuSig2Tweaks struct {
// GenericTweaks is a list of normal tweaks to apply to the combined
// public key (and to the private key when signing).
GenericTweaks []musig2.KeyTweakDesc
// TaprootBIP0086Tweak indicates that the final key should use the
// taproot tweak as defined in BIP 341, with the BIP 86 modification:
// outputKey = internalKey + h_tapTweak(internalKey)*G.
// In this case, the aggregated key before the tweak will be used as the
// internal key. If this is set to true then TaprootTweak will be
// ignored.
TaprootBIP0086Tweak bool
// TaprootTweak specifies that the final key should use the taproot
// tweak as defined in BIP 341:
// outputKey = internalKey + h_tapTweak(internalKey || scriptRoot).
// In this case, the aggregated key before the tweak will be used as the
// internal key. Will be ignored if TaprootBIP0086Tweak is set to true.
TaprootTweak []byte
}
// HasTaprootTweak returns true if either a taproot BIP0086 tweak or a taproot
// script root tweak is set.
func (t *MuSig2Tweaks) HasTaprootTweak() bool {
return t.TaprootBIP0086Tweak || len(t.TaprootTweak) > 0
}
// ToContextOptions converts the tweak descriptor to context options.
func (t *MuSig2Tweaks) ToContextOptions() []musig2.ContextOption {
var tweakOpts []musig2.ContextOption
if len(t.GenericTweaks) > 0 {
tweakOpts = append(tweakOpts, musig2.WithTweakedContext(
t.GenericTweaks...,
))
}
// The BIP0086 tweak and the taproot script tweak are mutually
// exclusive.
if t.TaprootBIP0086Tweak {
tweakOpts = append(tweakOpts, musig2.WithBip86TweakCtx())
} else if len(t.TaprootTweak) > 0 {
tweakOpts = append(tweakOpts, musig2.WithTaprootTweakCtx(
t.TaprootTweak,
))
}
return tweakOpts
}
// ToV040ContextOptions converts the tweak descriptor to v0.4.0 context options.
func (t *MuSig2Tweaks) ToV040ContextOptions() []musig2v040.ContextOption {
var tweakOpts []musig2v040.ContextOption
if len(t.GenericTweaks) > 0 {
genericTweaksCopy := make(
[]musig2v040.KeyTweakDesc, len(t.GenericTweaks),
)
for idx := range t.GenericTweaks {
genericTweaksCopy[idx] = musig2v040.KeyTweakDesc{
Tweak: t.GenericTweaks[idx].Tweak,
IsXOnly: t.GenericTweaks[idx].IsXOnly,
}
}
tweakOpts = append(tweakOpts, musig2v040.WithTweakedContext(
genericTweaksCopy...,
))
}
// The BIP0086 tweak and the taproot script tweak are mutually
// exclusive.
if t.TaprootBIP0086Tweak {
tweakOpts = append(tweakOpts, musig2v040.WithBip86TweakCtx())
} else if len(t.TaprootTweak) > 0 {
tweakOpts = append(tweakOpts, musig2v040.WithTaprootTweakCtx(
t.TaprootTweak,
))
}
return tweakOpts
}
// MuSig2ParsePubKeys parses a list of raw public keys as the signing keys of a
// MuSig2 signing session.
func MuSig2ParsePubKeys(bipVersion MuSig2Version,
rawPubKeys [][]byte) ([]*btcec.PublicKey, error) {
allSignerPubKeys := make([]*btcec.PublicKey, len(rawPubKeys))
if len(rawPubKeys) < 2 {
return nil, fmt.Errorf("need at least two signing public keys")
}
for idx, pubKeyBytes := range rawPubKeys {
switch bipVersion {
case MuSig2Version040:
pubKey, err := schnorr.ParsePubKey(pubKeyBytes)
if err != nil {
return nil, fmt.Errorf("error parsing signer "+
"public key %d for v0.4.0 (x-only "+
"format): %v", idx, err)
}
allSignerPubKeys[idx] = pubKey
case MuSig2Version100RC2:
pubKey, err := btcec.ParsePubKey(pubKeyBytes)
if err != nil {
return nil, fmt.Errorf("error parsing signer "+
"public key %d for v1.0.0rc2 ("+
"compressed format): %v", idx, err)
}
allSignerPubKeys[idx] = pubKey
default:
return nil, fmt.Errorf("unknown MuSig2 version: <%d>",
bipVersion)
}
}
return allSignerPubKeys, nil
}
// MuSig2CombineKeys combines the given set of public keys into a single
// combined MuSig2 combined public key, applying the given tweaks.
func MuSig2CombineKeys(bipVersion MuSig2Version,
allSignerPubKeys []*btcec.PublicKey, sortKeys bool,
tweaks *MuSig2Tweaks) (*musig2.AggregateKey, error) {
switch bipVersion {
case MuSig2Version040:
return combineKeysV040(allSignerPubKeys, sortKeys, tweaks)
case MuSig2Version100RC2:
return combineKeysV100RC2(allSignerPubKeys, sortKeys, tweaks)
default:
return nil, fmt.Errorf("unknown MuSig2 version: <%d>",
bipVersion)
}
}
// combineKeysV100rc1 implements the MuSigCombineKeys logic for the MuSig2 BIP
// draft version 1.0.0rc2.
func combineKeysV100RC2(allSignerPubKeys []*btcec.PublicKey, sortKeys bool,
tweaks *MuSig2Tweaks) (*musig2.AggregateKey, error) {
// Convert the tweak options into the appropriate MuSig2 API functional
// options.
var keyAggOpts []musig2.KeyAggOption
switch {
case tweaks.TaprootBIP0086Tweak:
keyAggOpts = append(keyAggOpts, musig2.WithBIP86KeyTweak())
case len(tweaks.TaprootTweak) > 0:
keyAggOpts = append(keyAggOpts, musig2.WithTaprootKeyTweak(
tweaks.TaprootTweak,
))
case len(tweaks.GenericTweaks) > 0:
keyAggOpts = append(keyAggOpts, musig2.WithKeyTweaks(
tweaks.GenericTweaks...,
))
}
// Then we'll use this information to compute the aggregated public key.
combinedKey, _, _, err := musig2.AggregateKeys(
allSignerPubKeys, sortKeys, keyAggOpts...,
)
return combinedKey, err
}
// combineKeysV040 implements the MuSigCombineKeys logic for the MuSig2 BIP
// draft version 0.4.0.
func combineKeysV040(allSignerPubKeys []*btcec.PublicKey, sortKeys bool,
tweaks *MuSig2Tweaks) (*musig2.AggregateKey, error) {
// Convert the tweak options into the appropriate MuSig2 API functional
// options.
var keyAggOpts []musig2v040.KeyAggOption
switch {
case tweaks.TaprootBIP0086Tweak:
keyAggOpts = append(keyAggOpts, musig2v040.WithBIP86KeyTweak())
case len(tweaks.TaprootTweak) > 0:
keyAggOpts = append(keyAggOpts, musig2v040.WithTaprootKeyTweak(
tweaks.TaprootTweak,
))
case len(tweaks.GenericTweaks) > 0:
genericTweaksCopy := make(
[]musig2v040.KeyTweakDesc, len(tweaks.GenericTweaks),
)
for idx := range tweaks.GenericTweaks {
genericTweaksCopy[idx] = musig2v040.KeyTweakDesc{
Tweak: tweaks.GenericTweaks[idx].Tweak,
IsXOnly: tweaks.GenericTweaks[idx].IsXOnly,
}
}
keyAggOpts = append(keyAggOpts, musig2v040.WithKeyTweaks(
genericTweaksCopy...,
))
}
// Then we'll use this information to compute the aggregated public key.
combinedKey, _, _, err := musig2v040.AggregateKeys(
allSignerPubKeys, sortKeys, keyAggOpts...,
)
// Copy the result back into the default version's native type.
return &musig2.AggregateKey{
FinalKey: combinedKey.FinalKey,
PreTweakedKey: combinedKey.PreTweakedKey,
}, err
}
// MuSig2CreateContext creates a new MuSig2 signing context.
func MuSig2CreateContext(bipVersion MuSig2Version, privKey *btcec.PrivateKey,
allSignerPubKeys []*btcec.PublicKey, tweaks *MuSig2Tweaks,
localNonces *musig2.Nonces,
) (MuSig2Context, MuSig2Session, error) {
switch bipVersion {
case MuSig2Version040:
return createContextV040(
privKey, allSignerPubKeys, tweaks, localNonces,
)
case MuSig2Version100RC2:
return createContextV100RC2(
privKey, allSignerPubKeys, tweaks, localNonces,
)
default:
return nil, nil, fmt.Errorf("unknown MuSig2 version: <%d>",
bipVersion)
}
}
// createContextV100RC2 implements the MuSig2CreateContext logic for the MuSig2
// BIP draft version 1.0.0rc2.
func createContextV100RC2(privKey *btcec.PrivateKey,
allSignerPubKeys []*btcec.PublicKey, tweaks *MuSig2Tweaks,
localNonces *musig2.Nonces,
) (*musig2.Context, *musig2.Session, error) {
// The context keeps track of all signing keys and our local key.
allOpts := append(
[]musig2.ContextOption{
musig2.WithKnownSigners(allSignerPubKeys),
},
tweaks.ToContextOptions()...,
)
muSigContext, err := musig2.NewContext(privKey, true, allOpts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating MuSig2 signing "+
"context: %v", err)
}
var sessionOpts []musig2.SessionOption
if localNonces != nil {
sessionOpts = append(
sessionOpts, musig2.WithPreGeneratedNonce(localNonces),
)
}
muSigSession, err := muSigContext.NewSession(sessionOpts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating MuSig2 signing "+
"session: %v", err)
}
return muSigContext, muSigSession, nil
}
// createContextV040 implements the MuSig2CreateContext logic for the MuSig2 BIP
// draft version 0.4.0.
func createContextV040(privKey *btcec.PrivateKey,
allSignerPubKeys []*btcec.PublicKey, tweaks *MuSig2Tweaks,
_ *musig2.Nonces,
) (*musig2v040.Context, *musig2v040.Session, error) {
// The context keeps track of all signing keys and our local key.
allOpts := append(
[]musig2v040.ContextOption{
musig2v040.WithKnownSigners(allSignerPubKeys),
},
tweaks.ToV040ContextOptions()...,
)
muSigContext, err := musig2v040.NewContext(privKey, true, allOpts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating MuSig2 signing "+
"context: %v", err)
}
muSigSession, err := muSigContext.NewSession()
if err != nil {
return nil, nil, fmt.Errorf("error creating MuSig2 signing "+
"session: %v", err)
}
return muSigContext, muSigSession, nil
}
// MuSig2Sign calls the Sign() method on the given versioned signing session and
// returns the result in the most recent version of the MuSig2 API.
func MuSig2Sign(session MuSig2Session, msg [32]byte,
withSortedKeys bool) (*musig2.PartialSignature, error) {
switch s := session.(type) {
case *musig2.Session:
var opts []musig2.SignOption
if withSortedKeys {
opts = append(opts, musig2.WithSortedKeys())
}
partialSig, err := s.Sign(msg, opts...)
if err != nil {
return nil, fmt.Errorf("error signing with local key: "+
"%v", err)
}
return partialSig, nil
case *musig2v040.Session:
var opts []musig2v040.SignOption
if withSortedKeys {
opts = append(opts, musig2v040.WithSortedKeys())
}
partialSig, err := s.Sign(msg, opts...)
if err != nil {
return nil, fmt.Errorf("error signing with local key: "+
"%v", err)
}
return &musig2.PartialSignature{
S: partialSig.S,
R: partialSig.R,
}, nil
default:
return nil, fmt.Errorf("invalid session type <%T>", s)
}
}
// MuSig2CombineSig calls the CombineSig() method on the given versioned signing
// session and returns the result in the most recent version of the MuSig2 API.
func MuSig2CombineSig(session MuSig2Session,
otherPartialSig *musig2.PartialSignature) (bool, error) {
switch s := session.(type) {
case *musig2.Session:
haveAllSigs, err := s.CombineSig(otherPartialSig)
if err != nil {
return false, fmt.Errorf("error combining partial "+
"signature: %v", err)
}
return haveAllSigs, nil
case *musig2v040.Session:
haveAllSigs, err := s.CombineSig(&musig2v040.PartialSignature{
S: otherPartialSig.S,
R: otherPartialSig.R,
})
if err != nil {
return false, fmt.Errorf("error combining partial "+
"signature: %v", err)
}
return haveAllSigs, nil
default:
return false, fmt.Errorf("invalid session type <%T>", s)
}
}
// NewMuSig2SessionID returns the unique ID of a MuSig2 session by using the
// combined key and the local public nonces and hashing that data.
func NewMuSig2SessionID(combinedKey *btcec.PublicKey,
publicNonces [musig2.PubNonceSize]byte) MuSig2SessionID {
// We hash the data to save some bytes in memory.
hash := sha256.New()
_, _ = hash.Write(combinedKey.SerializeCompressed())
_, _ = hash.Write(publicNonces[:])
id := MuSig2SessionID{}
copy(id[:], hash.Sum(nil))
return id
}
// SerializePartialSignature encodes the partial signature to a fixed size byte
// array.
func SerializePartialSignature(
sig *musig2.PartialSignature) ([MuSig2PartialSigSize]byte, error) {
var (
buf bytes.Buffer
result [MuSig2PartialSigSize]byte
)
if err := sig.Encode(&buf); err != nil {
return result, fmt.Errorf("error encoding partial signature: "+
"%v", err)
}
if buf.Len() != MuSig2PartialSigSize {
return result, fmt.Errorf("invalid partial signature length, "+
"got %d wanted %d", buf.Len(), MuSig2PartialSigSize)
}
copy(result[:], buf.Bytes())
return result, nil
}
// DeserializePartialSignature decodes a partial signature from a byte slice.
func DeserializePartialSignature(scalarBytes []byte) (*musig2.PartialSignature,
error) {
if len(scalarBytes) != MuSig2PartialSigSize {
return nil, fmt.Errorf("invalid partial signature length, got "+
"%d wanted %d", len(scalarBytes), MuSig2PartialSigSize)
}
sig := &musig2.PartialSignature{}
if err := sig.Decode(bytes.NewReader(scalarBytes)); err != nil {
return nil, fmt.Errorf("error decoding partial signature: %w",
err)
}
return sig, nil
}
package input
import (
"crypto/sha256"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/multimutex"
)
// MuSig2State is a struct that holds on to the internal signing session state
// of a MuSig2 session.
type MuSig2State struct {
// MuSig2SessionInfo is the associated meta information of the signing
// session.
MuSig2SessionInfo
// context is the signing context responsible for keeping track of the
// public keys involved in the signing process.
context MuSig2Context
// session is the signing session responsible for keeping track of the
// nonces and partial signatures involved in the signing process.
session MuSig2Session
}
// PrivKeyFetcher is used to fetch a private key that matches a given key desc.
type PrivKeyFetcher func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error)
// MusigSessionMusigSessionManager houses the state needed to manage concurrent
// musig sessions. Each session is identified by a unique session ID which is
// used by callers to interact with a given session.
type MusigSessionManager struct {
keyFetcher PrivKeyFetcher
sessionMtx *multimutex.Mutex[MuSig2SessionID]
musig2Sessions *lnutils.SyncMap[MuSig2SessionID, *MuSig2State]
}
// NewMusigSessionManager creates a new musig manager given an abstract key
// fetcher.
func NewMusigSessionManager(keyFetcher PrivKeyFetcher) *MusigSessionManager {
return &MusigSessionManager{
keyFetcher: keyFetcher,
musig2Sessions: &lnutils.SyncMap[
MuSig2SessionID, *MuSig2State,
]{},
sessionMtx: multimutex.NewMutex[MuSig2SessionID](),
}
}
// MuSig2CreateSession creates a new MuSig2 signing session using the local key
// identified by the key locator. The complete list of all public keys of all
// signing parties must be provided, including the public key of the local
// signing key. If nonces of other parties are already known, they can be
// submitted as well to reduce the number of method calls necessary later on.
//
// The set of sessionOpts are _optional_ and allow a caller to modify the
// generated sessions. As an example the local nonce might already be generated
// ahead of time.
func (m *MusigSessionManager) MuSig2CreateSession(bipVersion MuSig2Version,
keyLoc keychain.KeyLocator, allSignerPubKeys []*btcec.PublicKey,
tweaks *MuSig2Tweaks, otherSignerNonces [][musig2.PubNonceSize]byte,
localNonces *musig2.Nonces) (*MuSig2SessionInfo, error) {
// We need to derive the private key for signing. In the remote signing
// setup, this whole RPC call will be forwarded to the signing
// instance, which requires it to be stateful.
privKey, err := m.keyFetcher(&keychain.KeyDescriptor{
KeyLocator: keyLoc,
})
if err != nil {
return nil, fmt.Errorf("error deriving private key: %w", err)
}
// Create a signing context and session with the given private key and
// list of all known signer public keys.
musigContext, musigSession, err := MuSig2CreateContext(
bipVersion, privKey, allSignerPubKeys, tweaks, localNonces,
)
if err != nil {
return nil, fmt.Errorf("error creating signing context: %w",
err)
}
// Add all nonces we might've learned so far.
haveAllNonces := false
for _, otherSignerNonce := range otherSignerNonces {
haveAllNonces, err = musigSession.RegisterPubNonce(
otherSignerNonce,
)
if err != nil {
return nil, fmt.Errorf("error registering other "+
"signer public nonce: %v", err)
}
}
// Register the new session.
combinedKey, err := musigContext.CombinedKey()
if err != nil {
return nil, fmt.Errorf("error getting combined key: %w", err)
}
session := &MuSig2State{
MuSig2SessionInfo: MuSig2SessionInfo{
SessionID: NewMuSig2SessionID(
combinedKey, musigSession.PublicNonce(),
),
Version: bipVersion,
PublicNonce: musigSession.PublicNonce(),
CombinedKey: combinedKey,
TaprootTweak: tweaks.HasTaprootTweak(),
HaveAllNonces: haveAllNonces,
},
context: musigContext,
session: musigSession,
}
// The internal key is only calculated if we are using a taproot tweak
// and need to know it for a potential script spend.
if tweaks.HasTaprootTweak() {
internalKey, err := musigContext.TaprootInternalKey()
if err != nil {
return nil, fmt.Errorf("error getting internal key: %w",
err)
}
session.TaprootInternalKey = internalKey
}
// Since we generate new nonces for every session, there is no way that
// a session with the same ID already exists. So even if we call the API
// twice with the same signers, we still get a new ID.
//
// We'll use just all zeroes as the session ID for the mutex, as this
// is a "global" action.
m.musig2Sessions.Store(session.SessionID, session)
return &session.MuSig2SessionInfo, nil
}
// MuSig2Sign creates a partial signature using the local signing key
// that was specified when the session was created. This can only be
// called when all public nonces of all participants are known and have
// been registered with the session. If this node isn't responsible for
// combining all the partial signatures, then the cleanup parameter
// should be set, indicating that the session can be removed from memory
// once the signature was produced.
func (m *MusigSessionManager) MuSig2Sign(sessionID MuSig2SessionID,
msg [sha256.Size]byte, cleanUp bool) (*musig2.PartialSignature, error) {
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions.Load(sessionID)
if !ok {
return nil, fmt.Errorf("session with ID %x not found",
sessionID[:])
}
// We can only sign once we have all other signer's nonces.
if !session.HaveAllNonces {
return nil, fmt.Errorf("only have %d of %d required nonces",
session.session.NumRegisteredNonces(),
len(session.context.SigningKeys()))
}
// Create our own partial signature with the local signing key.
partialSig, err := MuSig2Sign(session.session, msg, true)
if err != nil {
return nil, fmt.Errorf("error signing with local key: %w", err)
}
// Clean up our local state if requested.
if cleanUp {
m.musig2Sessions.Delete(sessionID)
}
return partialSig, nil
}
// MuSig2CombineSig combines the given partial signature(s) with the
// local one, if it already exists. Once a partial signature of all
// participants is registered, the final signature will be combined and
// returned.
func (m *MusigSessionManager) MuSig2CombineSig(sessionID MuSig2SessionID,
partialSigs []*musig2.PartialSignature) (*schnorr.Signature, bool,
error) {
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions.Load(sessionID)
if !ok {
return nil, false, fmt.Errorf("session with ID %x not found",
sessionID[:])
}
// Make sure we don't exceed the number of expected partial signatures
// as that would indicate something is wrong with the signing setup.
if session.HaveAllSigs {
return nil, true, fmt.Errorf("already have all partial" +
"signatures")
}
// Add all sigs we got so far.
var (
finalSig *schnorr.Signature
err error
)
for _, otherPartialSig := range partialSigs {
session.HaveAllSigs, err = MuSig2CombineSig(
session.session, otherPartialSig,
)
if err != nil {
return nil, false, fmt.Errorf("error combining "+
"partial signature: %w", err)
}
}
// If we have all partial signatures, we should be able to get the
// complete signature now. We also remove this session from memory since
// there is nothing more left to do.
if session.HaveAllSigs {
finalSig = session.session.FinalSig()
m.musig2Sessions.Delete(sessionID)
}
return finalSig, session.HaveAllSigs, nil
}
// MuSig2Cleanup removes a session from memory to free up resources.
func (m *MusigSessionManager) MuSig2Cleanup(sessionID MuSig2SessionID) error {
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
_, ok := m.musig2Sessions.Load(sessionID)
if !ok {
return fmt.Errorf("session with ID %x not found", sessionID[:])
}
m.musig2Sessions.Delete(sessionID)
return nil
}
// MuSig2RegisterNonces registers one or more public nonces of other signing
// participants for a session identified by its ID. This method returns true
// once we have all nonces for all other signing participants.
func (m *MusigSessionManager) MuSig2RegisterNonces(sessionID MuSig2SessionID,
otherSignerNonces [][musig2.PubNonceSize]byte) (bool, error) {
// We hold the lock during the whole operation, we don't want any
// interference with calls that might come through in parallel for the
// same session.
m.sessionMtx.Lock(sessionID)
defer m.sessionMtx.Unlock(sessionID)
session, ok := m.musig2Sessions.Load(sessionID)
if !ok {
return false, fmt.Errorf("session with ID %x not found",
sessionID[:])
}
// Make sure we don't exceed the number of expected nonces as that would
// indicate something is wrong with the signing setup.
if session.HaveAllNonces {
return true, fmt.Errorf("already have all nonces")
}
numSigners := len(session.context.SigningKeys())
remainingNonces := numSigners - session.session.NumRegisteredNonces()
if len(otherSignerNonces) > remainingNonces {
return false, fmt.Errorf("only %d other nonces remaining but "+
"trying to register %d more", remainingNonces,
len(otherSignerNonces))
}
// Add all nonces we've learned so far.
var err error
for _, otherSignerNonce := range otherSignerNonces {
session.HaveAllNonces, err = session.session.RegisterPubNonce(
otherSignerNonce,
)
if err != nil {
return false, fmt.Errorf("error registering other "+
"signer public nonce: %v", err)
}
}
return session.HaveAllNonces, nil
}
package input
import (
"errors"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/txscript"
"github.com/lightningnetwork/lnd/lnutils"
)
// ErrUnknownScriptType is returned when an unknown script type is encountered.
var ErrUnknownScriptType = errors.New("unknown script type")
// ScriptPath is used to indicate the spending path of a given script. Possible
// paths include: timeout, success, revocation, and others.
type ScriptPath uint8
const (
// ScriptPathTimeout is a script path that can be taken only after a
// timeout has elapsed.
ScriptPathTimeout ScriptPath = iota
// ScriptPathSuccess is a script path that can be taken only with some
// secret data.
ScriptPathSuccess
// ScriptPathRevocation is a script path used when a contract has been
// breached.
ScriptPathRevocation
// ScriptPathDelay is a script path used when a contract has relative
// delay that must elapse before it can be swept.
ScriptPathDelay
)
// ScriptDescriptor is an interface that abstracts over the various ways a
// pkScript can be spent from an output. This supports both normal p2wsh
// (witness script, etc.), and also tapscript paths which have distinct
// tapscript leaves.
type ScriptDescriptor interface {
// PkScript is the public key script that commits to the final
// contract.
PkScript() []byte
// WitnessScriptToSign returns the witness script that we'll use when
// signing for the remote party, and also verifying signatures on our
// transactions. As an example, when we create an outgoing HTLC for the
// remote party, we want to sign their success path.
//
// TODO(roasbeef): break out into HTLC specific desc? or Branching Desc
// w/ the below?
WitnessScriptToSign() []byte
// WitnessScriptForPath returns the witness script for the given
// spending path. An error is returned if the path is unknown. This is
// useful as when constructing a control block for a given path, one
// also needs witness script being signed.
WitnessScriptForPath(path ScriptPath) ([]byte, error)
}
// TapscriptDescriptor is a super-set of the normal script multiplexer that
// adds in taproot specific details such as the control block, or top-level tap
// tweak.
type TapscriptDescriptor interface {
ScriptDescriptor
// CtrlBlockForPath returns the control block for the given spending
// path. For unknown paths, an error is returned.
CtrlBlockForPath(path ScriptPath) (*txscript.ControlBlock, error)
// TapTweak returns the top-level taproot tweak for the script.
TapTweak() []byte
// TapScriptTree returns the underlying tapscript tree.
TapScriptTree() *txscript.IndexedTapScriptTree
// Tree returns the underlying ScriptTree.
Tree() ScriptTree
}
// ScriptTree holds the contents needed to spend a script within a tapscript
// tree.
type ScriptTree struct {
// InternalKey is the internal key of the Taproot output key.
InternalKey *btcec.PublicKey
// TaprootKey is the key that will be used to generate the taproot
// output.
TaprootKey *btcec.PublicKey
// TapscriptTree is the full tapscript tree that also includes the
// control block needed to spend each of the leaves.
TapscriptTree *txscript.IndexedTapScriptTree
// TapscriptTreeRoot is the root hash of the tapscript tree.
TapscriptRoot []byte
}
// PkScript is the public key script that commits to the final contract.
func (s *ScriptTree) PkScript() []byte {
// Script building can never internally return an error, so we ignore
// the error to simplify the interface.
pkScript, _ := PayToTaprootScript(s.TaprootKey)
return pkScript
}
// TapTweak returns the top-level taproot tweak for the script.
func (s *ScriptTree) TapTweak() []byte {
return lnutils.ByteSlice(s.TapscriptTree.RootNode.TapHash())
}
// TapScriptTree returns the underlying tapscript tree.
func (s *ScriptTree) TapScriptTree() *txscript.IndexedTapScriptTree {
return s.TapscriptTree
}
package input
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"golang.org/x/crypto/ripemd160"
)
var (
// TODO(roasbeef): remove these and use the one's defined in txscript
// within testnet-L.
// SequenceLockTimeSeconds is the 22nd bit which indicates the lock
// time is in seconds.
SequenceLockTimeSeconds = uint32(1 << 22)
)
// MustParsePubKey parses a hex encoded public key string into a public key and
// panic if parsing fails.
func MustParsePubKey(pubStr string) btcec.PublicKey {
pubBytes, err := hex.DecodeString(pubStr)
if err != nil {
panic(err)
}
pub, err := btcec.ParsePubKey(pubBytes)
if err != nil {
panic(err)
}
return *pub
}
// TaprootNUMSHex is the hex encoded version of the taproot NUMs key.
const TaprootNUMSHex = "02dca094751109d0bd055d03565874e8276dd53e926b44e3bd1bb" +
"6bf4bc130a279"
var (
// TaprootNUMSKey is a NUMS key (nothing up my sleeves number) that has
// no known private key. This was generated using the following script:
// https://github.com/lightninglabs/lightning-node-connect/tree/
// master/mailbox/numsgen, with the seed phrase "Lightning Simple
// Taproot".
TaprootNUMSKey = MustParsePubKey(TaprootNUMSHex)
)
// Signature is an interface for objects that can populate signatures during
// witness construction.
type Signature interface {
// Serialize returns a DER-encoded ECDSA signature.
Serialize() []byte
// Verify return true if the ECDSA signature is valid for the passed
// message digest under the provided public key.
Verify([]byte, *btcec.PublicKey) bool
}
// ParseSignature parses a raw signature into an input.Signature instance. This
// routine supports parsing normal ECDSA DER encoded signatures, as well as
// schnorr signatures.
func ParseSignature(rawSig []byte) (Signature, error) {
if len(rawSig) == schnorr.SignatureSize {
return schnorr.ParseSignature(rawSig)
}
return ecdsa.ParseDERSignature(rawSig)
}
// WitnessScriptHash generates a pay-to-witness-script-hash public key script
// paying to a version 0 witness program paying to the passed redeem script.
func WitnessScriptHash(witnessScript []byte) ([]byte, error) {
bldr := txscript.NewScriptBuilder(
txscript.WithScriptAllocSize(P2WSHSize),
)
bldr.AddOp(txscript.OP_0)
scriptHash := sha256.Sum256(witnessScript)
bldr.AddData(scriptHash[:])
return bldr.Script()
}
// WitnessPubKeyHash generates a pay-to-witness-pubkey-hash public key script
// paying to a version 0 witness program containing the passed serialized
// public key.
func WitnessPubKeyHash(pubkey []byte) ([]byte, error) {
bldr := txscript.NewScriptBuilder(
txscript.WithScriptAllocSize(P2WPKHSize),
)
bldr.AddOp(txscript.OP_0)
pkhash := btcutil.Hash160(pubkey)
bldr.AddData(pkhash)
return bldr.Script()
}
// GenerateP2SH generates a pay-to-script-hash public key script paying to the
// passed redeem script.
func GenerateP2SH(script []byte) ([]byte, error) {
bldr := txscript.NewScriptBuilder(
txscript.WithScriptAllocSize(NestedP2WPKHSize),
)
bldr.AddOp(txscript.OP_HASH160)
scripthash := btcutil.Hash160(script)
bldr.AddData(scripthash)
bldr.AddOp(txscript.OP_EQUAL)
return bldr.Script()
}
// GenerateP2PKH generates a pay-to-public-key-hash public key script paying to
// the passed serialized public key.
func GenerateP2PKH(pubkey []byte) ([]byte, error) {
bldr := txscript.NewScriptBuilder(
txscript.WithScriptAllocSize(P2PKHSize),
)
bldr.AddOp(txscript.OP_DUP)
bldr.AddOp(txscript.OP_HASH160)
pkhash := btcutil.Hash160(pubkey)
bldr.AddData(pkhash)
bldr.AddOp(txscript.OP_EQUALVERIFY)
bldr.AddOp(txscript.OP_CHECKSIG)
return bldr.Script()
}
// GenerateUnknownWitness generates the maximum-sized witness public key script
// consisting of a version push and a 40-byte data push.
func GenerateUnknownWitness() ([]byte, error) {
bldr := txscript.NewScriptBuilder()
bldr.AddOp(txscript.OP_0)
witnessScript := make([]byte, 40)
bldr.AddData(witnessScript)
return bldr.Script()
}
// GenMultiSigScript generates the non-p2sh'd multisig script for 2 of 2
// pubkeys.
func GenMultiSigScript(aPub, bPub []byte) ([]byte, error) {
if len(aPub) != 33 || len(bPub) != 33 {
return nil, fmt.Errorf("pubkey size error: compressed " +
"pubkeys only")
}
// Swap to sort pubkeys if needed. Keys are sorted in lexicographical
// order. The signatures within the scriptSig must also adhere to the
// order, ensuring that the signatures for each public key appears in
// the proper order on the stack.
if bytes.Compare(aPub, bPub) == 1 {
aPub, bPub = bPub, aPub
}
bldr := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
MultiSigSize,
))
bldr.AddOp(txscript.OP_2)
bldr.AddData(aPub) // Add both pubkeys (sorted).
bldr.AddData(bPub)
bldr.AddOp(txscript.OP_2)
bldr.AddOp(txscript.OP_CHECKMULTISIG)
return bldr.Script()
}
// GenFundingPkScript creates a redeem script, and its matching p2wsh
// output for the funding transaction.
func GenFundingPkScript(aPub, bPub []byte, amt int64) ([]byte, *wire.TxOut, error) {
// As a sanity check, ensure that the passed amount is above zero.
if amt <= 0 {
return nil, nil, fmt.Errorf("can't create FundTx script with " +
"zero, or negative coins")
}
// First, create the 2-of-2 multi-sig script itself.
witnessScript, err := GenMultiSigScript(aPub, bPub)
if err != nil {
return nil, nil, err
}
// With the 2-of-2 script in had, generate a p2wsh script which pays
// to the funding script.
pkScript, err := WitnessScriptHash(witnessScript)
if err != nil {
return nil, nil, err
}
return witnessScript, wire.NewTxOut(amt, pkScript), nil
}
// GenTaprootFundingScript constructs the taproot-native funding output that
// uses MuSig2 to create a single aggregated key to anchor the channel.
func GenTaprootFundingScript(aPub, bPub *btcec.PublicKey,
amt int64, tapscriptRoot fn.Option[chainhash.Hash]) ([]byte,
*wire.TxOut, error) {
muSig2Opt := musig2.WithBIP86KeyTweak()
tapscriptRoot.WhenSome(func(scriptRoot chainhash.Hash) {
muSig2Opt = musig2.WithTaprootKeyTweak(scriptRoot[:])
})
// Similar to the existing p2wsh funding script, we'll always make sure
// we sort the keys before any major operations. In order to ensure
// that there's no other way this output can be spent, we'll use a BIP
// 86 tweak here during aggregation, unless the user has explicitly
// specified a tapscript root.
combinedKey, _, _, err := musig2.AggregateKeys(
[]*btcec.PublicKey{aPub, bPub}, true, muSig2Opt,
)
if err != nil {
return nil, nil, fmt.Errorf("unable to combine keys: %w", err)
}
// Now that we have the combined key, we can create a taproot pkScript
// from this, and then make the txOut given the amount.
pkScript, err := PayToTaprootScript(combinedKey.FinalKey)
if err != nil {
return nil, nil, fmt.Errorf("unable to make taproot "+
"pkscript: %w", err)
}
txOut := wire.NewTxOut(amt, pkScript)
// For the "witness program" we just return the raw pkScript since the
// output we create can _only_ be spent with a MuSig2 signature.
return pkScript, txOut, nil
}
// SpendMultiSig generates the witness stack required to redeem the 2-of-2 p2wsh
// multi-sig output.
func SpendMultiSig(witnessScript, pubA []byte, sigA Signature,
pubB []byte, sigB Signature) [][]byte {
witness := make([][]byte, 4)
// When spending a p2wsh multi-sig script, rather than an OP_0, we add
// a nil stack element to eat the extra pop.
witness[0] = nil
// When initially generating the witnessScript, we sorted the serialized
// public keys in descending order. So we do a quick comparison in order
// ensure the signatures appear on the Script Virtual Machine stack in
// the correct order.
if bytes.Compare(pubA, pubB) == 1 {
witness[1] = append(sigB.Serialize(), byte(txscript.SigHashAll))
witness[2] = append(sigA.Serialize(), byte(txscript.SigHashAll))
} else {
witness[1] = append(sigA.Serialize(), byte(txscript.SigHashAll))
witness[2] = append(sigB.Serialize(), byte(txscript.SigHashAll))
}
// Finally, add the preimage as the last witness element.
witness[3] = witnessScript
return witness
}
// FindScriptOutputIndex finds the index of the public key script output
// matching 'script'. Additionally, a boolean is returned indicating if a
// matching output was found at all.
//
// NOTE: The search stops after the first matching script is found.
func FindScriptOutputIndex(tx *wire.MsgTx, script []byte) (bool, uint32) {
found := false
index := uint32(0)
for i, txOut := range tx.TxOut {
if bytes.Equal(txOut.PkScript, script) {
found = true
index = uint32(i)
break
}
}
return found, index
}
// Ripemd160H calculates the ripemd160 of the passed byte slice. This is used to
// calculate the intermediate hash for payment pre-images. Payment hashes are
// the result of ripemd160(sha256(paymentPreimage)). As a result, the value
// passed in should be the sha256 of the payment hash.
func Ripemd160H(d []byte) []byte {
h := ripemd160.New()
h.Write(d)
return h.Sum(nil)
}
// SenderHTLCScript constructs the public key script for an outgoing HTLC
// output payment for the sender's version of the commitment transaction. The
// possible script paths from this output include:
//
// - The sender timing out the HTLC using the second level HTLC timeout
// transaction.
// - The receiver of the HTLC claiming the output on-chain with the payment
// preimage.
// - The receiver of the HTLC sweeping all the funds in the case that a
// revoked commitment transaction bearing this HTLC was broadcast.
//
// If confirmedSpend=true, a 1 OP_CSV check will be added to the non-revocation
// cases, to allow sweeping only after confirmation.
//
// Possible Input Scripts:
//
// SENDR: <0> <sendr sig> <recvr sig> <0> (spend using HTLC timeout transaction)
// RECVR: <recvr sig> <preimage>
// REVOK: <revoke sig> <revoke key>
// * receiver revoke
//
// Offered HTLC Output Script:
//
// OP_DUP OP_HASH160 <revocation key hash160> OP_EQUAL
// OP_IF
// OP_CHECKSIG
// OP_ELSE
// <recv htlc key>
// OP_SWAP OP_SIZE 32 OP_EQUAL
// OP_NOTIF
// OP_DROP 2 OP_SWAP <sender htlc key> 2 OP_CHECKMULTISIG
// OP_ELSE
// OP_HASH160 <ripemd160(payment hash)> OP_EQUALVERIFY
// OP_CHECKSIG
// OP_ENDIF
// [1 OP_CHECKSEQUENCEVERIFY OP_DROP] <- if allowing confirmed
// spend only.
// OP_ENDIF
func SenderHTLCScript(senderHtlcKey, receiverHtlcKey,
revocationKey *btcec.PublicKey, paymentHash []byte,
confirmedSpend bool) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
OfferedHtlcScriptSizeConfirmed,
))
// The opening operations are used to determine if this is the receiver
// of the HTLC attempting to sweep all the funds due to a contract
// breach. In this case, they'll place the revocation key at the top of
// the stack.
builder.AddOp(txscript.OP_DUP)
builder.AddOp(txscript.OP_HASH160)
builder.AddData(btcutil.Hash160(revocationKey.SerializeCompressed()))
builder.AddOp(txscript.OP_EQUAL)
// If the hash matches, then this is the revocation clause. The output
// can be spent if the check sig operation passes.
builder.AddOp(txscript.OP_IF)
builder.AddOp(txscript.OP_CHECKSIG)
// Otherwise, this may either be the receiver of the HTLC claiming with
// the pre-image, or the sender of the HTLC sweeping the output after
// it has timed out.
builder.AddOp(txscript.OP_ELSE)
// We'll do a bit of set up by pushing the receiver's key on the top of
// the stack. This will be needed later if we decide that this is the
// sender activating the time out clause with the HTLC timeout
// transaction.
builder.AddData(receiverHtlcKey.SerializeCompressed())
// Atm, the top item of the stack is the receiverKey's so we use a swap
// to expose what is either the payment pre-image or a signature.
builder.AddOp(txscript.OP_SWAP)
// With the top item swapped, check if it's 32 bytes. If so, then this
// *may* be the payment pre-image.
builder.AddOp(txscript.OP_SIZE)
builder.AddInt64(32)
builder.AddOp(txscript.OP_EQUAL)
// If it isn't then this might be the sender of the HTLC activating the
// time out clause.
builder.AddOp(txscript.OP_NOTIF)
// We'll drop the OP_IF return value off the top of the stack so we can
// reconstruct the multi-sig script used as an off-chain covenant. If
// two valid signatures are provided, then the output will be deemed as
// spendable.
builder.AddOp(txscript.OP_DROP)
builder.AddOp(txscript.OP_2)
builder.AddOp(txscript.OP_SWAP)
builder.AddData(senderHtlcKey.SerializeCompressed())
builder.AddOp(txscript.OP_2)
builder.AddOp(txscript.OP_CHECKMULTISIG)
// Otherwise, then the only other case is that this is the receiver of
// the HTLC sweeping it on-chain with the payment pre-image.
builder.AddOp(txscript.OP_ELSE)
// Hash the top item of the stack and compare it with the hash160 of
// the payment hash, which is already the sha256 of the payment
// pre-image. By using this little trick we're able to save space
// on-chain as the witness includes a 20-byte hash rather than a
// 32-byte hash.
builder.AddOp(txscript.OP_HASH160)
builder.AddData(Ripemd160H(paymentHash))
builder.AddOp(txscript.OP_EQUALVERIFY)
// This checks the receiver's signature so that a third party with
// knowledge of the payment preimage still cannot steal the output.
builder.AddOp(txscript.OP_CHECKSIG)
// Close out the OP_IF statement above.
builder.AddOp(txscript.OP_ENDIF)
// Add 1 block CSV delay if a confirmation is required for the
// non-revocation clauses.
if confirmedSpend {
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
}
// Close out the OP_IF statement at the top of the script.
builder.AddOp(txscript.OP_ENDIF)
return builder.Script()
}
// SenderHtlcSpendRevokeWithKey constructs a valid witness allowing the receiver of an
// HTLC to claim the output with knowledge of the revocation private key in the
// scenario that the sender of the HTLC broadcasts a previously revoked
// commitment transaction. A valid spend requires knowledge of the private key
// that corresponds to their revocation base point and also the private key from
// the per commitment point, and a valid signature under the combined public
// key.
func SenderHtlcSpendRevokeWithKey(signer Signer, signDesc *SignDescriptor,
revokeKey *btcec.PublicKey, sweepTx *wire.MsgTx) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The stack required to sweep a revoke HTLC output consists simply of
// the exact witness stack as one of a regular p2wkh spend. The only
// difference is that the keys used were derived in an adversarial
// manner in order to encode the revocation contract into a sig+key
// pair.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = revokeKey.SerializeCompressed()
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// SenderHtlcSpendRevoke constructs a valid witness allowing the receiver of an
// HTLC to claim the output with knowledge of the revocation private key in the
// scenario that the sender of the HTLC broadcasts a previously revoked
// commitment transaction. This method first derives the appropriate revocation
// key, and requires that the provided SignDescriptor has a local revocation
// basepoint and commitment secret in the PubKey and DoubleTweak fields,
// respectively.
func SenderHtlcSpendRevoke(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
revokeKey, err := deriveRevokePubKey(signDesc)
if err != nil {
return nil, err
}
return SenderHtlcSpendRevokeWithKey(signer, signDesc, revokeKey, sweepTx)
}
// IsHtlcSpendRevoke is used to determine if the passed spend is spending a
// HTLC output using the revocation key.
func IsHtlcSpendRevoke(txIn *wire.TxIn, signDesc *SignDescriptor) (
bool, error) {
// For taproot channels, the revocation path only has a single witness,
// as that's the key spend path.
isTaproot := txscript.IsPayToTaproot(signDesc.Output.PkScript)
if isTaproot {
return len(txIn.Witness) == 1, nil
}
revokeKey, err := deriveRevokePubKey(signDesc)
if err != nil {
return false, err
}
if len(txIn.Witness) == 3 &&
bytes.Equal(txIn.Witness[1], revokeKey.SerializeCompressed()) {
return true, nil
}
return false, nil
}
// SenderHtlcSpendRedeem constructs a valid witness allowing the receiver of an
// HTLC to redeem the pending output in the scenario that the sender broadcasts
// their version of the commitment transaction. A valid spend requires
// knowledge of the payment preimage, and a valid signature under the receivers
// public key.
func SenderHtlcSpendRedeem(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, paymentPreimage []byte) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The stack required to spend this output is simply the signature
// generated above under the receiver's public key, and the payment
// pre-image.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = paymentPreimage
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// SenderHtlcSpendTimeout constructs a valid witness allowing the sender of an
// HTLC to activate the time locked covenant clause of a soon to be expired
// HTLC. This script simply spends the multi-sig output using the
// pre-generated HTLC timeout transaction.
func SenderHtlcSpendTimeout(receiverSig Signature,
receiverSigHash txscript.SigHashType, signer Signer,
signDesc *SignDescriptor, htlcTimeoutTx *wire.MsgTx) (
wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(htlcTimeoutTx, signDesc)
if err != nil {
return nil, err
}
// We place a zero as the first item of the evaluated witness stack in
// order to force Script execution to the HTLC timeout clause. The
// second zero is required to consume the extra pop due to a bug in the
// original OP_CHECKMULTISIG.
witnessStack := wire.TxWitness(make([][]byte, 5))
witnessStack[0] = nil
witnessStack[1] = append(receiverSig.Serialize(), byte(receiverSigHash))
witnessStack[2] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[3] = nil
witnessStack[4] = signDesc.WitnessScript
return witnessStack, nil
}
// SenderHTLCTapLeafTimeout returns the full tapscript leaf for the timeout
// path of the sender HTLC. This is a small script that allows the sender to
// timeout the HTLC after a period of time:
//
// <local_key> OP_CHECKSIGVERIFY
// <remote_key> OP_CHECKSIG
func SenderHTLCTapLeafTimeout(senderHtlcKey,
receiverHtlcKey *btcec.PublicKey) (txscript.TapLeaf, error) {
builder := txscript.NewScriptBuilder()
builder.AddData(schnorr.SerializePubKey(senderHtlcKey))
builder.AddOp(txscript.OP_CHECKSIGVERIFY)
builder.AddData(schnorr.SerializePubKey(receiverHtlcKey))
builder.AddOp(txscript.OP_CHECKSIG)
timeoutLeafScript, err := builder.Script()
if err != nil {
return txscript.TapLeaf{}, err
}
return txscript.NewBaseTapLeaf(timeoutLeafScript), nil
}
// SenderHTLCTapLeafSuccess returns the full tapscript leaf for the success
// path of the sender HTLC. This is a small script that allows the receiver to
// redeem the HTLC with a pre-image:
//
// OP_SIZE 32 OP_EQUALVERIFY OP_HASH160
// <RIPEMD160(payment_hash)> OP_EQUALVERIFY
// <remote_htlcpubkey> OP_CHECKSIG
// 1 OP_CHECKSEQUENCEVERIFY OP_DROP
func SenderHTLCTapLeafSuccess(receiverHtlcKey *btcec.PublicKey,
paymentHash []byte) (txscript.TapLeaf, error) {
builder := txscript.NewScriptBuilder()
// Check that the pre-image is 32 bytes as required.
builder.AddOp(txscript.OP_SIZE)
builder.AddInt64(32)
builder.AddOp(txscript.OP_EQUALVERIFY)
// Check that the specified pre-image matches what we hard code into
// the script.
builder.AddOp(txscript.OP_HASH160)
builder.AddData(Ripemd160H(paymentHash))
builder.AddOp(txscript.OP_EQUALVERIFY)
// Verify the remote party's signature, then make them wait 1 block
// after confirmation to properly sweep.
builder.AddData(schnorr.SerializePubKey(receiverHtlcKey))
builder.AddOp(txscript.OP_CHECKSIG)
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
successLeafScript, err := builder.Script()
if err != nil {
return txscript.TapLeaf{}, err
}
return txscript.NewBaseTapLeaf(successLeafScript), nil
}
// htlcType is an enum value that denotes what type of HTLC script this is.
type htlcType uint8
const (
// htlcLocalIncoming represents an incoming HTLC on the local
// commitment transaction.
htlcLocalIncoming htlcType = iota
// htlcLocalOutgoing represents an outgoing HTLC on the local
// commitment transaction.
htlcLocalOutgoing
// htlcRemoteIncoming represents an incoming HTLC on the remote
// commitment transaction.
htlcRemoteIncoming
// htlcRemoteOutgoing represents an outgoing HTLC on the remote
// commitment transaction.
htlcRemoteOutgoing
)
// HtlcScriptTree holds the taproot output key, as well as the two script path
// leaves that every taproot HTLC script depends on.
type HtlcScriptTree struct {
ScriptTree
// SuccessTapLeaf is the tapleaf for the redemption path.
SuccessTapLeaf txscript.TapLeaf
// TimeoutTapLeaf is the tapleaf for the timeout path.
TimeoutTapLeaf txscript.TapLeaf
// AuxLeaf is an auxiliary leaf that can be used to extend the base
// HTLC script tree with new spend paths, or just as extra commitment
// space. When present, this leaf will always be in the right-most area
// of the tapscript tree.
AuxLeaf AuxTapLeaf
// htlcType is the type of HTLC script this is.
htlcType htlcType
}
// WitnessScriptToSign returns the witness script that we'll use when signing
// for the remote party, and also verifying signatures on our transactions. As
// an example, when we create an outgoing HTLC for the remote party, we want to
// sign the success path for them, so we'll return the success path leaf.
func (h *HtlcScriptTree) WitnessScriptToSign() []byte {
switch h.htlcType {
// For incoming HLTCs on our local commitment, we care about verifying
// the success path.
case htlcLocalIncoming:
return h.SuccessTapLeaf.Script
// For incoming HTLCs on the remote party's commitment, we want to sign
// the timeout path for them.
case htlcRemoteIncoming:
return h.TimeoutTapLeaf.Script
// For outgoing HTLCs on our local commitment, we want to verify the
// timeout path.
case htlcLocalOutgoing:
return h.TimeoutTapLeaf.Script
// For outgoing HTLCs on the remote party's commitment, we want to sign
// the success path for them.
case htlcRemoteOutgoing:
return h.SuccessTapLeaf.Script
default:
panic(fmt.Sprintf("unknown htlc type: %v", h.htlcType))
}
}
// WitnessScriptForPath returns the witness script for the given spending path.
// An error is returned if the path is unknown.
func (h *HtlcScriptTree) WitnessScriptForPath(path ScriptPath) ([]byte, error) {
switch path {
case ScriptPathSuccess:
return h.SuccessTapLeaf.Script, nil
case ScriptPathTimeout:
return h.TimeoutTapLeaf.Script, nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// CtrlBlockForPath returns the control block for the given spending path. For
// script types that don't have a control block, nil is returned.
func (h *HtlcScriptTree) CtrlBlockForPath(
path ScriptPath) (*txscript.ControlBlock, error) {
switch path {
case ScriptPathSuccess:
return lnutils.Ptr(MakeTaprootCtrlBlock(
h.SuccessTapLeaf.Script, h.InternalKey,
h.TapscriptTree,
)), nil
case ScriptPathTimeout:
return lnutils.Ptr(MakeTaprootCtrlBlock(
h.TimeoutTapLeaf.Script, h.InternalKey,
h.TapscriptTree,
)), nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// Tree returns the underlying ScriptTree of the HtlcScriptTree.
func (h *HtlcScriptTree) Tree() ScriptTree {
return h.ScriptTree
}
// A compile time check to ensure HtlcScriptTree implements the
// TapscriptMultiplexer interface.
var _ TapscriptDescriptor = (*HtlcScriptTree)(nil)
// senderHtlcTapScriptTree builds the tapscript tree which is used to anchor
// the HTLC key for HTLCs on the sender's commitment.
func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
revokeKey *btcec.PublicKey, payHash []byte, hType htlcType,
auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) {
// First, we'll obtain the tap leaves for both the success and timeout
// path.
successTapLeaf, err := SenderHTLCTapLeafSuccess(
receiverHtlcKey, payHash,
)
if err != nil {
return nil, err
}
timeoutTapLeaf, err := SenderHTLCTapLeafTimeout(
senderHtlcKey, receiverHtlcKey,
)
if err != nil {
return nil, err
}
tapLeaves := []txscript.TapLeaf{successTapLeaf, timeoutTapLeaf}
auxLeaf.WhenSome(func(l txscript.TapLeaf) {
tapLeaves = append(tapLeaves, l)
})
// With the two leaves obtained, we'll now make the tapscript tree,
// then obtain the root from that
tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...)
tapScriptRoot := tapscriptTree.RootNode.TapHash()
// With the tapscript root obtained, we'll tweak the revocation key
// with this value to obtain the key that HTLCs will be sent to.
htlcKey := txscript.ComputeTaprootOutputKey(
revokeKey, tapScriptRoot[:],
)
return &HtlcScriptTree{
ScriptTree: ScriptTree{
TaprootKey: htlcKey,
TapscriptTree: tapscriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: revokeKey,
},
SuccessTapLeaf: successTapLeaf,
TimeoutTapLeaf: timeoutTapLeaf,
AuxLeaf: auxLeaf,
htlcType: hType,
}, nil
}
// SenderHTLCScriptTaproot constructs the taproot witness program (schnorr key)
// for an outgoing HTLC on the sender's version of the commitment transaction.
// This method returns the top level tweaked public key that commits to both
// the script paths. This is also known as an offered HTLC.
//
// The returned key commits to a tapscript tree with two possible paths:
//
// - Timeout path:
// <local_key> OP_CHECKSIGVERIFY
// <remote_key> OP_CHECKSIG
//
// - Success path:
// OP_SIZE 32 OP_EQUALVERIFY
// OP_HASH160 <RIPEMD160(payment_hash)> OP_EQUALVERIFY
// <remote_htlcpubkey> OP_CHECKSIG
// 1 OP_CHECKSEQUENCEVERIFY OP_DROP
//
// The timeout path can be spent with a witness of (sender timeout):
//
// <receiver sig> <local sig> <timeout_script> <control_block>
//
// The success path can be spent with a valid control block, and a witness of
// (receiver redeem):
//
// <receiver sig> <preimage> <success_script> <control_block>
//
// The top level keyspend key is the revocation key, which allows a defender to
// unilaterally spend the created output.
func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey,
revokeKey *btcec.PublicKey, payHash []byte,
whoseCommit lntypes.ChannelParty, auxLeaf AuxTapLeaf) (*HtlcScriptTree,
error) {
var hType htlcType
if whoseCommit.IsLocal() {
hType = htlcLocalOutgoing
} else {
hType = htlcRemoteIncoming
}
// Given all the necessary parameters, we'll return the HTLC script
// tree that includes the top level output script, as well as the two
// tap leaf paths.
return senderHtlcTapScriptTree(
senderHtlcKey, receiverHtlcKey, revokeKey, payHash, hType,
auxLeaf,
)
}
// maybeAppendSighashType appends a sighash type to the end of a signature if
// the sighash type isn't sighash default.
func maybeAppendSighash(sig Signature, sigHash txscript.SigHashType) []byte {
sigBytes := sig.Serialize()
if sigHash == txscript.SigHashDefault {
return sigBytes
}
return append(sigBytes, byte(sigHash))
}
// SenderHTLCScriptTaprootRedeem creates a valid witness needed to redeem a
// sender taproot HTLC with the pre-image. The returned witness is valid and
// includes the control block required to spend the output. This is the offered
// HTLC claimed by the remote party.
func SenderHTLCScriptTaprootRedeem(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, preimage []byte, revokeKey *btcec.PublicKey,
tapscriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// In addition to the signature and the witness/leaf script, we also
// need to make a control block proof using the tapscript tree.
var ctrlBlock []byte
if signDesc.ControlBlock == nil {
successControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, revokeKey, tapscriptTree,
)
ctrlBytes, err := successControlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlock = ctrlBytes
} else {
ctrlBlock = signDesc.ControlBlock
}
// The final witness stack is:
// <receiver sig> <preimage> <success_script> <control_block>
witnessStack := make(wire.TxWitness, 4)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[1] = preimage
witnessStack[2] = signDesc.WitnessScript
witnessStack[3] = ctrlBlock
return witnessStack, nil
}
// SenderHTLCScriptTaprootTimeout creates a valid witness needed to timeout an
// HTLC on the sender's commitment transaction. The returned witness is valid
// and includes the control block required to spend the output. This is a
// timeout of the offered HTLC by the sender.
func SenderHTLCScriptTaprootTimeout(receiverSig Signature,
receiverSigHash txscript.SigHashType, signer Signer,
signDesc *SignDescriptor, htlcTimeoutTx *wire.MsgTx,
revokeKey *btcec.PublicKey,
tapscriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(htlcTimeoutTx, signDesc)
if err != nil {
return nil, err
}
// With the sweep signature obtained, we'll obtain the control block
// proof needed to perform a valid spend for the timeout path.
var ctrlBlockBytes []byte
if signDesc.ControlBlock == nil {
timeoutControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, revokeKey, tapscriptTree,
)
ctrlBytes, err := timeoutControlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlockBytes = ctrlBytes
} else {
ctrlBlockBytes = signDesc.ControlBlock
}
// The final witness stack is:
// <receiver sig> <local sig> <timeout_script> <control_block>
witnessStack := make(wire.TxWitness, 4)
witnessStack[0] = maybeAppendSighash(receiverSig, receiverSigHash)
witnessStack[1] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[2] = signDesc.WitnessScript
witnessStack[3] = ctrlBlockBytes
return witnessStack, nil
}
// SenderHTLCScriptTaprootRevoke creates a valid witness needed to spend the
// revocation path of the HTLC. This uses a plain keyspend using the specified
// revocation key.
func SenderHTLCScriptTaprootRevoke(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The witness stack in this case is pretty simple: we only need to
// specify the signature generated.
witnessStack := make(wire.TxWitness, 1)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
return witnessStack, nil
}
// ReceiverHTLCScript constructs the public key script for an incoming HTLC
// output payment for the receiver's version of the commitment transaction. The
// possible execution paths from this script include:
// - The receiver of the HTLC uses its second level HTLC transaction to
// advance the state of the HTLC into the delay+claim state.
// - The sender of the HTLC sweeps all the funds of the HTLC as a breached
// commitment was broadcast.
// - The sender of the HTLC sweeps the HTLC on-chain after the timeout period
// of the HTLC has passed.
//
// If confirmedSpend=true, a 1 OP_CSV check will be added to the non-revocation
// cases, to allow sweeping only after confirmation.
//
// Possible Input Scripts:
//
// RECVR: <0> <sender sig> <recvr sig> <preimage> (spend using HTLC success transaction)
// REVOK: <sig> <key>
// SENDR: <sig> 0
//
// Received HTLC Output Script:
//
// OP_DUP OP_HASH160 <revocation key hash160> OP_EQUAL
// OP_IF
// OP_CHECKSIG
// OP_ELSE
// <sendr htlc key>
// OP_SWAP OP_SIZE 32 OP_EQUAL
// OP_IF
// OP_HASH160 <ripemd160(payment hash)> OP_EQUALVERIFY
// 2 OP_SWAP <recvr htlc key> 2 OP_CHECKMULTISIG
// OP_ELSE
// OP_DROP <cltv expiry> OP_CHECKLOCKTIMEVERIFY OP_DROP
// OP_CHECKSIG
// OP_ENDIF
// [1 OP_CHECKSEQUENCEVERIFY OP_DROP] <- if allowing confirmed
// spend only.
// OP_ENDIF
func ReceiverHTLCScript(cltvExpiry uint32, senderHtlcKey,
receiverHtlcKey, revocationKey *btcec.PublicKey,
paymentHash []byte, confirmedSpend bool) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
AcceptedHtlcScriptSizeConfirmed,
))
// The opening operations are used to determine if this is the sender
// of the HTLC attempting to sweep all the funds due to a contract
// breach. In this case, they'll place the revocation key at the top of
// the stack.
builder.AddOp(txscript.OP_DUP)
builder.AddOp(txscript.OP_HASH160)
builder.AddData(btcutil.Hash160(revocationKey.SerializeCompressed()))
builder.AddOp(txscript.OP_EQUAL)
// If the hash matches, then this is the revocation clause. The output
// can be spent if the check sig operation passes.
builder.AddOp(txscript.OP_IF)
builder.AddOp(txscript.OP_CHECKSIG)
// Otherwise, this may either be the receiver of the HTLC starting the
// claiming process via the second level HTLC success transaction and
// the pre-image, or the sender of the HTLC sweeping the output after
// it has timed out.
builder.AddOp(txscript.OP_ELSE)
// We'll do a bit of set up by pushing the sender's key on the top of
// the stack. This will be needed later if we decide that this is the
// receiver transitioning the output to the claim state using their
// second-level HTLC success transaction.
builder.AddData(senderHtlcKey.SerializeCompressed())
// Atm, the top item of the stack is the sender's key so we use a swap
// to expose what is either the payment pre-image or something else.
builder.AddOp(txscript.OP_SWAP)
// With the top item swapped, check if it's 32 bytes. If so, then this
// *may* be the payment pre-image.
builder.AddOp(txscript.OP_SIZE)
builder.AddInt64(32)
builder.AddOp(txscript.OP_EQUAL)
// If the item on the top of the stack is 32-bytes, then it is the
// proper size, so this indicates that the receiver of the HTLC is
// attempting to claim the output on-chain by transitioning the state
// of the HTLC to delay+claim.
builder.AddOp(txscript.OP_IF)
// Next we'll hash the item on the top of the stack, if it matches the
// payment pre-image, then we'll continue. Otherwise, we'll end the
// script here as this is the invalid payment pre-image.
builder.AddOp(txscript.OP_HASH160)
builder.AddData(Ripemd160H(paymentHash))
builder.AddOp(txscript.OP_EQUALVERIFY)
// If the payment hash matches, then we'll also need to satisfy the
// multi-sig covenant by providing both signatures of the sender and
// receiver. If the convenient is met, then we'll allow the spending of
// this output, but only by the HTLC success transaction.
builder.AddOp(txscript.OP_2)
builder.AddOp(txscript.OP_SWAP)
builder.AddData(receiverHtlcKey.SerializeCompressed())
builder.AddOp(txscript.OP_2)
builder.AddOp(txscript.OP_CHECKMULTISIG)
// Otherwise, this might be the sender of the HTLC attempting to sweep
// it on-chain after the timeout.
builder.AddOp(txscript.OP_ELSE)
// We'll drop the extra item (which is the output from evaluating the
// OP_EQUAL) above from the stack.
builder.AddOp(txscript.OP_DROP)
// With that item dropped off, we can now enforce the absolute
// lock-time required to timeout the HTLC. If the time has passed, then
// we'll proceed with a checksig to ensure that this is actually the
// sender of he original HTLC.
builder.AddInt64(int64(cltvExpiry))
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
builder.AddOp(txscript.OP_DROP)
builder.AddOp(txscript.OP_CHECKSIG)
// Close out the inner if statement.
builder.AddOp(txscript.OP_ENDIF)
// Add 1 block CSV delay for non-revocation clauses if confirmation is
// required.
if confirmedSpend {
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
}
// Close out the outer if statement.
builder.AddOp(txscript.OP_ENDIF)
return builder.Script()
}
// ReceiverHtlcSpendRedeem constructs a valid witness allowing the receiver of
// an HTLC to redeem the conditional payment in the event that their commitment
// transaction is broadcast. This clause transitions the state of the HLTC
// output into the delay+claim state by activating the off-chain covenant bound
// by the 2-of-2 multi-sig output. The HTLC success timeout transaction being
// signed has a relative timelock delay enforced by its sequence number. This
// delay give the sender of the HTLC enough time to revoke the output if this
// is a breach commitment transaction.
func ReceiverHtlcSpendRedeem(senderSig Signature,
senderSigHash txscript.SigHashType, paymentPreimage []byte,
signer Signer, signDesc *SignDescriptor, htlcSuccessTx *wire.MsgTx) (
wire.TxWitness, error) {
// First, we'll generate a signature for the HTLC success transaction.
// The signDesc should be signing with the public key used as the
// receiver's public key and also the correct single tweak.
sweepSig, err := signer.SignOutputRaw(htlcSuccessTx, signDesc)
if err != nil {
return nil, err
}
// The final witness stack is used the provide the script with the
// payment pre-image, and also execute the multi-sig clause after the
// pre-images matches. We add a nil item at the bottom of the stack in
// order to consume the extra pop within OP_CHECKMULTISIG.
witnessStack := wire.TxWitness(make([][]byte, 5))
witnessStack[0] = nil
witnessStack[1] = append(senderSig.Serialize(), byte(senderSigHash))
witnessStack[2] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[3] = paymentPreimage
witnessStack[4] = signDesc.WitnessScript
return witnessStack, nil
}
// ReceiverHtlcSpendRevokeWithKey constructs a valid witness allowing the sender of an
// HTLC within a previously revoked commitment transaction to re-claim the
// pending funds in the case that the receiver broadcasts this revoked
// commitment transaction.
func ReceiverHtlcSpendRevokeWithKey(signer Signer, signDesc *SignDescriptor,
revokeKey *btcec.PublicKey, sweepTx *wire.MsgTx) (wire.TxWitness, error) {
// First, we'll generate a signature for the sweep transaction. The
// signDesc should be signing with the public key used as the fully
// derived revocation public key and also the correct double tweak
// value.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// We place a zero, then one as the first items in the evaluated
// witness stack in order to force script execution to the HTLC
// revocation clause.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = revokeKey.SerializeCompressed()
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
func deriveRevokePubKey(signDesc *SignDescriptor) (*btcec.PublicKey, error) {
if signDesc.KeyDesc.PubKey == nil {
return nil, fmt.Errorf("cannot generate witness with nil " +
"KeyDesc pubkey")
}
// Derive the revocation key using the local revocation base point and
// commitment point.
revokeKey := DeriveRevocationPubkey(
signDesc.KeyDesc.PubKey,
signDesc.DoubleTweak.PubKey(),
)
return revokeKey, nil
}
// ReceiverHtlcSpendRevoke constructs a valid witness allowing the sender of an
// HTLC within a previously revoked commitment transaction to re-claim the
// pending funds in the case that the receiver broadcasts this revoked
// commitment transaction. This method first derives the appropriate revocation
// key, and requires that the provided SignDescriptor has a local revocation
// basepoint and commitment secret in the PubKey and DoubleTweak fields,
// respectively.
func ReceiverHtlcSpendRevoke(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
revokeKey, err := deriveRevokePubKey(signDesc)
if err != nil {
return nil, err
}
return ReceiverHtlcSpendRevokeWithKey(signer, signDesc, revokeKey, sweepTx)
}
// ReceiverHtlcSpendTimeout constructs a valid witness allowing the sender of
// an HTLC to recover the pending funds after an absolute timeout in the
// scenario that the receiver of the HTLC broadcasts their version of the
// commitment transaction. If the caller has already set the lock time on the
// spending transaction, than a value of -1 can be passed for the cltvExpiry
// value.
//
// NOTE: The target input of the passed transaction MUST NOT have a final
// sequence number. Otherwise, the OP_CHECKLOCKTIMEVERIFY check will fail.
func ReceiverHtlcSpendTimeout(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, cltvExpiry int32) (wire.TxWitness, error) {
// If the caller set a proper timeout value, then we'll apply it
// directly to the transaction.
if cltvExpiry != -1 {
// The HTLC output has an absolute time period before we are
// permitted to recover the pending funds. Therefore we need to
// set the locktime on this sweeping transaction in order to
// pass Script verification.
sweepTx.LockTime = uint32(cltvExpiry)
}
// With the lock time on the transaction set, we'll not generate a
// signature for the sweep transaction. The passed sign descriptor
// should be created using the raw public key of the sender (w/o the
// single tweak applied), and the single tweak set to the proper value
// taking into account the current state's point.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = nil
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// ReceiverHtlcTapLeafTimeout returns the full tapscript leaf for the timeout
// path of the sender HTLC. This is a small script that allows the sender
// timeout the HTLC after expiry:
//
// <sender_htlcpubkey> OP_CHECKSIG
// 1 OP_CHECKSEQUENCEVERIFY OP_DROP
// <cltv_expiry> OP_CHECKLOCKTIMEVERIFY OP_DROP
func ReceiverHtlcTapLeafTimeout(senderHtlcKey *btcec.PublicKey,
cltvExpiry uint32) (txscript.TapLeaf, error) {
builder := txscript.NewScriptBuilder()
// The first part of the script will verify a signature from the
// sender authorizing the spend (the timeout).
builder.AddData(schnorr.SerializePubKey(senderHtlcKey))
builder.AddOp(txscript.OP_CHECKSIG)
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
// The second portion will ensure that the CLTV expiry on the spending
// transaction is correct.
builder.AddInt64(int64(cltvExpiry))
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
builder.AddOp(txscript.OP_DROP)
timeoutLeafScript, err := builder.Script()
if err != nil {
return txscript.TapLeaf{}, err
}
return txscript.NewBaseTapLeaf(timeoutLeafScript), nil
}
// ReceiverHtlcTapLeafSuccess returns the full tapscript leaf for the success
// path for an HTLC on the receiver's commitment transaction. This script
// allows the receiver to redeem an HTLC with knowledge of the preimage:
//
// OP_SIZE 32 OP_EQUALVERIFY OP_HASH160
// <RIPEMD160(payment_hash)> OP_EQUALVERIFY
// <receiver_htlcpubkey> OP_CHECKSIGVERIFY
// <sender_htlcpubkey> OP_CHECKSIG
func ReceiverHtlcTapLeafSuccess(receiverHtlcKey *btcec.PublicKey,
senderHtlcKey *btcec.PublicKey,
paymentHash []byte) (txscript.TapLeaf, error) {
builder := txscript.NewScriptBuilder()
// Check that the pre-image is 32 bytes as required.
builder.AddOp(txscript.OP_SIZE)
builder.AddInt64(32)
builder.AddOp(txscript.OP_EQUALVERIFY)
// Check that the specified pre-image matches what we hard code into
// the script.
builder.AddOp(txscript.OP_HASH160)
builder.AddData(Ripemd160H(paymentHash))
builder.AddOp(txscript.OP_EQUALVERIFY)
// Verify the "2-of-2" multi-sig that requires both parties to sign
// off.
builder.AddData(schnorr.SerializePubKey(receiverHtlcKey))
builder.AddOp(txscript.OP_CHECKSIGVERIFY)
builder.AddData(schnorr.SerializePubKey(senderHtlcKey))
builder.AddOp(txscript.OP_CHECKSIG)
successLeafScript, err := builder.Script()
if err != nil {
return txscript.TapLeaf{}, err
}
return txscript.NewBaseTapLeaf(successLeafScript), nil
}
// receiverHtlcTapScriptTree builds the tapscript tree which is used to anchor
// the HTLC key for HTLCs on the receiver's commitment.
func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey,
revokeKey *btcec.PublicKey, payHash []byte, cltvExpiry uint32,
hType htlcType, auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) {
// First, we'll obtain the tap leaves for both the success and timeout
// path.
successTapLeaf, err := ReceiverHtlcTapLeafSuccess(
receiverHtlcKey, senderHtlcKey, payHash,
)
if err != nil {
return nil, err
}
timeoutTapLeaf, err := ReceiverHtlcTapLeafTimeout(
senderHtlcKey, cltvExpiry,
)
if err != nil {
return nil, err
}
tapLeaves := []txscript.TapLeaf{timeoutTapLeaf, successTapLeaf}
auxLeaf.WhenSome(func(l txscript.TapLeaf) {
tapLeaves = append(tapLeaves, l)
})
// With the two leaves obtained, we'll now make the tapscript tree,
// then obtain the root from that
tapscriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...)
tapScriptRoot := tapscriptTree.RootNode.TapHash()
// With the tapscript root obtained, we'll tweak the revocation key
// with this value to obtain the key that HTLCs will be sent to.
htlcKey := txscript.ComputeTaprootOutputKey(
revokeKey, tapScriptRoot[:],
)
return &HtlcScriptTree{
ScriptTree: ScriptTree{
TaprootKey: htlcKey,
TapscriptTree: tapscriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: revokeKey,
},
SuccessTapLeaf: successTapLeaf,
TimeoutTapLeaf: timeoutTapLeaf,
AuxLeaf: auxLeaf,
htlcType: hType,
}, nil
}
// ReceiverHTLCScriptTaproot constructs the taproot witness program (schnor
// key) for an incoming HTLC on the receiver's version of the commitment
// transaction. This method returns the top level tweaked public key that
// commits to both the script paths. From the PoV of the receiver, this is an
// accepted HTLC.
//
// The returned key commits to a tapscript tree with two possible paths:
//
// - The timeout path:
// <remote_htlcpubkey> OP_CHECKSIG
// 1 OP_CHECKSEQUENCEVERIFY OP_DROP
// <cltv_expiry> OP_CHECKLOCKTIMEVERIFY OP_DROP
//
// - Success path:
// OP_SIZE 32 OP_EQUALVERIFY
// OP_HASH160 <RIPEMD160(payment_hash)> OP_EQUALVERIFY
// <local_htlcpubkey> OP_CHECKSIGVERIFY
// <remote_htlcpubkey> OP_CHECKSIG
//
// The timeout path can be spent with a witness of:
// - <sender sig> <timeout_script> <control_block>
//
// The success path can be spent with a witness of:
// - <sender sig> <receiver sig> <preimage> <success_script> <control_block>
//
// The top level keyspend key is the revocation key, which allows a defender to
// unilaterally spend the created output. Both the final output key as well as
// the tap leaf are returned.
func ReceiverHTLCScriptTaproot(cltvExpiry uint32,
senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey,
payHash []byte, whoseCommit lntypes.ChannelParty,
auxLeaf AuxTapLeaf) (*HtlcScriptTree, error) {
var hType htlcType
if whoseCommit.IsLocal() {
hType = htlcLocalIncoming
} else {
hType = htlcRemoteOutgoing
}
// Given all the necessary parameters, we'll return the HTLC script
// tree that includes the top level output script, as well as the two
// tap leaf paths.
return receiverHtlcTapScriptTree(
senderHtlcKey, receiverHtlcKey, revocationKey, payHash,
cltvExpiry, hType, auxLeaf,
)
}
// ReceiverHTLCScriptTaprootRedeem creates a valid witness needed to redeem a
// receiver taproot HTLC with the pre-image. The returned witness is valid and
// includes the control block required to spend the output.
func ReceiverHTLCScriptTaprootRedeem(senderSig Signature,
senderSigHash txscript.SigHashType, paymentPreimage []byte,
signer Signer, signDesc *SignDescriptor,
htlcSuccessTx *wire.MsgTx, revokeKey *btcec.PublicKey,
tapscriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// First, we'll generate a signature for the HTLC success transaction.
// The signDesc should be signing with the public key used as the
// receiver's public key and also the correct single tweak.
sweepSig, err := signer.SignOutputRaw(htlcSuccessTx, signDesc)
if err != nil {
return nil, err
}
// In addition to the signature and the witness/leaf script, we also
// need to make a control block proof using the tapscript tree.
var ctrlBlock []byte
if signDesc.ControlBlock == nil {
redeemControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, revokeKey, tapscriptTree,
)
ctrlBytes, err := redeemControlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlock = ctrlBytes
} else {
ctrlBlock = signDesc.ControlBlock
}
// The final witness stack is:
// * <sender sig> <receiver sig> <preimage> <success_script>
// <control_block>
witnessStack := wire.TxWitness(make([][]byte, 5))
witnessStack[0] = maybeAppendSighash(senderSig, senderSigHash)
witnessStack[1] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[2] = paymentPreimage
witnessStack[3] = signDesc.WitnessScript
witnessStack[4] = ctrlBlock
return witnessStack, nil
}
// ReceiverHTLCScriptTaprootTimeout creates a valid witness needed to timeout
// an HTLC on the receiver's commitment transaction after the timeout has
// elapsed.
func ReceiverHTLCScriptTaprootTimeout(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, cltvExpiry int32, revokeKey *btcec.PublicKey,
tapscriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// If the caller set a proper timeout value, then we'll apply it
// directly to the transaction.
//
// TODO(roasbeef): helper func
if cltvExpiry != -1 {
// The HTLC output has an absolute time period before we are
// permitted to recover the pending funds. Therefore we need to
// set the locktime on this sweeping transaction in order to
// pass Script verification.
sweepTx.LockTime = uint32(cltvExpiry)
}
// With the lock time on the transaction set, we'll now generate a
// signature for the sweep transaction. The passed sign descriptor
// should be created using the raw public key of the sender (w/o the
// single tweak applied), and the single tweak set to the proper value
// taking into account the current state's point.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// In addition to the signature and the witness/leaf script, we also
// need to make a control block proof using the tapscript tree.
var ctrlBlock []byte
if signDesc.ControlBlock == nil {
timeoutControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, revokeKey, tapscriptTree,
)
ctrlBlock, err = timeoutControlBlock.ToBytes()
if err != nil {
return nil, err
}
} else {
ctrlBlock = signDesc.ControlBlock
}
// The final witness is pretty simple, we just need to present a valid
// signature for the script, and then provide the control block.
witnessStack := make(wire.TxWitness, 3)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[1] = signDesc.WitnessScript
witnessStack[2] = ctrlBlock
return witnessStack, nil
}
// ReceiverHTLCScriptTaprootRevoke creates a valid witness needed to spend the
// revocation path of the HTLC from the PoV of the sender (offerer) of the
// HTLC. This uses a plain keyspend using the specified revocation key.
func ReceiverHTLCScriptTaprootRevoke(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The witness stack in this case is pretty simple: we only need to
// specify the signature generated.
witnessStack := make(wire.TxWitness, 1)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
return witnessStack, nil
}
// SecondLevelHtlcScript is the uniform script that's used as the output for
// the second-level HTLC transactions. The second level transaction act as a
// sort of covenant, ensuring that a 2-of-2 multi-sig output can only be
// spent in a particular way, and to a particular output.
//
// Possible Input Scripts:
//
// - To revoke an HTLC output that has been transitioned to the claim+delay
// state:
// <revoke sig> 1
//
// - To claim and HTLC output, either with a pre-image or due to a timeout:
// <delay sig> 0
//
// Output Script:
//
// OP_IF
// <revoke key>
// OP_ELSE
// <delay in blocks>
// OP_CHECKSEQUENCEVERIFY
// OP_DROP
// <delay key>
// OP_ENDIF
// OP_CHECKSIG
//
// TODO(roasbeef): possible renames for second-level
// - transition?
// - covenant output
func SecondLevelHtlcScript(revocationKey, delayKey *btcec.PublicKey,
csvDelay uint32) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
ToLocalScriptSize,
))
// If this is the revocation clause for this script is to be executed,
// the spender will push a 1, forcing us to hit the true clause of this
// if statement.
builder.AddOp(txscript.OP_IF)
// If this is the revocation case, then we'll push the revocation
// public key on the stack.
builder.AddData(revocationKey.SerializeCompressed())
// Otherwise, this is either the sender or receiver of the HTLC
// attempting to claim the HTLC output.
builder.AddOp(txscript.OP_ELSE)
// In order to give the other party time to execute the revocation
// clause above, we require a relative timeout to pass before the
// output can be spent.
builder.AddInt64(int64(csvDelay))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
// If the relative timelock passes, then we'll add the delay key to the
// stack to ensure that we properly authenticate the spending party.
builder.AddData(delayKey.SerializeCompressed())
// Close out the if statement.
builder.AddOp(txscript.OP_ENDIF)
// In either case, we'll ensure that only either the party possessing
// the revocation private key, or the delay private key is able to
// spend this output.
builder.AddOp(txscript.OP_CHECKSIG)
return builder.Script()
}
// TODO(roasbeef): move all taproot stuff to new file?
// TaprootSecondLevelTapLeaf constructs the tap leaf used as the sole script
// path for a second level HTLC spend.
//
// The final script used is:
//
// <local_delay_key> OP_CHECKSIG
// <to_self_delay> OP_CHECKSEQUENCEVERIFY OP_DROP
func TaprootSecondLevelTapLeaf(delayKey *btcec.PublicKey,
csvDelay uint32) (txscript.TapLeaf, error) {
builder := txscript.NewScriptBuilder()
// Ensure the proper party can sign for this output.
builder.AddData(schnorr.SerializePubKey(delayKey))
builder.AddOp(txscript.OP_CHECKSIG)
// Assuming the above passes, then we'll now ensure that the CSV delay
// has been upheld, dropping the int we pushed on. If the sig above is
// valid, then a 1 will be left on the stack.
builder.AddInt64(int64(csvDelay))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
secondLevelLeafScript, err := builder.Script()
if err != nil {
return txscript.TapLeaf{}, err
}
return txscript.NewBaseTapLeaf(secondLevelLeafScript), nil
}
// SecondLevelHtlcTapscriptTree construct the indexed tapscript tree needed to
// generate the tap tweak to create the final output and also control block.
func SecondLevelHtlcTapscriptTree(delayKey *btcec.PublicKey, csvDelay uint32,
auxLeaf AuxTapLeaf) (*txscript.IndexedTapScriptTree, error) {
// First grab the second level leaf script we need to create the top
// level output.
secondLevelTapLeaf, err := TaprootSecondLevelTapLeaf(delayKey, csvDelay)
if err != nil {
return nil, err
}
tapLeaves := []txscript.TapLeaf{secondLevelTapLeaf}
auxLeaf.WhenSome(func(l txscript.TapLeaf) {
tapLeaves = append(tapLeaves, l)
})
// Now that we have the sole second level script, we can create the
// tapscript tree that commits to both the leaves.
return txscript.AssembleTaprootScriptTree(tapLeaves...), nil
}
// TaprootSecondLevelHtlcScript is the uniform script that's used as the output
// for the second-level HTLC transaction. The second level transaction acts as
// an off-chain 2-of-2 covenant that can only be spent a particular way and to
// a particular output.
//
// Possible Input Scripts:
// - revocation sig
// - <local_delay_sig>
//
// The script main script lets the broadcaster spend after a delay the script
// path:
//
// <local_delay_key> OP_CHECKSIG
// <to_self_delay> OP_CHECKSEQUENCEVERIFY OP_DROP
//
// The keyspend path require knowledge of the top level revocation private key.
func TaprootSecondLevelHtlcScript(revokeKey, delayKey *btcec.PublicKey,
csvDelay uint32, auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) {
// First, we'll make the tapscript tree that commits to the redemption
// path.
tapScriptTree, err := SecondLevelHtlcTapscriptTree(
delayKey, csvDelay, auxLeaf,
)
if err != nil {
return nil, err
}
tapScriptRoot := tapScriptTree.RootNode.TapHash()
// With the tapscript root obtained, we'll tweak the revocation key
// with this value to obtain the key that the second level spend will
// create.
redemptionKey := txscript.ComputeTaprootOutputKey(
revokeKey, tapScriptRoot[:],
)
return redemptionKey, nil
}
// SecondLevelScriptTree is a tapscript tree used to spend the second level
// HTLC output after the CSV delay has passed.
type SecondLevelScriptTree struct {
ScriptTree
// SuccessTapLeaf is the tapleaf for the redemption path.
SuccessTapLeaf txscript.TapLeaf
// AuxLeaf is an optional leaf that can be used to extend the script
// tree.
AuxLeaf AuxTapLeaf
}
// TaprootSecondLevelScriptTree constructs the tapscript tree used to spend the
// second level HTLC output.
func TaprootSecondLevelScriptTree(revokeKey, delayKey *btcec.PublicKey,
csvDelay uint32, auxLeaf AuxTapLeaf) (*SecondLevelScriptTree, error) {
// First, we'll make the tapscript tree that commits to the redemption
// path.
tapScriptTree, err := SecondLevelHtlcTapscriptTree(
delayKey, csvDelay, auxLeaf,
)
if err != nil {
return nil, err
}
// With the tree constructed, we can make the pkscript which is the
// taproot output key itself.
tapScriptRoot := tapScriptTree.RootNode.TapHash()
outputKey := txscript.ComputeTaprootOutputKey(
revokeKey, tapScriptRoot[:],
)
return &SecondLevelScriptTree{
ScriptTree: ScriptTree{
TaprootKey: outputKey,
TapscriptTree: tapScriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: revokeKey,
},
SuccessTapLeaf: tapScriptTree.LeafMerkleProofs[0].TapLeaf,
AuxLeaf: auxLeaf,
}, nil
}
// WitnessScriptToSign returns the witness script that we'll use when signing
// for the remote party, and also verifying signatures on our transactions. As
// an example, when we create an outgoing HTLC for the remote party, we want to
// sign their success path.
func (s *SecondLevelScriptTree) WitnessScriptToSign() []byte {
return s.SuccessTapLeaf.Script
}
// WitnessScriptForPath returns the witness script for the given spending path.
// An error is returned if the path is unknown.
func (s *SecondLevelScriptTree) WitnessScriptForPath(
path ScriptPath) ([]byte, error) {
switch path {
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return s.SuccessTapLeaf.Script, nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// CtrlBlockForPath returns the control block for the given spending path. For
// script types that don't have a control block, nil is returned.
func (s *SecondLevelScriptTree) CtrlBlockForPath(
path ScriptPath) (*txscript.ControlBlock, error) {
switch path {
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return lnutils.Ptr(MakeTaprootCtrlBlock(
s.SuccessTapLeaf.Script, s.InternalKey,
s.TapscriptTree,
)), nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// Tree returns the underlying ScriptTree of the SecondLevelScriptTree.
func (s *SecondLevelScriptTree) Tree() ScriptTree {
return s.ScriptTree
}
// A compile time check to ensure SecondLevelScriptTree implements the
// TapscriptDescriptor interface.
var _ TapscriptDescriptor = (*SecondLevelScriptTree)(nil)
// TaprootHtlcSpendRevoke spends a second-level HTLC output via the revocation
// path. This uses the top level keyspend path to redeem the contested output.
//
// The passed SignDescriptor MUST have the proper witness script and also the
// proper top-level tweak derived from the tapscript tree for the second level
// output.
func TaprootHtlcSpendRevoke(signer Signer, signDesc *SignDescriptor,
revokeTx *wire.MsgTx) (wire.TxWitness, error) {
// We don't need any spacial modifications to the transaction as this
// is just sweeping a revoked HTLC output. So we'll generate a regular
// schnorr signature.
sweepSig, err := signer.SignOutputRaw(revokeTx, signDesc)
if err != nil {
return nil, err
}
// The witness stack in this case is pretty simple: we only need to
// specify the signature generated.
witnessStack := make(wire.TxWitness, 1)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
return witnessStack, nil
}
// TaprootHtlcSpendSuccess spends a second-level HTLC output via the redemption
// path. This should be used to sweep funds after the pre-image is known.
//
// NOTE: The caller MUST set the txn version, sequence number, and sign
// descriptor's sig hash cache before invocation.
func TaprootHtlcSpendSuccess(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, revokeKey *btcec.PublicKey,
tapscriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// First, we'll generate the sweep signature based on the populated
// sign desc. This should give us a valid schnorr signature for the
// sole script path leaf.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
var ctrlBlock []byte
if signDesc.ControlBlock == nil {
// Now that we have the sweep signature, we'll construct the
// control block needed to spend the script path.
redeemControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, revokeKey, tapscriptTree,
)
ctrlBlock, err = redeemControlBlock.ToBytes()
if err != nil {
return nil, err
}
} else {
ctrlBlock = signDesc.ControlBlock
}
// Now that we have the redeem control block, we can construct the
// final witness needed to spend the script:
//
// <success sig> <success script> <control_block>
witnessStack := make(wire.TxWitness, 3)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[1] = signDesc.WitnessScript
witnessStack[2] = ctrlBlock
return witnessStack, nil
}
// LeaseSecondLevelHtlcScript is the uniform script that's used as the output
// for the second-level HTLC transactions. The second level transaction acts as
// a sort of covenant, ensuring that a 2-of-2 multi-sig output can only be
// spent in a particular way, and to a particular output.
//
// Possible Input Scripts:
//
// - To revoke an HTLC output that has been transitioned to the claim+delay
// state:
// <revoke sig> 1
//
// - To claim an HTLC output, either with a pre-image or due to a timeout:
// <delay sig> 0
//
// Output Script:
//
// OP_IF
// <revoke key>
// OP_ELSE
// <lease maturity in blocks>
// OP_CHECKLOCKTIMEVERIFY
// OP_DROP
// <delay in blocks>
// OP_CHECKSEQUENCEVERIFY
// OP_DROP
// <delay key>
// OP_ENDIF
// OP_CHECKSIG.
func LeaseSecondLevelHtlcScript(revocationKey, delayKey *btcec.PublicKey,
csvDelay, cltvExpiry uint32) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
ToLocalScriptSize + LeaseWitnessScriptSizeOverhead,
))
// If this is the revocation clause for this script is to be executed,
// the spender will push a 1, forcing us to hit the true clause of this
// if statement.
builder.AddOp(txscript.OP_IF)
// If this this is the revocation case, then we'll push the revocation
// public key on the stack.
builder.AddData(revocationKey.SerializeCompressed())
// Otherwise, this is either the sender or receiver of the HTLC
// attempting to claim the HTLC output.
builder.AddOp(txscript.OP_ELSE)
// The channel initiator always has the additional channel lease
// expiration constraint for outputs that pay to them which must be
// satisfied.
builder.AddInt64(int64(cltvExpiry))
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
builder.AddOp(txscript.OP_DROP)
// In order to give the other party time to execute the revocation
// clause above, we require a relative timeout to pass before the
// output can be spent.
builder.AddInt64(int64(csvDelay))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
// If the relative timelock passes, then we'll add the delay key to the
// stack to ensure that we properly authenticate the spending party.
builder.AddData(delayKey.SerializeCompressed())
// Close out the if statement.
builder.AddOp(txscript.OP_ENDIF)
// In either case, we'll ensure that only either the party possessing
// the revocation private key, or the delay private key is able to
// spend this output.
builder.AddOp(txscript.OP_CHECKSIG)
return builder.Script()
}
// HtlcSpendSuccess spends a second-level HTLC output. This function is to be
// used by the sender of an HTLC to claim the output after a relative timeout
// or the receiver of the HTLC to claim on-chain with the pre-image.
func HtlcSpendSuccess(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, csvDelay uint32) (wire.TxWitness, error) {
// We're required to wait a relative period of time before we can sweep
// the output in order to allow the other party to contest our claim of
// validity to this version of the commitment transaction.
sweepTx.TxIn[0].Sequence = LockTimeToSequence(false, csvDelay)
// Finally, OP_CSV requires that the version of the transaction
// spending a pkscript with OP_CSV within it *must* be >= 2.
sweepTx.Version = 2
// As we mutated the transaction, we'll re-calculate the sighashes for
// this instance.
signDesc.SigHashes = NewTxSigHashesV0Only(sweepTx)
// With the proper sequence and version set, we'll now sign the timeout
// transaction using the passed signed descriptor. In order to generate
// a valid signature, then signDesc should be using the base delay
// public key, and the proper single tweak bytes.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// We set a zero as the first element the witness stack (ignoring the
// witness script), in order to force execution to the second portion
// of the if clause.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = nil
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// HtlcSpendRevoke spends a second-level HTLC output. This function is to be
// used by the sender or receiver of an HTLC to claim the HTLC after a revoked
// commitment transaction was broadcast.
func HtlcSpendRevoke(signer Signer, signDesc *SignDescriptor,
revokeTx *wire.MsgTx) (wire.TxWitness, error) {
// We don't need any spacial modifications to the transaction as this
// is just sweeping a revoked HTLC output. So we'll generate a regular
// witness signature.
sweepSig, err := signer.SignOutputRaw(revokeTx, signDesc)
if err != nil {
return nil, err
}
// We set a one as the first element the witness stack (ignoring the
// witness script), in order to force execution to the revocation
// clause in the second level HTLC script.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = []byte{1}
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// HtlcSecondLevelSpend exposes the public witness generation function for
// spending an HTLC success transaction, either due to an expiring time lock or
// having had the payment preimage. This method is able to spend any
// second-level HTLC transaction, assuming the caller sets the locktime or
// seqno properly.
//
// NOTE: The caller MUST set the txn version, sequence number, and sign
// descriptor's sig hash cache before invocation.
func HtlcSecondLevelSpend(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
// With the proper sequence and version set, we'll now sign the timeout
// transaction using the passed signed descriptor. In order to generate
// a valid signature, then signDesc should be using the base delay
// public key, and the proper single tweak bytes.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// We set a zero as the first element the witness stack (ignoring the
// witness script), in order to force execution to the second portion
// of the if clause.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(txscript.SigHashAll))
witnessStack[1] = nil
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// LockTimeToSequence converts the passed relative locktime to a sequence
// number in accordance to BIP-68.
// See: https://github.com/bitcoin/bips/blob/master/bip-0068.mediawiki
// - (Compatibility)
func LockTimeToSequence(isSeconds bool, locktime uint32) uint32 {
if !isSeconds {
// The locktime is to be expressed in confirmations.
return locktime
}
// Set the 22nd bit which indicates the lock time is in seconds, then
// shift the locktime over by 9 since the time granularity is in
// 512-second intervals (2^9). This results in a max lock-time of
// 33,554,431 seconds, or 1.06 years.
return SequenceLockTimeSeconds | (locktime >> 9)
}
// CommitScriptToSelf constructs the public key script for the output on the
// commitment transaction paying to the "owner" of said commitment transaction.
// If the other party learns of the preimage to the revocation hash, then they
// can claim all the settled funds in the channel, plus the unsettled funds.
//
// Possible Input Scripts:
//
// REVOKE: <sig> 1
// SENDRSWEEP: <sig> <emptyvector>
//
// Output Script:
//
// OP_IF
// <revokeKey>
// OP_ELSE
// <numRelativeBlocks> OP_CHECKSEQUENCEVERIFY OP_DROP
// <selfKey>
// OP_ENDIF
// OP_CHECKSIG
func CommitScriptToSelf(csvTimeout uint32, selfKey, revokeKey *btcec.PublicKey) ([]byte, error) {
// This script is spendable under two conditions: either the
// 'csvTimeout' has passed and we can redeem our funds, or they can
// produce a valid signature with the revocation public key. The
// revocation public key will *only* be known to the other party if we
// have divulged the revocation hash, allowing them to homomorphically
// derive the proper private key which corresponds to the revoke public
// key.
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
ToLocalScriptSize,
))
builder.AddOp(txscript.OP_IF)
// If a valid signature using the revocation key is presented, then
// allow an immediate spend provided the proper signature.
builder.AddData(revokeKey.SerializeCompressed())
builder.AddOp(txscript.OP_ELSE)
// Otherwise, we can re-claim our funds after a CSV delay of
// 'csvTimeout' timeout blocks, and a valid signature.
builder.AddInt64(int64(csvTimeout))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
builder.AddData(selfKey.SerializeCompressed())
builder.AddOp(txscript.OP_ENDIF)
// Finally, we'll validate the signature against the public key that's
// left on the top of the stack.
builder.AddOp(txscript.OP_CHECKSIG)
return builder.Script()
}
// CommitScriptTree holds the taproot output key (in this case the revocation
// key, or a NUMs point for the remote output) along with the tapscript leaf
// that can spend the output after a delay.
type CommitScriptTree struct {
ScriptTree
// SettleLeaf is the leaf used to settle the output after the delay.
SettleLeaf txscript.TapLeaf
// RevocationLeaf is the leaf used to spend the output with the
// revocation key signature.
RevocationLeaf txscript.TapLeaf
// AuxLeaf is an auxiliary leaf that can be used to extend the base
// commitment script tree with new spend paths, or just as extra
// commitment space. When present, this leaf will always be in the
// left-most or right-most area of the tapscript tree.
AuxLeaf AuxTapLeaf
}
// A compile time check to ensure CommitScriptTree implements the
// TapscriptDescriptor interface.
var _ TapscriptDescriptor = (*CommitScriptTree)(nil)
// WitnessScriptToSign returns the witness script that we'll use when signing
// for the remote party, and also verifying signatures on our transactions. As
// an example, when we create an outgoing HTLC for the remote party, we want to
// sign their success path.
func (c *CommitScriptTree) WitnessScriptToSign() []byte {
// TODO(roasbeef): abstraction leak here? always dependent
return nil
}
// WitnessScriptForPath returns the witness script for the given spending path.
// An error is returned if the path is unknown.
func (c *CommitScriptTree) WitnessScriptForPath(
path ScriptPath) ([]byte, error) {
switch path {
// For the commitment output, the delay and success path are the same,
// so we'll fall through here to success.
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return c.SettleLeaf.Script, nil
case ScriptPathRevocation:
return c.RevocationLeaf.Script, nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// CtrlBlockForPath returns the control block for the given spending path. For
// script types that don't have a control block, nil is returned.
func (c *CommitScriptTree) CtrlBlockForPath(
path ScriptPath) (*txscript.ControlBlock, error) {
switch path {
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return lnutils.Ptr(MakeTaprootCtrlBlock(
c.SettleLeaf.Script, c.InternalKey,
c.TapscriptTree,
)), nil
case ScriptPathRevocation:
return lnutils.Ptr(MakeTaprootCtrlBlock(
c.RevocationLeaf.Script, c.InternalKey,
c.TapscriptTree,
)), nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// Tree returns the underlying ScriptTree of the CommitScriptTree.
func (c *CommitScriptTree) Tree() ScriptTree {
return c.ScriptTree
}
// NewLocalCommitScriptTree returns a new CommitScript tree that can be used to
// create and spend the commitment output for the local party.
func NewLocalCommitScriptTree(csvTimeout uint32, selfKey,
revokeKey *btcec.PublicKey, auxLeaf AuxTapLeaf) (*CommitScriptTree,
error) {
// First, we'll need to construct the tapLeaf that'll be our delay CSV
// clause.
delayScript, err := TaprootLocalCommitDelayScript(csvTimeout, selfKey)
if err != nil {
return nil, err
}
// Next, we'll need to construct the revocation path, which is just a
// simple checksig script.
revokeScript, err := TaprootLocalCommitRevokeScript(selfKey, revokeKey)
if err != nil {
return nil, err
}
// With both scripts computed, we'll now create a tapscript tree with
// the two leaves, and then obtain a root from that.
delayTapLeaf := txscript.NewBaseTapLeaf(delayScript)
revokeTapLeaf := txscript.NewBaseTapLeaf(revokeScript)
tapLeaves := []txscript.TapLeaf{delayTapLeaf, revokeTapLeaf}
auxLeaf.WhenSome(func(l txscript.TapLeaf) {
tapLeaves = append(tapLeaves, l)
})
tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...)
tapScriptRoot := tapScriptTree.RootNode.TapHash()
// Now that we have our root, we can arrive at the final output script
// by tweaking the internal key with this root.
toLocalOutputKey := txscript.ComputeTaprootOutputKey(
&TaprootNUMSKey, tapScriptRoot[:],
)
return &CommitScriptTree{
ScriptTree: ScriptTree{
TaprootKey: toLocalOutputKey,
TapscriptTree: tapScriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: &TaprootNUMSKey,
},
SettleLeaf: delayTapLeaf,
RevocationLeaf: revokeTapLeaf,
AuxLeaf: auxLeaf,
}, nil
}
// TaprootLocalCommitDelayScript builds the tap leaf with the CSV delay script
// for the to-local output.
func TaprootLocalCommitDelayScript(csvTimeout uint32,
selfKey *btcec.PublicKey) ([]byte, error) {
builder := txscript.NewScriptBuilder()
builder.AddData(schnorr.SerializePubKey(selfKey))
builder.AddOp(txscript.OP_CHECKSIG)
builder.AddInt64(int64(csvTimeout))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
return builder.Script()
}
// TaprootLocalCommitRevokeScript builds the tap leaf with the revocation path
// for the to-local output.
func TaprootLocalCommitRevokeScript(selfKey, revokeKey *btcec.PublicKey) (
[]byte, error) {
builder := txscript.NewScriptBuilder()
builder.AddData(schnorr.SerializePubKey(selfKey))
builder.AddOp(txscript.OP_DROP)
builder.AddData(schnorr.SerializePubKey(revokeKey))
builder.AddOp(txscript.OP_CHECKSIG)
return builder.Script()
}
// TaprootCommitScriptToSelf creates the taproot witness program that commits
// to the revocation (script path) and delay path (script path) in a single
// taproot output key. Both the delay script and the revocation script are part
// of the tapscript tree to ensure that the internal key (the local delay key)
// is always revealed. This ensures that a 3rd party can always sweep the set
// of anchor outputs.
//
// For the delay path we have the following tapscript leaf script:
//
// <local_delayedpubkey> OP_CHECKSIG
// <to_self_delay> OP_CHECKSEQUENCEVERIFY OP_DROP
//
// This can then be spent with just:
//
// <local_delayedsig> <to_delay_script> <delay_control_block>
//
// Where the to_delay_script is listed above, and the delay_control_block
// computed as:
//
// delay_control_block = (output_key_y_parity | 0xc0) || taproot_nums_key
//
// The revocation path is simply:
//
// <local_delayedpubkey> OP_DROP
// <revocationkey> OP_CHECKSIG
//
// The revocation path can be spent with a control block similar to the above
// (but contains the hash of the other script), and with the following witness:
//
// <revocation_sig>
//
// We use a noop data push to ensure that the local public key is also revealed
// on chain, which enables the anchor output to be swept.
func TaprootCommitScriptToSelf(csvTimeout uint32,
selfKey, revokeKey *btcec.PublicKey) (*btcec.PublicKey, error) {
commitScriptTree, err := NewLocalCommitScriptTree(
csvTimeout, selfKey, revokeKey, NoneTapLeaf(),
)
if err != nil {
return nil, err
}
return commitScriptTree.TaprootKey, nil
}
// MakeTaprootCtrlBlock takes a leaf script, the internal key (usually the
// revoke key), and a script tree and creates a valid control block for a spend
// of the leaf.
func MakeTaprootCtrlBlock(leafScript []byte, internalKey *btcec.PublicKey,
scriptTree *txscript.IndexedTapScriptTree) txscript.ControlBlock {
tapLeafHash := txscript.NewBaseTapLeaf(leafScript).TapHash()
scriptIdx := scriptTree.LeafProofIndex[tapLeafHash]
settleMerkleProof := scriptTree.LeafMerkleProofs[scriptIdx]
return settleMerkleProof.ToControlBlock(internalKey)
}
// TaprootCommitSpendSuccess constructs a valid witness allowing a node to
// sweep the settled taproot output after the delay has passed for a force
// close.
func TaprootCommitSpendSuccess(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx,
scriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// First, we'll need to construct a valid control block to execute the
// leaf script for sweep settlement.
//
// TODO(roasbeef); make into closure instead? only need reovke key and
// scriptTree to make the ctrl block -- then default version that would
// take froms ign desc?
var ctrlBlockBytes []byte
if signDesc.ControlBlock == nil {
settleControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, &TaprootNUMSKey, scriptTree,
)
ctrlBytes, err := settleControlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlockBytes = ctrlBytes
} else {
ctrlBlockBytes = signDesc.ControlBlock
}
// With the control block created, we'll now generate the signature we
// need to authorize the spend.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The final witness stack will be:
//
// <sweep sig> <sweep script> <control block>
witnessStack := make(wire.TxWitness, 3)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[1] = signDesc.WitnessScript
witnessStack[2] = ctrlBlockBytes
return witnessStack, nil
}
// TaprootCommitSpendRevoke constructs a valid witness allowing a node to sweep
// the revoked taproot output of a malicious peer.
func TaprootCommitSpendRevoke(signer Signer, signDesc *SignDescriptor,
revokeTx *wire.MsgTx,
scriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// First, we'll need to construct a valid control block to execute the
// leaf script for revocation path.
var ctrlBlockBytes []byte
if signDesc.ControlBlock == nil {
revokeCtrlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, &TaprootNUMSKey, scriptTree,
)
revokeBytes, err := revokeCtrlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlockBytes = revokeBytes
} else {
ctrlBlockBytes = signDesc.ControlBlock
}
// With the control block created, we'll now generate the signature we
// need to authorize the spend.
revokeSig, err := signer.SignOutputRaw(revokeTx, signDesc)
if err != nil {
return nil, err
}
// The final witness stack will be:
//
// <revoke sig sig> <revoke script> <control block>
witnessStack := make(wire.TxWitness, 3)
witnessStack[0] = maybeAppendSighash(revokeSig, signDesc.HashType)
witnessStack[1] = signDesc.WitnessScript
witnessStack[2] = ctrlBlockBytes
return witnessStack, nil
}
// LeaseCommitScriptToSelf constructs the public key script for the output on the
// commitment transaction paying to the "owner" of said commitment transaction.
// If the other party learns of the preimage to the revocation hash, then they
// can claim all the settled funds in the channel, plus the unsettled funds.
//
// Possible Input Scripts:
//
// REVOKE: <sig> 1
// SENDRSWEEP: <sig> <emptyvector>
//
// Output Script:
//
// OP_IF
// <revokeKey>
// OP_ELSE
// <absoluteLeaseExpiry> OP_CHECKLOCKTIMEVERIFY OP_DROP
// <numRelativeBlocks> OP_CHECKSEQUENCEVERIFY OP_DROP
// <selfKey>
// OP_ENDIF
// OP_CHECKSIG
func LeaseCommitScriptToSelf(selfKey, revokeKey *btcec.PublicKey,
csvTimeout, leaseExpiry uint32) ([]byte, error) {
// This script is spendable under two conditions: either the
// 'csvTimeout' has passed and we can redeem our funds, or they can
// produce a valid signature with the revocation public key. The
// revocation public key will *only* be known to the other party if we
// have divulged the revocation hash, allowing them to homomorphically
// derive the proper private key which corresponds to the revoke public
// key.
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
ToLocalScriptSize + LeaseWitnessScriptSizeOverhead,
))
builder.AddOp(txscript.OP_IF)
// If a valid signature using the revocation key is presented, then
// allow an immediate spend provided the proper signature.
builder.AddData(revokeKey.SerializeCompressed())
builder.AddOp(txscript.OP_ELSE)
// Otherwise, we can re-claim our funds after once the CLTV lease
// maturity has been met, along with the CSV delay of 'csvTimeout'
// timeout blocks, and a valid signature.
builder.AddInt64(int64(leaseExpiry))
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
builder.AddOp(txscript.OP_DROP)
builder.AddInt64(int64(csvTimeout))
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
builder.AddData(selfKey.SerializeCompressed())
builder.AddOp(txscript.OP_ENDIF)
// Finally, we'll validate the signature against the public key that's
// left on the top of the stack.
builder.AddOp(txscript.OP_CHECKSIG)
return builder.Script()
}
// CommitSpendTimeout constructs a valid witness allowing the owner of a
// particular commitment transaction to spend the output returning settled
// funds back to themselves after a relative block timeout. In order to
// properly spend the transaction, the target input's sequence number should be
// set accordingly based off of the target relative block timeout within the
// redeem script. Additionally, OP_CSV requires that the version of the
// transaction spending a pkscript with OP_CSV within it *must* be >= 2.
func CommitSpendTimeout(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
// Ensure the transaction version supports the validation of sequence
// locks and CSV semantics.
if sweepTx.Version < 2 {
return nil, fmt.Errorf("version of passed transaction MUST "+
"be >= 2, not %v", sweepTx.Version)
}
// With the sequence number in place, we're now able to properly sign
// off on the sweep transaction.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// Place an empty byte as the first item in the evaluated witness stack
// to force script execution to the timeout spend clause. We need to
// place an empty byte in order to ensure our script is still valid
// from the PoV of nodes that are enforcing minimal OP_IF/OP_NOTIF.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = nil
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// CommitSpendRevoke constructs a valid witness allowing a node to sweep the
// settled output of a malicious counterparty who broadcasts a revoked
// commitment transaction.
//
// NOTE: The passed SignDescriptor should include the raw (untweaked)
// revocation base public key of the receiver and also the proper double tweak
// value based on the commitment secret of the revoked commitment.
func CommitSpendRevoke(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// Place a 1 as the first item in the evaluated witness stack to
// force script execution to the revocation clause.
witnessStack := wire.TxWitness(make([][]byte, 3))
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = []byte{1}
witnessStack[2] = signDesc.WitnessScript
return witnessStack, nil
}
// CommitSpendNoDelay constructs a valid witness allowing a node to spend their
// settled no-delay output on the counterparty's commitment transaction. If the
// tweakless field is true, then we'll omit the set where we tweak the pubkey
// with a random set of bytes, and use it directly in the witness stack.
//
// NOTE: The passed SignDescriptor should include the raw (untweaked) public
// key of the receiver and also the proper single tweak value based on the
// current commitment point.
func CommitSpendNoDelay(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx, tweakless bool) (wire.TxWitness, error) {
if signDesc.KeyDesc.PubKey == nil {
return nil, fmt.Errorf("cannot generate witness with nil " +
"KeyDesc pubkey")
}
// This is just a regular p2wkh spend which looks something like:
// * witness: <sig> <pubkey>
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// Finally, we'll manually craft the witness. The witness here is the
// exact same as a regular p2wkh witness, depending on the value of the
// tweakless bool.
witness := make([][]byte, 2)
witness[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
switch tweakless {
// If we're tweaking the key, then we use the tweaked public key as the
// last item in the witness stack which was originally used to created
// the pkScript we're spending.
case false:
witness[1] = TweakPubKeyWithTweak(
signDesc.KeyDesc.PubKey, signDesc.SingleTweak,
).SerializeCompressed()
// Otherwise, we can just use the raw pubkey, since there's no random
// value to be combined.
case true:
witness[1] = signDesc.KeyDesc.PubKey.SerializeCompressed()
}
return witness, nil
}
// CommitScriptUnencumbered constructs the public key script on the commitment
// transaction paying to the "other" party. The constructed output is a normal
// p2wkh output spendable immediately, requiring no contestation period.
func CommitScriptUnencumbered(key *btcec.PublicKey) ([]byte, error) {
// This script goes to the "other" party, and is spendable immediately.
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
P2WPKHSize,
))
builder.AddOp(txscript.OP_0)
builder.AddData(btcutil.Hash160(key.SerializeCompressed()))
return builder.Script()
}
// CommitScriptToRemoteConfirmed constructs the script for the output on the
// commitment transaction paying to the remote party of said commitment
// transaction. The money can only be spend after one confirmation.
//
// Possible Input Scripts:
//
// SWEEP: <sig>
//
// Output Script:
//
// <key> OP_CHECKSIGVERIFY
// 1 OP_CHECKSEQUENCEVERIFY
func CommitScriptToRemoteConfirmed(key *btcec.PublicKey) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
ToRemoteConfirmedScriptSize,
))
// Only the given key can spend the output.
builder.AddData(key.SerializeCompressed())
builder.AddOp(txscript.OP_CHECKSIGVERIFY)
// Check that the it has one confirmation.
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
return builder.Script()
}
// NewRemoteCommitScriptTree constructs a new script tree for the remote party
// to sweep their funds after a hard coded 1 block delay.
func NewRemoteCommitScriptTree(remoteKey *btcec.PublicKey,
auxLeaf AuxTapLeaf) (*CommitScriptTree, error) {
// First, construct the remote party's tapscript they'll use to sweep
// their outputs.
builder := txscript.NewScriptBuilder()
builder.AddData(schnorr.SerializePubKey(remoteKey))
builder.AddOp(txscript.OP_CHECKSIG)
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_DROP)
remoteScript, err := builder.Script()
if err != nil {
return nil, err
}
tapLeaf := txscript.NewBaseTapLeaf(remoteScript)
tapLeaves := []txscript.TapLeaf{tapLeaf}
auxLeaf.WhenSome(func(l txscript.TapLeaf) {
tapLeaves = append(tapLeaves, l)
})
// With this script constructed, we'll map that into a tapLeaf, then
// make a new tapscript root from that.
tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaves...)
tapScriptRoot := tapScriptTree.RootNode.TapHash()
// Now that we have our root, we can arrive at the final output script
// by tweaking the internal key with this root.
toRemoteOutputKey := txscript.ComputeTaprootOutputKey(
&TaprootNUMSKey, tapScriptRoot[:],
)
return &CommitScriptTree{
ScriptTree: ScriptTree{
TaprootKey: toRemoteOutputKey,
TapscriptTree: tapScriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: &TaprootNUMSKey,
},
SettleLeaf: tapLeaf,
AuxLeaf: auxLeaf,
}, nil
}
// TaprootCommitScriptToRemote constructs a taproot witness program for the
// output on the commitment transaction for the remote party. For the top level
// key spend, we'll use a NUMs key to ensure that only the script path can be
// taken. Using a set NUMs key here also means that recovery solutions can scan
// the chain given knowledge of the public key for the remote party. We then
// commit to a single tapscript leaf that holds the normal CSV 1 delay
// script.
//
// Our single tapleaf will use the following script:
//
// <remotepubkey> OP_CHECKSIG
// 1 OP_CHECKSEQUENCEVERIFY OP_DROP
func TaprootCommitScriptToRemote(remoteKey *btcec.PublicKey,
auxLeaf AuxTapLeaf) (*btcec.PublicKey, error) {
commitScriptTree, err := NewRemoteCommitScriptTree(remoteKey, auxLeaf)
if err != nil {
return nil, err
}
return commitScriptTree.TaprootKey, nil
}
// TaprootCommitRemoteSpend allows the remote party to sweep their output into
// their wallet after an enforced 1 block delay.
func TaprootCommitRemoteSpend(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx,
scriptTree *txscript.IndexedTapScriptTree) (wire.TxWitness, error) {
// First, we'll need to construct a valid control block to execute the
// leaf script for sweep settlement.
var ctrlBlockBytes []byte
if signDesc.ControlBlock == nil {
settleControlBlock := MakeTaprootCtrlBlock(
signDesc.WitnessScript, &TaprootNUMSKey, scriptTree,
)
ctrlBytes, err := settleControlBlock.ToBytes()
if err != nil {
return nil, err
}
ctrlBlockBytes = ctrlBytes
} else {
ctrlBlockBytes = signDesc.ControlBlock
}
// With the control block created, we'll now generate the signature we
// need to authorize the spend.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The final witness stack will be:
//
// <sweep sig> <sweep script> <control block>
witnessStack := make(wire.TxWitness, 3)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
witnessStack[1] = signDesc.WitnessScript
witnessStack[2] = ctrlBlockBytes
return witnessStack, nil
}
// LeaseCommitScriptToRemoteConfirmed constructs the script for the output on
// the commitment transaction paying to the remote party of said commitment
// transaction. The money can only be spend after one confirmation.
//
// Possible Input Scripts:
//
// SWEEP: <sig>
//
// Output Script:
//
// <key> OP_CHECKSIGVERIFY
// <lease maturity in blocks> OP_CHECKLOCKTIMEVERIFY OP_DROP
// 1 OP_CHECKSEQUENCEVERIFY
func LeaseCommitScriptToRemoteConfirmed(key *btcec.PublicKey,
leaseExpiry uint32) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(45))
// Only the given key can spend the output.
builder.AddData(key.SerializeCompressed())
builder.AddOp(txscript.OP_CHECKSIGVERIFY)
// The channel initiator always has the additional channel lease
// expiration constraint for outputs that pay to them which must be
// satisfied.
builder.AddInt64(int64(leaseExpiry))
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
builder.AddOp(txscript.OP_DROP)
// Check that it has one confirmation.
builder.AddOp(txscript.OP_1)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
return builder.Script()
}
// CommitSpendToRemoteConfirmed constructs a valid witness allowing a node to
// spend their settled output on the counterparty's commitment transaction when
// it has one confirmetion. This is used for the anchor channel type. The
// spending key will always be non-tweaked for this output type.
func CommitSpendToRemoteConfirmed(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
if signDesc.KeyDesc.PubKey == nil {
return nil, fmt.Errorf("cannot generate witness with nil " +
"KeyDesc pubkey")
}
// Similar to non delayed output, only a signature is needed.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// Finally, we'll manually craft the witness. The witness here is the
// signature and the redeem script.
witnessStack := make([][]byte, 2)
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = signDesc.WitnessScript
return witnessStack, nil
}
// CommitScriptAnchor constructs the script for the anchor output spendable by
// the given key immediately, or by anyone after 16 confirmations.
//
// Possible Input Scripts:
//
// By owner: <sig>
// By anyone (after 16 conf): <emptyvector>
//
// Output Script:
//
// <funding_pubkey> OP_CHECKSIG OP_IFDUP
// OP_NOTIF
// OP_16 OP_CSV
// OP_ENDIF
func CommitScriptAnchor(key *btcec.PublicKey) ([]byte, error) {
builder := txscript.NewScriptBuilder(txscript.WithScriptAllocSize(
AnchorScriptSize,
))
// Spend immediately with key.
builder.AddData(key.SerializeCompressed())
builder.AddOp(txscript.OP_CHECKSIG)
// Duplicate the value if true, since it will be consumed by the NOTIF.
builder.AddOp(txscript.OP_IFDUP)
// Otherwise spendable by anyone after 16 confirmations.
builder.AddOp(txscript.OP_NOTIF)
builder.AddOp(txscript.OP_16)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
builder.AddOp(txscript.OP_ENDIF)
return builder.Script()
}
// AnchorScriptTree holds all the contents needed to sweep a taproot anchor
// output on chain.
type AnchorScriptTree struct {
ScriptTree
// SweepLeaf is the leaf used to settle the output after the delay.
SweepLeaf txscript.TapLeaf
}
// NewAnchorScriptTree makes a new script tree for an anchor output with the
// passed anchor key.
func NewAnchorScriptTree(
anchorKey *btcec.PublicKey) (*AnchorScriptTree, error) {
// The main script used is just a OP_16 CSV (anyone can sweep after 16
// blocks).
builder := txscript.NewScriptBuilder()
builder.AddOp(txscript.OP_16)
builder.AddOp(txscript.OP_CHECKSEQUENCEVERIFY)
anchorScript, err := builder.Script()
if err != nil {
return nil, err
}
// With the script, we can make our sole leaf, then derive the root
// from that.
tapLeaf := txscript.NewBaseTapLeaf(anchorScript)
tapScriptTree := txscript.AssembleTaprootScriptTree(tapLeaf)
tapScriptRoot := tapScriptTree.RootNode.TapHash()
// Now that we have our root, we can arrive at the final output script
// by tweaking the internal key with this root.
anchorOutputKey := txscript.ComputeTaprootOutputKey(
anchorKey, tapScriptRoot[:],
)
return &AnchorScriptTree{
ScriptTree: ScriptTree{
TaprootKey: anchorOutputKey,
TapscriptTree: tapScriptTree,
TapscriptRoot: tapScriptRoot[:],
InternalKey: anchorKey,
},
SweepLeaf: tapLeaf,
}, nil
}
// WitnessScriptToSign returns the witness script that we'll use when signing
// for the remote party, and also verifying signatures on our transactions. As
// an example, when we create an outgoing HTLC for the remote party, we want to
// sign their success path.
func (a *AnchorScriptTree) WitnessScriptToSign() []byte {
return a.SweepLeaf.Script
}
// WitnessScriptForPath returns the witness script for the given spending path.
// An error is returned if the path is unknown.
func (a *AnchorScriptTree) WitnessScriptForPath(
path ScriptPath) ([]byte, error) {
switch path {
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return a.SweepLeaf.Script, nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// CtrlBlockForPath returns the control block for the given spending path. For
// script types that don't have a control block, nil is returned.
func (a *AnchorScriptTree) CtrlBlockForPath(
path ScriptPath) (*txscript.ControlBlock, error) {
switch path {
case ScriptPathDelay:
fallthrough
case ScriptPathSuccess:
return lnutils.Ptr(MakeTaprootCtrlBlock(
a.SweepLeaf.Script, a.InternalKey,
a.TapscriptTree,
)), nil
default:
return nil, fmt.Errorf("unknown script path: %v", path)
}
}
// Tree returns the underlying ScriptTree of the AnchorScriptTree.
func (a *AnchorScriptTree) Tree() ScriptTree {
return a.ScriptTree
}
// A compile time check to ensure AnchorScriptTree implements the
// TapscriptDescriptor interface.
var _ TapscriptDescriptor = (*AnchorScriptTree)(nil)
// TaprootOutputKeyAnchor returns the segwit v1 (taproot) witness program that
// encodes the anchor output spending conditions: the passed key can be used
// for keyspend, with the OP_CSV 16 clause living within an internal tapscript
// leaf.
//
// Spend paths:
// - Key spend: <key_signature>
// - Script spend: OP_16 CSV <control_block>
func TaprootOutputKeyAnchor(key *btcec.PublicKey) (*btcec.PublicKey, error) {
anchorScriptTree, err := NewAnchorScriptTree(key)
if err != nil {
return nil, err
}
return anchorScriptTree.TaprootKey, nil
}
// TaprootAnchorSpend constructs a valid witness allowing a node to sweep their
// anchor output.
func TaprootAnchorSpend(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
// For this spend type, we only need a single signature which'll be a
// keyspend using the anchor private key.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The witness stack in this case is pretty simple: we only need to
// specify the signature generated.
witnessStack := make(wire.TxWitness, 1)
witnessStack[0] = maybeAppendSighash(sweepSig, signDesc.HashType)
return witnessStack, nil
}
// TaprootAnchorSpendAny constructs a valid witness allowing anyone to sweep
// the anchor output after 16 blocks.
func TaprootAnchorSpendAny(anchorKey *btcec.PublicKey) (wire.TxWitness, error) {
anchorScriptTree, err := NewAnchorScriptTree(anchorKey)
if err != nil {
return nil, err
}
// For this spend, the only thing we need to do is create a valid
// control block. Other than that, there're no restrictions to how the
// output can be spent.
scriptTree := anchorScriptTree.TapscriptTree
sweepLeaf := anchorScriptTree.SweepLeaf
sweepIdx := scriptTree.LeafProofIndex[sweepLeaf.TapHash()]
sweepMerkleProof := scriptTree.LeafMerkleProofs[sweepIdx]
sweepControlBlock := sweepMerkleProof.ToControlBlock(anchorKey)
// The final witness stack will be:
//
// <sweep script> <control block>
witnessStack := make(wire.TxWitness, 2)
witnessStack[0] = sweepLeaf.Script
witnessStack[1], err = sweepControlBlock.ToBytes()
if err != nil {
return nil, err
}
return witnessStack, nil
}
// CommitSpendAnchor constructs a valid witness allowing a node to spend their
// anchor output on the commitment transaction using their funding key. This is
// used for the anchor channel type.
func CommitSpendAnchor(signer Signer, signDesc *SignDescriptor,
sweepTx *wire.MsgTx) (wire.TxWitness, error) {
if signDesc.KeyDesc.PubKey == nil {
return nil, fmt.Errorf("cannot generate witness with nil " +
"KeyDesc pubkey")
}
// Create a signature.
sweepSig, err := signer.SignOutputRaw(sweepTx, signDesc)
if err != nil {
return nil, err
}
// The witness here is just a signature and the redeem script.
witnessStack := make([][]byte, 2)
witnessStack[0] = append(sweepSig.Serialize(), byte(signDesc.HashType))
witnessStack[1] = signDesc.WitnessScript
return witnessStack, nil
}
// CommitSpendAnchorAnyone constructs a witness allowing anyone to spend the
// anchor output after it has gotten 16 confirmations. Since no signing is
// required, only knowledge of the redeem script is necessary to spend it.
func CommitSpendAnchorAnyone(script []byte) (wire.TxWitness, error) {
// The witness here is just the redeem script.
witnessStack := make([][]byte, 2)
witnessStack[0] = nil
witnessStack[1] = script
return witnessStack, nil
}
// SingleTweakBytes computes set of bytes we call the single tweak. The purpose
// of the single tweak is to randomize all regular delay and payment base
// points. To do this, we generate a hash that binds the commitment point to
// the pay/delay base point. The end result is that the basePoint is
// tweaked as follows:
//
// - key = basePoint + sha256(commitPoint || basePoint)*G
func SingleTweakBytes(commitPoint, basePoint *btcec.PublicKey) []byte {
h := sha256.New()
h.Write(commitPoint.SerializeCompressed())
h.Write(basePoint.SerializeCompressed())
return h.Sum(nil)
}
// TweakPubKey tweaks a public base point given a per commitment point. The per
// commitment point is a unique point on our target curve for each commitment
// transaction. When tweaking a local base point for use in a remote commitment
// transaction, the remote party's current per commitment point is to be used.
// The opposite applies for when tweaking remote keys. Precisely, the following
// operation is used to "tweak" public keys:
//
// tweakPub := basePoint + sha256(commitPoint || basePoint) * G
// := G*k + sha256(commitPoint || basePoint)*G
// := G*(k + sha256(commitPoint || basePoint))
//
// Therefore, if a party possess the value k, the private key of the base
// point, then they are able to derive the proper private key for the
// revokeKey by computing:
//
// revokePriv := k + sha256(commitPoint || basePoint) mod N
//
// Where N is the order of the sub-group.
//
// The rationale for tweaking all public keys used within the commitment
// contracts is to ensure that all keys are properly delinearized to avoid any
// funny business when jointly collaborating to compute public and private
// keys. Additionally, the use of the per commitment point ensures that each
// commitment state houses a unique set of keys which is useful when creating
// blinded channel outsourcing protocols.
//
// TODO(roasbeef): should be using double-scalar mult here
func TweakPubKey(basePoint, commitPoint *btcec.PublicKey) *btcec.PublicKey {
tweakBytes := SingleTweakBytes(commitPoint, basePoint)
return TweakPubKeyWithTweak(basePoint, tweakBytes)
}
// TweakPubKeyWithTweak is the exact same as the TweakPubKey function, however
// it accepts the raw tweak bytes directly rather than the commitment point.
func TweakPubKeyWithTweak(pubKey *btcec.PublicKey,
tweakBytes []byte) *btcec.PublicKey {
var (
pubKeyJacobian btcec.JacobianPoint
tweakJacobian btcec.JacobianPoint
resultJacobian btcec.JacobianPoint
)
tweakKey, _ := btcec.PrivKeyFromBytes(tweakBytes)
btcec.ScalarBaseMultNonConst(&tweakKey.Key, &tweakJacobian)
pubKey.AsJacobian(&pubKeyJacobian)
btcec.AddNonConst(&pubKeyJacobian, &tweakJacobian, &resultJacobian)
resultJacobian.ToAffine()
return btcec.NewPublicKey(&resultJacobian.X, &resultJacobian.Y)
}
// TweakPrivKey tweaks the private key of a public base point given a per
// commitment point. The per commitment secret is the revealed revocation
// secret for the commitment state in question. This private key will only need
// to be generated in the case that a channel counter party broadcasts a
// revoked state. Precisely, the following operation is used to derive a
// tweaked private key:
//
// - tweakPriv := basePriv + sha256(commitment || basePub) mod N
//
// Where N is the order of the sub-group.
func TweakPrivKey(basePriv *btcec.PrivateKey,
commitTweak []byte) *btcec.PrivateKey {
// tweakInt := sha256(commitPoint || basePub)
tweakScalar := new(btcec.ModNScalar)
tweakScalar.SetByteSlice(commitTweak)
tweakScalar.Add(&basePriv.Key)
return &btcec.PrivateKey{Key: *tweakScalar}
}
// DeriveRevocationPubkey derives the revocation public key given the
// counterparty's commitment key, and revocation preimage derived via a
// pseudo-random-function. In the event that we (for some reason) broadcast a
// revoked commitment transaction, then if the other party knows the revocation
// preimage, then they'll be able to derive the corresponding private key to
// this private key by exploiting the homomorphism in the elliptic curve group:
// - https://en.wikipedia.org/wiki/Group_homomorphism#Homomorphisms_of_abelian_groups
//
// The derivation is performed as follows:
//
// revokeKey := revokeBase * sha256(revocationBase || commitPoint) +
// commitPoint * sha256(commitPoint || revocationBase)
//
// := G*(revokeBasePriv * sha256(revocationBase || commitPoint)) +
// G*(commitSecret * sha256(commitPoint || revocationBase))
//
// := G*(revokeBasePriv * sha256(revocationBase || commitPoint) +
// commitSecret * sha256(commitPoint || revocationBase))
//
// Therefore, once we divulge the revocation secret, the remote peer is able to
// compute the proper private key for the revokeKey by computing:
//
// revokePriv := (revokeBasePriv * sha256(revocationBase || commitPoint)) +
// (commitSecret * sha256(commitPoint || revocationBase)) mod N
//
// Where N is the order of the sub-group.
func DeriveRevocationPubkey(revokeBase,
commitPoint *btcec.PublicKey) *btcec.PublicKey {
// R = revokeBase * sha256(revocationBase || commitPoint)
revokeTweakBytes := SingleTweakBytes(revokeBase, commitPoint)
revokeTweakScalar := new(btcec.ModNScalar)
revokeTweakScalar.SetByteSlice(revokeTweakBytes)
var (
revokeBaseJacobian btcec.JacobianPoint
rJacobian btcec.JacobianPoint
)
revokeBase.AsJacobian(&revokeBaseJacobian)
btcec.ScalarMultNonConst(
revokeTweakScalar, &revokeBaseJacobian, &rJacobian,
)
// C = commitPoint * sha256(commitPoint || revocationBase)
commitTweakBytes := SingleTweakBytes(commitPoint, revokeBase)
commitTweakScalar := new(btcec.ModNScalar)
commitTweakScalar.SetByteSlice(commitTweakBytes)
var (
commitPointJacobian btcec.JacobianPoint
cJacobian btcec.JacobianPoint
)
commitPoint.AsJacobian(&commitPointJacobian)
btcec.ScalarMultNonConst(
commitTweakScalar, &commitPointJacobian, &cJacobian,
)
// Now that we have the revocation point, we add this to their commitment
// public key in order to obtain the revocation public key.
//
// P = R + C
var resultJacobian btcec.JacobianPoint
btcec.AddNonConst(&rJacobian, &cJacobian, &resultJacobian)
resultJacobian.ToAffine()
return btcec.NewPublicKey(&resultJacobian.X, &resultJacobian.Y)
}
// DeriveRevocationPrivKey derives the revocation private key given a node's
// commitment private key, and the preimage to a previously seen revocation
// hash. Using this derived private key, a node is able to claim the output
// within the commitment transaction of a node in the case that they broadcast
// a previously revoked commitment transaction.
//
// The private key is derived as follows:
//
// revokePriv := (revokeBasePriv * sha256(revocationBase || commitPoint)) +
// (commitSecret * sha256(commitPoint || revocationBase)) mod N
//
// Where N is the order of the sub-group.
func DeriveRevocationPrivKey(revokeBasePriv *btcec.PrivateKey,
commitSecret *btcec.PrivateKey) *btcec.PrivateKey {
// r = sha256(revokeBasePub || commitPoint)
revokeTweakBytes := SingleTweakBytes(
revokeBasePriv.PubKey(), commitSecret.PubKey(),
)
revokeTweakScalar := new(btcec.ModNScalar)
revokeTweakScalar.SetByteSlice(revokeTweakBytes)
// c = sha256(commitPoint || revokeBasePub)
commitTweakBytes := SingleTweakBytes(
commitSecret.PubKey(), revokeBasePriv.PubKey(),
)
commitTweakScalar := new(btcec.ModNScalar)
commitTweakScalar.SetByteSlice(commitTweakBytes)
// Finally to derive the revocation secret key we'll perform the
// following operation:
//
// k = (revocationPriv * r) + (commitSecret * c) mod N
//
// This works since:
// P = (G*a)*b + (G*c)*d
// P = G*(a*b) + G*(c*d)
// P = G*(a*b + c*d)
revokeHalfPriv := revokeTweakScalar.Mul(&revokeBasePriv.Key)
commitHalfPriv := commitTweakScalar.Mul(&commitSecret.Key)
revocationPriv := revokeHalfPriv.Add(commitHalfPriv)
return &btcec.PrivateKey{Key: *revocationPriv}
}
// ComputeCommitmentPoint generates a commitment point given a commitment
// secret. The commitment point for each state is used to randomize each key in
// the key-ring and also to used as a tweak to derive new public+private keys
// for the state.
func ComputeCommitmentPoint(commitSecret []byte) *btcec.PublicKey {
_, pubKey := btcec.PrivKeyFromBytes(commitSecret)
return pubKey
}
// ScriptIsOpReturn returns true if the passed script is an OP_RETURN script.
//
// Lifted from the txscript package:
// https://github.com/btcsuite/btcd/blob/cc26860b40265e1332cca8748c5dbaf3c81cc094/txscript/standard.go#L493-L526.
//
//nolint:ll
func ScriptIsOpReturn(script []byte) bool {
// A null script is of the form:
// OP_RETURN <optional data>
//
// Thus, it can either be a single OP_RETURN or an OP_RETURN followed by
// a data push up to MaxDataCarrierSize bytes.
// The script can't possibly be a null data script if it doesn't start
// with OP_RETURN. Fail fast to avoid more work below.
if len(script) < 1 || script[0] != txscript.OP_RETURN {
return false
}
// Single OP_RETURN.
if len(script) == 1 {
return true
}
// OP_RETURN followed by data push up to MaxDataCarrierSize bytes.
tokenizer := txscript.MakeScriptTokenizer(0, script[1:])
return tokenizer.Next() && tokenizer.Done() &&
(txscript.IsSmallInt(tokenizer.Opcode()) ||
tokenizer.Opcode() <= txscript.OP_PUSHDATA4) &&
len(tokenizer.Data()) <= txscript.MaxDataCarrierSize
}
package input
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain"
)
var (
// ErrTweakOverdose signals a SignDescriptor is invalid because both of its
// SingleTweak and DoubleTweak are non-nil.
ErrTweakOverdose = errors.New("sign descriptor should only have one tweak")
)
// SignDescriptor houses the necessary information required to successfully
// sign a given segwit output. This struct is used by the Signer interface in
// order to gain access to critical data needed to generate a valid signature.
type SignDescriptor struct {
// KeyDesc is a descriptor that precisely describes *which* key to use
// for signing. This may provide the raw public key directly, or
// require the Signer to re-derive the key according to the populated
// derivation path.
KeyDesc keychain.KeyDescriptor
// SingleTweak is a scalar value that will be added to the private key
// corresponding to the above public key to obtain the private key to
// be used to sign this input. This value is typically derived via the
// following computation:
//
// * derivedKey = privkey + sha256(perCommitmentPoint || pubKey) mod N
//
// NOTE: If this value is nil, then the input can be signed using only
// the above public key. Either a SingleTweak should be set or a
// DoubleTweak, not both.
SingleTweak []byte
// DoubleTweak is a private key that will be used in combination with
// its corresponding private key to derive the private key that is to
// be used to sign the target input. Within the Lightning protocol,
// this value is typically the commitment secret from a previously
// revoked commitment transaction. This value is in combination with
// two hash values, and the original private key to derive the private
// key to be used when signing.
//
// * k = (privKey*sha256(pubKey || tweakPub) +
// tweakPriv*sha256(tweakPub || pubKey)) mod N
//
// NOTE: If this value is nil, then the input can be signed using only
// the above public key. Either a SingleTweak should be set or a
// DoubleTweak, not both.
DoubleTweak *btcec.PrivateKey
// TapTweak is a 32-byte value that will be used to derive a taproot
// output public key (or the corresponding private key) from an
// internal key and this tweak. The transformation applied is:
// * outputKey = internalKey +
// tagged_hash("tapTweak", internalKey || tapTweak)
//
// When attempting to sign an output derived via BIP 86, then this
// field should be an empty byte array.
//
// When attempting to sign for the key spend path of an output key that
// commits to an actual script tree, the script root should be used.
TapTweak []byte
// WitnessScript is the full script required to properly redeem the
// output. This field should be set to the full script if a p2wsh
// output is being signed. For p2wkh it should be set to the hashed
// script (PkScript).
WitnessScript []byte
// SignMethod specifies how the input should be signed. Depending on the
// selected method, either the TapTweak, WitnessScript or both need to
// be specified.
SignMethod SignMethod
// Output is the target output which should be signed. The PkScript and
// Value fields within the output should be properly populated,
// otherwise an invalid signature may be generated.
Output *wire.TxOut
// HashType is the target sighash type that should be used when
// generating the final sighash, and signature.
HashType txscript.SigHashType
// SigHashes is the pre-computed sighash midstate to be used when
// generating the final sighash for signing.
SigHashes *txscript.TxSigHashes
// PrevOutputFetcher is an interface that can return the output
// information on all UTXOs that are being spent in this transaction.
// This MUST be set when spending Taproot outputs.
PrevOutputFetcher txscript.PrevOutputFetcher
// ControlBlock is a fully serialized control block that contains the
// merkle proof necessary to spend a taproot output. This may
// optionally be set if the SignMethod is
// input.TaprootScriptSpendSignMethod. In which case, this should be an
// inclusion proof for the WitnessScript.
ControlBlock []byte
// InputIndex is the target input within the transaction that should be
// signed.
InputIndex int
}
// SignMethod defines the different ways a signer can sign, given a specific
// input.
type SignMethod uint8
const (
// WitnessV0SignMethod denotes that a SegWit v0 (p2wkh, np2wkh, p2wsh)
// input script should be signed.
WitnessV0SignMethod SignMethod = 0
// TaprootKeySpendBIP0086SignMethod denotes that a SegWit v1 (p2tr)
// input should be signed by using the BIP0086 method (commit to
// internal key only).
TaprootKeySpendBIP0086SignMethod SignMethod = 1
// TaprootKeySpendSignMethod denotes that a SegWit v1 (p2tr)
// input should be signed by using a given taproot hash to commit to in
// addition to the internal key.
TaprootKeySpendSignMethod SignMethod = 2
// TaprootScriptSpendSignMethod denotes that a SegWit v1 (p2tr) input
// should be spent using the script path and that a specific leaf script
// should be signed for.
TaprootScriptSpendSignMethod SignMethod = 3
)
// String returns a human-readable representation of the signing method.
func (s SignMethod) String() string {
switch s {
case WitnessV0SignMethod:
return "witness_v0"
case TaprootKeySpendBIP0086SignMethod:
return "taproot_key_spend_bip86"
case TaprootKeySpendSignMethod:
return "taproot_key_spend"
case TaprootScriptSpendSignMethod:
return "taproot_script_spend"
default:
return fmt.Sprintf("unknown<%d>", s)
}
}
// PkScriptCompatible returns true if the given public key script is compatible
// with the sign method.
func (s SignMethod) PkScriptCompatible(pkScript []byte) bool {
switch s {
// SegWit v0 can be p2wkh, np2wkh, p2wsh.
case WitnessV0SignMethod:
return txscript.IsPayToWitnessPubKeyHash(pkScript) ||
txscript.IsPayToWitnessScriptHash(pkScript) ||
txscript.IsPayToScriptHash(pkScript)
case TaprootKeySpendBIP0086SignMethod, TaprootKeySpendSignMethod,
TaprootScriptSpendSignMethod:
return txscript.IsPayToTaproot(pkScript)
default:
return false
}
}
// WriteSignDescriptor serializes a SignDescriptor struct into the passed
// io.Writer stream.
//
// NOTE: We assume the SigHashes and InputIndex fields haven't been assigned
// yet, since that is usually done just before broadcast by the witness
// generator.
func WriteSignDescriptor(w io.Writer, sd *SignDescriptor) error {
err := binary.Write(w, binary.BigEndian, sd.KeyDesc.Family)
if err != nil {
return err
}
err = binary.Write(w, binary.BigEndian, sd.KeyDesc.Index)
if err != nil {
return err
}
err = binary.Write(w, binary.BigEndian, sd.KeyDesc.PubKey != nil)
if err != nil {
return err
}
if sd.KeyDesc.PubKey != nil {
serializedPubKey := sd.KeyDesc.PubKey.SerializeCompressed()
if err := wire.WriteVarBytes(w, 0, serializedPubKey); err != nil {
return err
}
}
if err := wire.WriteVarBytes(w, 0, sd.SingleTweak); err != nil {
return err
}
var doubleTweakBytes []byte
if sd.DoubleTweak != nil {
doubleTweakBytes = sd.DoubleTweak.Serialize()
}
if err := wire.WriteVarBytes(w, 0, doubleTweakBytes); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, sd.WitnessScript); err != nil {
return err
}
if err := writeTxOut(w, sd.Output); err != nil {
return err
}
var scratch [4]byte
binary.BigEndian.PutUint32(scratch[:], uint32(sd.HashType))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
return nil
}
// ReadSignDescriptor deserializes a SignDescriptor struct from the passed
// io.Reader stream.
func ReadSignDescriptor(r io.Reader, sd *SignDescriptor) error {
err := binary.Read(r, binary.BigEndian, &sd.KeyDesc.Family)
if err != nil {
return err
}
err = binary.Read(r, binary.BigEndian, &sd.KeyDesc.Index)
if err != nil {
return err
}
var hasKey bool
err = binary.Read(r, binary.BigEndian, &hasKey)
if err != nil {
return err
}
if hasKey {
pubKeyBytes, err := wire.ReadVarBytes(r, 0, 34, "pubkey")
if err != nil {
return err
}
sd.KeyDesc.PubKey, err = btcec.ParsePubKey(pubKeyBytes)
if err != nil {
return err
}
}
singleTweak, err := wire.ReadVarBytes(r, 0, 32, "singleTweak")
if err != nil {
return err
}
// Serializing a SignDescriptor with a nil-valued SingleTweak results
// in deserializing a zero-length slice. Since a nil-valued SingleTweak
// has special meaning and a zero-length slice for a SingleTweak is
// invalid, we can use the zero-length slice as the flag for a
// nil-valued SingleTweak.
if len(singleTweak) == 0 {
sd.SingleTweak = nil
} else {
sd.SingleTweak = singleTweak
}
doubleTweakBytes, err := wire.ReadVarBytes(r, 0, 32, "doubleTweak")
if err != nil {
return err
}
// Serializing a SignDescriptor with a nil-valued DoubleTweak results
// in deserializing a zero-length slice. Since a nil-valued DoubleTweak
// has special meaning and a zero-length slice for a DoubleTweak is
// invalid, we can use the zero-length slice as the flag for a
// nil-valued DoubleTweak.
if len(doubleTweakBytes) == 0 {
sd.DoubleTweak = nil
} else {
sd.DoubleTweak, _ = btcec.PrivKeyFromBytes(doubleTweakBytes)
}
// Only one tweak should ever be set, fail if both are present.
if sd.SingleTweak != nil && sd.DoubleTweak != nil {
return ErrTweakOverdose
}
witnessScript, err := wire.ReadVarBytes(r, 0, 500, "witnessScript")
if err != nil {
return err
}
sd.WitnessScript = witnessScript
txOut := &wire.TxOut{}
if err := readTxOut(r, txOut); err != nil {
return err
}
sd.Output = txOut
var hashType [4]byte
if _, err := io.ReadFull(r, hashType[:]); err != nil {
return err
}
sd.HashType = txscript.SigHashType(binary.BigEndian.Uint32(hashType[:]))
return nil
}
package input
import (
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/lightningnetwork/lnd/lntypes"
)
const (
// witnessScaleFactor determines the level of "discount" witness data
// receives compared to "base" data. A scale factor of 4, denotes that
// witness data is 1/4 as cheap as regular non-witness data. Value copied
// here for convenience.
witnessScaleFactor = blockchain.WitnessScaleFactor
// The weight(weight), which is different from the !size! (see BIP-141),
// is calculated as:
// Weight = 4 * BaseSize + WitnessSize (weight).
// BaseSize - size of the transaction without witness data (bytes).
// WitnessSize - witness size (bytes).
// Weight - the metric for determining the weight of the transaction.
// P2WPKHSize 22 bytes
// - OP_0: 1 byte
// - OP_DATA: 1 byte (PublicKeyHASH160 length)
// - PublicKeyHASH160: 20 bytes
P2WPKHSize = 1 + 1 + 20
// NestedP2WPKHSize 23 bytes
// - OP_DATA: 1 byte (P2WPKHSize)
// - P2WPKHWitnessProgram: 22 bytes
NestedP2WPKHSize = 1 + P2WPKHSize
// P2WSHSize 34 bytes
// - OP_0: 1 byte
// - OP_DATA: 1 byte (WitnessScriptSHA256 length)
// - WitnessScriptSHA256: 32 bytes
P2WSHSize = 1 + 1 + 32
// NestedP2WSHSize 35 bytes
// - OP_DATA: 1 byte (P2WSHSize)
// - P2WSHWitnessProgram: 34 bytes
NestedP2WSHSize = 1 + P2WSHSize
// UnknownWitnessSize 42 bytes
// - OP_x: 1 byte
// - OP_DATA: 1 byte (max-size length)
// - max-size: 40 bytes
UnknownWitnessSize = 1 + 1 + 40
// BaseOutputSize 9 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
BaseOutputSize = 8 + 1
// P2PKHSize 25 bytes.
P2PKHSize = 25
// P2PKHOutputSize 34 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
// - pkscript (p2pkh): 25 bytes
P2PKHOutputSize = BaseOutputSize + P2PKHSize
// P2WKHOutputSize 31 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
// - pkscript (p2wpkh): 22 bytes
P2WKHOutputSize = BaseOutputSize + P2WPKHSize
// P2WSHOutputSize 43 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
// - pkscript (p2wsh): 34 bytes
P2WSHOutputSize = BaseOutputSize + P2WSHSize
// P2SHSize 23 bytes.
P2SHSize = 23
// P2SHOutputSize 32 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
// - pkscript (p2sh): 23 bytes
P2SHOutputSize = BaseOutputSize + P2SHSize
// P2TRSize 34 bytes
// - OP_0: 1 byte
// - OP_DATA: 1 byte (x-only public key length)
// - x-only public key length: 32 bytes
P2TRSize = 34
// P2TROutputSize 43 bytes
// - value: 8 bytes
// - var_int: 1 byte (pkscript_length)
// - pkscript (p2tr): 34 bytes
P2TROutputSize = BaseOutputSize + P2TRSize
// P2PKHScriptSigSize 108 bytes
// - OP_DATA: 1 byte (signature length)
// - signature
// - OP_DATA: 1 byte (pubkey length)
// - pubkey
P2PKHScriptSigSize = 1 + 73 + 1 + 33
// P2WKHWitnessSize 109 bytes
// - number_of_witness_elements: 1 byte
// - signature_length: 1 byte
// - signature
// - pubkey_length: 1 byte
// - pubkey
P2WKHWitnessSize = 1 + 1 + 73 + 1 + 33
// MultiSigSize 71 bytes
// - OP_2: 1 byte
// - OP_DATA: 1 byte (pubKeyAlice length)
// - pubKeyAlice: 33 bytes
// - OP_DATA: 1 byte (pubKeyBob length)
// - pubKeyBob: 33 bytes
// - OP_2: 1 byte
// - OP_CHECKMULTISIG: 1 byte
MultiSigSize = 1 + 1 + 33 + 1 + 33 + 1 + 1
// MultiSigWitnessSize 222 bytes
// - NumberOfWitnessElements: 1 byte
// - NilLength: 1 byte
// - sigAliceLength: 1 byte
// - sigAlice: 73 bytes
// - sigBobLength: 1 byte
// - sigBob: 73 bytes
// - WitnessScriptLength: 1 byte
// - WitnessScript (MultiSig)
MultiSigWitnessSize = 1 + 1 + 1 + 73 + 1 + 73 + 1 + MultiSigSize
// InputSize 41 bytes
// - PreviousOutPoint:
// - Hash: 32 bytes
// - Index: 4 bytes
// - OP_DATA: 1 byte (ScriptSigLength)
// - ScriptSig: 0 bytes
// - Witness <---- we use "Witness" instead of "ScriptSig" for
// transaction validation, but "Witness" is stored
// separately and weight for it size is smaller. So
// we separate the calculation of ordinary data
// from witness data.
// - Sequence: 4 bytes
InputSize = 32 + 4 + 1 + 4
// FundingInputSize 41 bytes
// FundingInputSize represents the size of an input to a funding
// transaction, and is equivalent to the size of a standard segwit input
// as calculated above.
FundingInputSize = InputSize
// CommitmentDelayOutput 43 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (P2WSH)
CommitmentDelayOutput = 8 + 1 + P2WSHSize
// TaprootCommitmentOutput 43 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (P2TR)
TaprootCommitmentOutput = 8 + 1 + P2TRSize
// CommitmentKeyHashOutput 31 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (P2WPKH)
CommitmentKeyHashOutput = 8 + 1 + P2WPKHSize
// CommitmentAnchorOutput 43 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (P2WSH)
CommitmentAnchorOutput = 8 + 1 + P2WSHSize
// TaprootCommitmentAnchorOutput 43 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (P2TR)
TaprootCommitmentAnchorOutput = 8 + 1 + P2TRSize
// HTLCSize 43 bytes
// - Value: 8 bytes
// - VarInt: 1 byte (PkScript length)
// - PkScript (PW2SH)
HTLCSize = 8 + 1 + P2WSHSize
// WitnessHeaderSize 2 bytes
// - Flag: 1 byte
// - Marker: 1 byte
WitnessHeaderSize = 1 + 1
// BaseTxSize 8 bytes
// - Version: 4 bytes
// - LockTime: 4 bytes
BaseTxSize = 4 + 4
// BaseCommitmentTxSize 125 + 43 * num-htlc-outputs bytes
// - Version: 4 bytes
// - WitnessHeader <---- part of the witness data
// - CountTxIn: 1 byte
// - TxIn: 41 bytes
// FundingInput
// - CountTxOut: 1 byte
// - TxOut: 74 + 43 * num-htlc-outputs bytes
// OutputPayingToThem,
// OutputPayingToUs,
// ....HTLCOutputs...
// - LockTime: 4 bytes
BaseCommitmentTxSize = 4 + 1 + FundingInputSize + 1 +
CommitmentDelayOutput + CommitmentKeyHashOutput + 4
// BaseCommitmentTxWeight 500 weight
BaseCommitmentTxWeight = witnessScaleFactor * BaseCommitmentTxSize
// WitnessCommitmentTxWeight 224 weight
WitnessCommitmentTxWeight = WitnessHeaderSize + MultiSigWitnessSize
// BaseAnchorCommitmentTxSize 225 + 43 * num-htlc-outputs bytes
// - Version: 4 bytes
// - WitnessHeader <---- part of the witness data
// - CountTxIn: 1 byte
// - TxIn: 41 bytes
// FundingInput
// - CountTxOut: 3 byte
// - TxOut: 4*43 + 43 * num-htlc-outputs bytes
// OutputPayingToThem,
// OutputPayingToUs,
// AnchorPayingToThem,
// AnchorPayingToUs,
// ....HTLCOutputs...
// - LockTime: 4 bytes
BaseAnchorCommitmentTxSize = 4 + 1 + FundingInputSize + 3 +
2*CommitmentDelayOutput + 2*CommitmentAnchorOutput + 4
// BaseAnchorCommitmentTxWeight 900 weight.
BaseAnchorCommitmentTxWeight = witnessScaleFactor * BaseAnchorCommitmentTxSize
// BaseTaprootCommitmentTxWeight 225 + 43 * num-htlc-outputs bytes
// - Version: 4 bytes
// - WitnessHeader <---- part of the witness data
// - CountTxIn: 1 byte
// - TxIn: 41 bytes
// FundingInput
// - CountTxOut: 3 byte
// - TxOut: 172 + 43 * num-htlc-outputs bytes
// OutputPayingToThem,
// OutputPayingToUs,
// ....HTLCOutputs...
// - LockTime: 4 bytes
BaseTaprootCommitmentTxWeight = (4 + 1 + FundingInputSize + 3 +
2*TaprootCommitmentOutput + 2*TaprootCommitmentAnchorOutput +
4) * witnessScaleFactor
// CommitWeight 724 weight.
CommitWeight = BaseCommitmentTxWeight + WitnessCommitmentTxWeight
// AnchorCommitWeight 1124 weight.
AnchorCommitWeight = BaseAnchorCommitmentTxWeight + WitnessCommitmentTxWeight
// TaprootCommitWeight 968 weight.
TaprootCommitWeight = (BaseTaprootCommitmentTxWeight +
WitnessHeaderSize + TaprootKeyPathWitnessSize)
// HTLCWeight 172 weight.
HTLCWeight = witnessScaleFactor * HTLCSize
// HtlcTimeoutWeight 663 weight
// HtlcTimeoutWeight is the weight of the HTLC timeout transaction
// which will transition an outgoing HTLC to the delay-and-claim state.
HtlcTimeoutWeight = 663
// TaprootHtlcTimeoutWeight is the total weight of the taproot HTLC
// timeout transaction.
TaprootHtlcTimeoutWeight = 645
// HtlcSuccessWeight 703 weight
// HtlcSuccessWeight is the weight of the HTLC success transaction
// which will transition an incoming HTLC to the delay-and-claim state.
HtlcSuccessWeight = 703
// TaprootHtlcSuccessWeight is the total weight of the taproot HTLC
// success transaction.
TaprootHtlcSuccessWeight = 705
// HtlcConfirmedScriptOverhead 3 bytes
// HtlcConfirmedScriptOverhead is the extra length of an HTLC script
// that requires confirmation before it can be spent. These extra bytes
// is a result of the extra CSV check.
HtlcConfirmedScriptOverhead = 3
// HtlcTimeoutWeightConfirmed 666 weight
// HtlcTimeoutWeightConfirmed is the weight of the HTLC timeout
// transaction which will transition an outgoing HTLC to the
// delay-and-claim state, for the confirmed HTLC outputs. It is 3 bytes
// larger because of the additional CSV check in the input script.
HtlcTimeoutWeightConfirmed = HtlcTimeoutWeight + HtlcConfirmedScriptOverhead
// HtlcSuccessWeightConfirmed 706 weight
// HtlcSuccessWeightConfirmed is the weight of the HTLC success
// transaction which will transition an incoming HTLC to the
// delay-and-claim state, for the confirmed HTLC outputs. It is 3 bytes
// larger because of the cdditional CSV check in the input script.
HtlcSuccessWeightConfirmed = HtlcSuccessWeight + HtlcConfirmedScriptOverhead
// MaxHTLCNumber 966
// MaxHTLCNumber is the maximum number HTLCs which can be included in a
// commitment transaction. This limit was chosen such that, in the case
// of a contract breach, the punishment transaction is able to sweep
// all the HTLC's yet still remain below the widely used standard
// weight limits.
MaxHTLCNumber = 966
// ToLocalScriptSize 79 bytes
// - OP_IF: 1 byte
// - OP_DATA: 1 byte
// - revoke_key: 33 bytes
// - OP_ELSE: 1 byte
// - OP_DATA: 1 byte
// - csv_delay: 4 bytes
// - OP_CHECKSEQUENCEVERIFY: 1 byte
// - OP_DROP: 1 byte
// - OP_DATA: 1 byte
// - delay_key: 33 bytes
// - OP_ENDIF: 1 byte
// - OP_CHECKSIG: 1 byte
ToLocalScriptSize = 1 + 1 + 33 + 1 + 1 + 4 + 1 + 1 + 1 + 33 + 1 + 1
// LeaseWitnessScriptSizeOverhead represents the size overhead in bytes
// of the witness scripts used within script enforced lease commitments.
// This overhead results from the additional CLTV clause required to
// spend.
//
// - OP_DATA: 1 byte
// - lease_expiry: 4 bytes
// - OP_CHECKLOCKTIMEVERIFY: 1 byte
// - OP_DROP: 1 byte
LeaseWitnessScriptSizeOverhead = 1 + 4 + 1 + 1
// ToLocalTimeoutWitnessSize 156 bytes
// - number_of_witness_elements: 1 byte
// - local_delay_sig_length: 1 byte
// - local_delay_sig: 73 bytes
// - zero_length: 1 byte
// - witness_script_length: 1 byte
// - witness_script (to_local_script)
ToLocalTimeoutWitnessSize = 1 + 1 + 73 + 1 + 1 + ToLocalScriptSize
// ToLocalPenaltyWitnessSize 157 bytes
// - number_of_witness_elements: 1 byte
// - revocation_sig_length: 1 byte
// - revocation_sig: 73 bytes
// - OP_TRUE_length: 1 byte
// - OP_TRUE: 1 byte
// - witness_script_length: 1 byte
// - witness_script (to_local_script)
ToLocalPenaltyWitnessSize = 1 + 1 + 73 + 1 + 1 + 1 + ToLocalScriptSize
// ToRemoteConfirmedScriptSize 37 bytes
// - OP_DATA: 1 byte
// - to_remote_key: 33 bytes
// - OP_CHECKSIGVERIFY: 1 byte
// - OP_1: 1 byte
// - OP_CHECKSEQUENCEVERIFY: 1 byte
ToRemoteConfirmedScriptSize = 1 + 33 + 1 + 1 + 1
// ToRemoteConfirmedWitnessSize 113 bytes
// - number_of_witness_elements: 1 byte
// - sig_length: 1 byte
// - sig: 73 bytes
// - witness_script_length: 1 byte
// - witness_script (to_remote_delayed_script)
ToRemoteConfirmedWitnessSize = 1 + 1 + 73 + 1 + ToRemoteConfirmedScriptSize
// AcceptedHtlcScriptSize 140 bytes
// - OP_DUP: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(SHA256(revocationkey)) length)
// - RIPEMD160(SHA256(revocationkey)): 20 bytes
// - OP_EQUAL: 1 byte
// - OP_IF: 1 byte
// - OP_CHECKSIG: 1 byte
// - OP_ELSE: 1 byte
// - OP_DATA: 1 byte (remotekey length)
// - remotekey: 33 bytes
// - OP_SWAP: 1 byte
// - OP_SIZE: 1 byte
// - OP_DATA: 1 byte (32 length)
// - 32: 1 byte
// - OP_EQUAL: 1 byte
// - OP_IF: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(payment_hash) length)
// - RIPEMD160(payment_hash): 20 bytes
// - OP_EQUALVERIFY: 1 byte
// - 2: 1 byte
// - OP_SWAP: 1 byte
// - OP_DATA: 1 byte (localkey length)
// - localkey: 33 bytes
// - 2: 1 byte
// - OP_CHECKMULTISIG: 1 byte
// - OP_ELSE: 1 byte
// - OP_DROP: 1 byte
// - OP_DATA: 1 byte (cltv_expiry length)
// - cltv_expiry: 4 bytes
// - OP_CHECKLOCKTIMEVERIFY: 1 byte
// - OP_DROP: 1 byte
// - OP_CHECKSIG: 1 byte
// - OP_ENDIF: 1 byte
// - OP_1: 1 byte // These 3 extra bytes are only
// - OP_CSV: 1 byte // present for the confirmed
// - OP_DROP: 1 byte // HTLC script types.
// - OP_ENDIF: 1 byte
AcceptedHtlcScriptSize = 3*1 + 20 + 5*1 + 33 + 8*1 + 20 + 4*1 +
33 + 5*1 + 4 + 5*1
// AcceptedHtlcScriptSizeConfirmed 143 bytes.
AcceptedHtlcScriptSizeConfirmed = AcceptedHtlcScriptSize +
HtlcConfirmedScriptOverhead
// AcceptedHtlcTimeoutWitnessSize 217 bytes
// - number_of_witness_elements: 1 byte
// - sender_sig_length: 1 byte
// - sender_sig: 73 bytes
// - nil_length: 1 byte
// - witness_script_length: 1 byte
// - witness_script: (accepted_htlc_script)
AcceptedHtlcTimeoutWitnessSize = 1 + 1 + 73 + 1 + 1 + AcceptedHtlcScriptSize
// AcceptedHtlcTimeoutWitnessSizeConfirmed 220 bytes.
AcceptedHtlcTimeoutWitnessSizeConfirmed = 1 + 1 + 73 + 1 + 1 +
AcceptedHtlcScriptSizeConfirmed
// AcceptedHtlcPenaltyWitnessSize 250 bytes
// - number_of_witness_elements: 1 byte
// - revocation_sig_length: 1 byte
// - revocation_sig: 73 bytes
// - revocation_key_length: 1 byte
// - revocation_key: 33 bytes
// - witness_script_length: 1 byte
// - witness_script (accepted_htlc_script)
AcceptedHtlcPenaltyWitnessSize = 1 + 1 + 73 + 1 + 33 + 1 + AcceptedHtlcScriptSize
// AcceptedHtlcPenaltyWitnessSizeConfirmed 253 bytes.
AcceptedHtlcPenaltyWitnessSizeConfirmed = 1 + 1 + 73 + 1 + 33 + 1 +
AcceptedHtlcScriptSizeConfirmed
// AcceptedHtlcSuccessWitnessSize 324 bytes
// - number_of_witness_elements: 1 byte
// - nil_length: 1 byte
// - sig_alice_length: 1 byte
// - sig_alice: 73 bytes
// - sig_bob_length: 1 byte
// - sig_bob: 73 bytes
// - preimage_length: 1 byte
// - preimage: 32 bytes
// - witness_script_length: 1 byte
// - witness_script (accepted_htlc_script)
//
// Input to second level success tx, spending non-delayed HTLC output.
AcceptedHtlcSuccessWitnessSize = 1 + 1 + 1 + 73 + 1 + 73 + 1 + 32 + 1 +
AcceptedHtlcScriptSize
// AcceptedHtlcSuccessWitnessSizeConfirmed 327 bytes
//
// Input to second level success tx, spending 1 CSV delayed HTLC output.
AcceptedHtlcSuccessWitnessSizeConfirmed = 1 + 1 + 1 + 73 + 1 + 73 + 1 + 32 + 1 +
AcceptedHtlcScriptSizeConfirmed
// OfferedHtlcScriptSize 133 bytes
// - OP_DUP: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(SHA256(revocationkey)) length)
// - RIPEMD160(SHA256(revocationkey)): 20 bytes
// - OP_EQUAL: 1 byte
// - OP_IF: 1 byte
// - OP_CHECKSIG: 1 byte
// - OP_ELSE: 1 byte
// - OP_DATA: 1 byte (remotekey length)
// - remotekey: 33 bytes
// - OP_SWAP: 1 byte
// - OP_SIZE: 1 byte
// - OP_DATA: 1 byte (32 length)
// - 32: 1 byte
// - OP_EQUAL: 1 byte
// - OP_NOTIF: 1 byte
// - OP_DROP: 1 byte
// - 2: 1 byte
// - OP_SWAP: 1 byte
// - OP_DATA: 1 byte (localkey length)
// - localkey: 33 bytes
// - 2: 1 byte
// - OP_CHECKMULTISIG: 1 byte
// - OP_ELSE: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(payment_hash) length)
// - RIPEMD160(payment_hash): 20 bytes
// - OP_EQUALVERIFY: 1 byte
// - OP_CHECKSIG: 1 byte
// - OP_ENDIF: 1 byte
// - OP_1: 1 byte // These 3 extra bytes are only
// - OP_CSV: 1 byte // present for the confirmed
// - OP_DROP: 1 byte // HTLC script types.
// - OP_ENDIF: 1 byte
OfferedHtlcScriptSize = 3*1 + 20 + 5*1 + 33 + 10*1 + 33 + 5*1 + 20 + 4*1
// OfferedHtlcScriptSizeConfirmed 136 bytes.
OfferedHtlcScriptSizeConfirmed = OfferedHtlcScriptSize +
HtlcConfirmedScriptOverhead
// OfferedHtlcSuccessWitnessSize 242 bytes
// - number_of_witness_elements: 1 byte
// - receiver_sig_length: 1 byte
// - receiver_sig: 73 bytes
// - payment_preimage_length: 1 byte
// - payment_preimage: 32 bytes
// - witness_script_length: 1 byte
// - witness_script (offered_htlc_script)
OfferedHtlcSuccessWitnessSize = 1 + 1 + 73 + 1 + 32 + 1 + OfferedHtlcScriptSize
// OfferedHtlcSuccessWitnessSizeConfirmed 245 bytes.
OfferedHtlcSuccessWitnessSizeConfirmed = 1 + 1 + 73 + 1 + 32 + 1 +
OfferedHtlcScriptSizeConfirmed
// OfferedHtlcTimeoutWitnessSize 285 bytes
// - number_of_witness_elements: 1 byte
// - nil_length: 1 byte
// - sig_alice_length: 1 byte
// - sig_alice: 73 bytes
// - sig_bob_length: 1 byte
// - sig_bob: 73 bytes
// - nil_length: 1 byte
// - witness_script_length: 1 byte
// - witness_script (offered_htlc_script)
//
// Input to second level timeout tx, spending non-delayed HTLC output.
OfferedHtlcTimeoutWitnessSize = 1 + 1 + 1 + 73 + 1 + 73 + 1 + 1 +
OfferedHtlcScriptSize
// OfferedHtlcTimeoutWitnessSizeConfirmed 288 bytes
//
// Input to second level timeout tx, spending 1 CSV delayed HTLC output.
OfferedHtlcTimeoutWitnessSizeConfirmed = 1 + 1 + 1 + 73 + 1 + 73 + 1 + 1 +
OfferedHtlcScriptSizeConfirmed
// OfferedHtlcPenaltyWitnessSize 243 bytes
// - number_of_witness_elements: 1 byte
// - revocation_sig_length: 1 byte
// - revocation_sig: 73 bytes
// - revocation_key_length: 1 byte
// - revocation_key: 33 bytes
// - witness_script_length: 1 byte
// - witness_script (offered_htlc_script)
OfferedHtlcPenaltyWitnessSize = 1 + 1 + 73 + 1 + 33 + 1 + OfferedHtlcScriptSize
// OfferedHtlcPenaltyWitnessSizeConfirmed 246 bytes.
OfferedHtlcPenaltyWitnessSizeConfirmed = 1 + 1 + 73 + 1 + 33 + 1 +
OfferedHtlcScriptSizeConfirmed
// AnchorScriptSize 40 bytes
// - pubkey_length: 1 byte
// - pubkey: 33 bytes
// - OP_CHECKSIG: 1 byte
// - OP_IFDUP: 1 byte
// - OP_NOTIF: 1 byte
// - OP_16: 1 byte
// - OP_CSV 1 byte
// - OP_ENDIF: 1 byte
AnchorScriptSize = 1 + 33 + 6*1
// AnchorWitnessSize 116 bytes
// - number_of_witnes_elements: 1 byte
// - signature_length: 1 byte
// - signature: 73 bytes
// - witness_script_length: 1 byte
// - witness_script (anchor_script)
AnchorWitnessSize = 1 + 1 + 73 + 1 + AnchorScriptSize
// TaprootSignatureWitnessSize 65 bytes
// - sigLength: 1 byte
// - sig: 64 bytes
TaprootSignatureWitnessSize = 1 + 64
// TaprootKeyPathWitnessSize 66 bytes
// - NumberOfWitnessElements: 1 byte
// - sigLength: 1 byte
// - sig: 64 bytes
TaprootKeyPathWitnessSize = 1 + TaprootSignatureWitnessSize
// TaprootKeyPathCustomSighashWitnessSize 67 bytes
// - NumberOfWitnessElements: 1 byte
// - sigLength: 1 byte
// - sig: 64 bytes
// - sighashFlag: 1 byte
TaprootKeyPathCustomSighashWitnessSize = TaprootKeyPathWitnessSize + 1
// TaprootBaseControlBlockWitnessSize 33 bytes
// - leafVersionAndParity: 1 byte
// - schnorrPubKey: 32 byte
TaprootBaseControlBlockWitnessSize = 33
// TaprootToLocalScriptSize
// - OP_DATA: 1 byte (pub key len)
// - local_key: 32 bytes
// - OP_CHECKSIG: 1 byte
// - OP_DATA: 1 byte (csv delay)
// - csv_delay: 4 bytes (worst case estimate)
// - OP_CSV: 1 byte
// - OP_DROP: 1 byte
TaprootToLocalScriptSize = 41
// TaprootToLocalWitnessSize: 175 bytes
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - script_len: 1 byte
// - taproot_to_local_script_size: 41 bytes
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
// - sibling_merkle_hash: 32 bytes
TaprootToLocalWitnessSize = 1 + 1 + 65 + 1 + TaprootToLocalScriptSize +
1 + TaprootBaseControlBlockWitnessSize + 32
// TaprootToLocalRevokeScriptSize: 68 bytes
// - OP_DATA: 1 byte
// - local key: 32 bytes
// - OP_DROP: 1 byte
// - OP_DATA: 1 byte
// - revocation key: 32 bytes
// - OP_CHECKSIG: 1 byte
TaprootToLocalRevokeScriptSize = 1 + 32 + 1 + 1 + 32 + 1
// TaprootToLocalRevokeWitnessSize: 202 bytes
// - NumberOfWitnessElements: 1 byte
// - sigLength: 1 byte
// - sweep sig: 65 bytes
// - script len: 1 byte
// - revocation script size: 68 bytes
// - ctrl block size: 1 byte
// - base control block: 33 bytes
// - merkle proof: 32
TaprootToLocalRevokeWitnessSize = (1 + 1 + 65 + 1 +
TaprootToLocalRevokeScriptSize + 1 + 33 + 32)
// TaprootToRemoteScriptSize
// - OP_DATA: 1 byte
// - remote key: 32 bytes
// - OP_CHECKSIG: 1 byte
// - OP_1: 1 byte
// - OP_CHECKSEQUENCEVERIFY: 1 byte
// - OP_DROP: 1 byte
TaprootToRemoteScriptSize = 1 + 32 + 1 + 1 + 1 + 1
// TaprootToRemoteWitnessSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - script_len: 1 byte
// - taproot_to_local_script_size: 36 bytes
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
TaprootToRemoteWitnessSize = (1 + 1 + 65 + 1 +
TaprootToRemoteScriptSize + 1 +
TaprootBaseControlBlockWitnessSize)
// TaprootAnchorWitnessSize: 67 bytes
//
// In this case, we use the custom sighash size to give the most
// pessemistic estimate.
TaprootAnchorWitnessSize = TaprootKeyPathCustomSighashWitnessSize
// TaprootSecondLevelHtlcScriptSize: 41 bytes
// - OP_DATA: 1 byte (pub key len)
// - local_key: 32 bytes
// - OP_CHECKSIG: 1 byte
// - OP_DATA: 1 byte (csv delay)
// - csv_delay: 4 bytes (worst case)
// - OP_CSV: 1 byte
// - OP_DROP: 1 byte
TaprootSecondLevelHtlcScriptSize = 1 + 32 + 1 + 1 + 4 + 1 + 1
// TaprootSecondLevelHtlcWitnessSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - script_len: 1 byte
// - taproot_second_level_htlc_script_size: 40 bytes
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
TaprootSecondLevelHtlcWitnessSize = 1 + 1 + 65 + 1 +
TaprootSecondLevelHtlcScriptSize + 1 +
TaprootBaseControlBlockWitnessSize
// TaprootSecondLevelRevokeWitnessSize
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
//nolint:ll
TaprootSecondLevelRevokeWitnessSize = TaprootKeyPathCustomSighashWitnessSize
// TaprootAcceptedRevokeWitnessSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
//nolint:ll
TaprootAcceptedRevokeWitnessSize = TaprootKeyPathCustomSighashWitnessSize
// TaprootOfferedRevokeWitnessSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
TaprootOfferedRevokeWitnessSize = TaprootKeyPathCustomSighashWitnessSize
// TaprootHtlcOfferedRemoteTimeoutScriptSize: 42 bytes
// - OP_DATA: 1 byte (pub key len)
// - local_key: 32 bytes
// - OP_CHECKSIG: 1 byte
// - OP_1: 1 byte
// - OP_DROP: 1 byte
// - OP_CHECKSEQUENCEVERIFY: 1 byte
// - OP_DATA: 1 byte (cltv_expiry length)
// - cltv_expiry: 4 bytes
// - OP_CHECKLOCKTIMEVERIFY: 1 byte
// - OP_DROP: 1 byte
TaprootHtlcOfferedRemoteTimeoutScriptSize = (1 + 32 + 1 + 1 + 1 + 1 +
1 + 4 + 1 + 1)
// TaprootHtlcOfferedRemoteTimeoutwitSize: 176 bytes
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - script_len: 1 byte
// - taproot_offered_htlc_script_size: 42 bytes
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
// - sibilng_merkle_proof: 32 bytes
TaprootHtlcOfferedRemoteTimeoutWitnessSize = 1 + 1 + 65 + 1 +
TaprootHtlcOfferedRemoteTimeoutScriptSize + 1 +
TaprootBaseControlBlockWitnessSize + 32
// TaprootHtlcOfferedLocalTmeoutScriptSize:
// - OP_DATA: 1 byte (pub key len)
// - local_key: 32 bytes
// - OP_CHECKSIGVERIFY: 1 byte
// - OP_DATA: 1 byte (pub key len)
// - remote_key: 32 bytes
// - OP_CHECKSIG: 1 byte
TaprootHtlcOfferedLocalTimeoutScriptSize = 1 + 32 + 1 + 1 + 32 + 1
// TaprootOfferedLocalTimeoutWitnessSize
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - script_len: 1 byte
// - taproot_offered_htlc_script_timeout_size:
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
// - sibilng_merkle_proof: 32 bytes
TaprootOfferedLocalTimeoutWitnessSize = 1 + 1 + 65 + 1 + 65 + 1 +
TaprootHtlcOfferedLocalTimeoutScriptSize + 1 +
TaprootBaseControlBlockWitnessSize + 32
// TaprootHtlcAcceptedRemoteSuccessScriptSize:
// - OP_SIZE: 1 byte
// - OP_DATA: 1 byte
// - 32: 1 byte
// - OP_EQUALVERIFY: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(payment_hash) length)
// - RIPEMD160(payment_hash): 20 bytes
// - OP_EQUALVERIFY: 1 byte
// - OP_DATA: 1 byte (pub key len)
// - remote_key: 32 bytes
// - OP_CHECKSIG: 1 byte
// - OP_1: 1 byte
// - OP_CSV: 1 byte
// - OP_DROP: 1 byte
TaprootHtlcAcceptedRemoteSuccessScriptSize = 1 + 1 + 1 + 1 + 1 + 1 +
1 + 20 + 1 + 32 + 1 + 1 + 1 + 1
// TaprootHtlcAcceptedRemoteSuccessScriptSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - payment_preimage_length: 1 byte
// - payment_preimage: 32 bytes
// - script_len: 1 byte
// - taproot_offered_htlc_script_success_size:
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
// - sibilng_merkle_proof: 32 bytes
TaprootHtlcAcceptedRemoteSuccessWitnessSize = 1 + 1 + 65 + 1 + 32 + 1 +
TaprootHtlcAcceptedRemoteSuccessScriptSize + 1 +
TaprootBaseControlBlockWitnessSize + 32
// TaprootHtlcAcceptedLocalSuccessScriptSize:
// - OP_SIZE: 1 byte
// - OP_DATA: 1 byte
// - 32: 1 byte
// - OP_EQUALVERIFY: 1 byte
// - OP_HASH160: 1 byte
// - OP_DATA: 1 byte (RIPEMD160(payment_hash) length)
// - RIPEMD160(payment_hash): 20 bytes
// - OP_EQUALVERIFY: 1 byte
// - OP_DATA: 1 byte (pub key len)
// - local_key: 32 bytes
// - OP_CHECKSIGVERIFY: 1 byte
// - OP_DATA: 1 byte (pub key len)
// - remote_key: 32 bytes
// - OP_CHECKSIG: 1 byte
TaprootHtlcAcceptedLocalSuccessScriptSize = 1 + 1 + 1 + 1 + 1 + 1 +
20 + 1 + 1 + 32 + 1 + 1 + 32 + 1
// TaprootHtlcAcceptedLocalSuccessWitnessSize:
// - number_of_witness_elements: 1 byte
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - sig_len: 1 byte
// - sweep_sig: 65 bytes (worst case w/o sighash default)
// - payment_preimage_length: 1 byte
// - payment_preimage: 32 bytes
// - script_len: 1 byte
// - taproot_accepted_htlc_script_success_size:
// - ctrl_block_len: 1 byte
// - base_control_block_size: 33 bytes
// - sibilng_merkle_proof: 32 bytes
TaprootHtlcAcceptedLocalSuccessWitnessSize = 1 + 1 + 65 + 1 + 65 + 1 +
32 + 1 + TaprootHtlcAcceptedLocalSuccessScriptSize + 1 +
TaprootBaseControlBlockWitnessSize + 32
)
// EstimateCommitTxWeight estimate commitment transaction weight depending on
// the precalculated weight of base transaction, witness data, which is needed
// for paying for funding tx, and htlc weight multiplied by their count.
func EstimateCommitTxWeight(count int, prediction bool) int64 {
// Make prediction about the size of commitment transaction with
// additional HTLC.
if prediction {
count++
}
htlcWeight := int64(count * HTLCWeight)
baseWeight := int64(BaseCommitmentTxWeight)
witnessWeight := int64(WitnessCommitmentTxWeight)
// TODO(roasbeef): need taproot modifier? also no anchor so wrong?
return htlcWeight + baseWeight + witnessWeight
}
// TxWeightEstimator is able to calculate weight estimates for transactions
// based on the input and output types. For purposes of estimation, all
// signatures are assumed to be of the maximum possible size, 73 bytes. Each
// method of the estimator returns an instance with the estimate applied. This
// allows callers to chain each of the methods
type TxWeightEstimator struct {
hasWitness bool
inputCount uint32
outputCount uint32
inputSize lntypes.VByte
inputWitnessSize lntypes.WeightUnit
outputSize lntypes.VByte
}
// AddP2PKHInput updates the weight estimate to account for an additional input
// spending a P2PKH output.
func (twe *TxWeightEstimator) AddP2PKHInput() *TxWeightEstimator {
twe.inputSize += InputSize + P2PKHScriptSigSize
twe.inputWitnessSize++
twe.inputCount++
return twe
}
// AddP2WKHInput updates the weight estimate to account for an additional input
// spending a native P2PWKH output.
func (twe *TxWeightEstimator) AddP2WKHInput() *TxWeightEstimator {
twe.AddWitnessInput(P2WKHWitnessSize)
return twe
}
// AddWitnessInput updates the weight estimate to account for an additional
// input spending a native pay-to-witness output. This accepts the total size
// of the witness as a parameter.
func (twe *TxWeightEstimator) AddWitnessInput(
witnessSize lntypes.WeightUnit) *TxWeightEstimator {
twe.inputSize += InputSize
twe.inputWitnessSize += witnessSize
twe.inputCount++
twe.hasWitness = true
return twe
}
// AddTapscriptInput updates the weight estimate to account for an additional
// input spending a segwit v1 pay-to-taproot output using the script path. This
// accepts the total size of the witness for the script leaf that is executed
// and adds the size of the control block to the total witness size.
//
// NOTE: The leaf witness size must be calculated without the byte that accounts
// for the number of witness elements, only the total size of all elements on
// the stack that are consumed by the revealed script should be counted.
func (twe *TxWeightEstimator) AddTapscriptInput(
leafWitnessSize lntypes.WeightUnit,
tapscript *waddrmgr.Tapscript) *TxWeightEstimator {
// We add 1 byte for the total number of witness elements.
controlBlockWitnessSize := 1 + TaprootBaseControlBlockWitnessSize +
// 1 byte for the length of the element plus the element itself.
1 + len(tapscript.RevealedScript) +
1 + len(tapscript.ControlBlock.InclusionProof)
twe.inputSize += InputSize
twe.inputWitnessSize += leafWitnessSize + lntypes.WeightUnit(
controlBlockWitnessSize,
)
twe.inputCount++
twe.hasWitness = true
return twe
}
// AddTaprootKeySpendInput updates the weight estimate to account for an
// additional input spending a segwit v1 pay-to-taproot output using the key
// spend path. This accepts the sighash type being used since that has an
// influence on the total size of the signature.
func (twe *TxWeightEstimator) AddTaprootKeySpendInput(
hashType txscript.SigHashType) *TxWeightEstimator {
twe.inputSize += InputSize
if hashType == txscript.SigHashDefault {
twe.inputWitnessSize += TaprootKeyPathWitnessSize
} else {
twe.inputWitnessSize += TaprootKeyPathCustomSighashWitnessSize
}
twe.inputCount++
twe.hasWitness = true
return twe
}
// AddNestedP2WKHInput updates the weight estimate to account for an additional
// input spending a P2SH output with a nested P2WKH redeem script.
func (twe *TxWeightEstimator) AddNestedP2WKHInput() *TxWeightEstimator {
twe.inputSize += InputSize + NestedP2WPKHSize
twe.inputWitnessSize += P2WKHWitnessSize
twe.inputCount++
twe.hasWitness = true
return twe
}
// AddNestedP2WSHInput updates the weight estimate to account for an additional
// input spending a P2SH output with a nested P2WSH redeem script.
func (twe *TxWeightEstimator) AddNestedP2WSHInput(
witnessSize lntypes.WeightUnit) *TxWeightEstimator {
twe.inputSize += InputSize + NestedP2WSHSize
twe.inputWitnessSize += witnessSize
twe.inputCount++
twe.hasWitness = true
return twe
}
// AddTxOutput adds a known TxOut to the weight estimator.
func (twe *TxWeightEstimator) AddTxOutput(txOut *wire.TxOut) *TxWeightEstimator {
twe.outputSize += lntypes.VByte(txOut.SerializeSize())
twe.outputCount++
return twe
}
// AddP2PKHOutput updates the weight estimate to account for an additional P2PKH
// output.
func (twe *TxWeightEstimator) AddP2PKHOutput() *TxWeightEstimator {
twe.outputSize += P2PKHOutputSize
twe.outputCount++
return twe
}
// AddP2WKHOutput updates the weight estimate to account for an additional
// native P2WKH output.
func (twe *TxWeightEstimator) AddP2WKHOutput() *TxWeightEstimator {
twe.outputSize += P2WKHOutputSize
twe.outputCount++
return twe
}
// AddP2WSHOutput updates the weight estimate to account for an additional
// native P2WSH output.
func (twe *TxWeightEstimator) AddP2WSHOutput() *TxWeightEstimator {
twe.outputSize += P2WSHOutputSize
twe.outputCount++
return twe
}
// AddP2TROutput updates the weight estimate to account for an additional native
// SegWit v1 P2TR output.
func (twe *TxWeightEstimator) AddP2TROutput() *TxWeightEstimator {
twe.outputSize += P2TROutputSize
twe.outputCount++
return twe
}
// AddP2SHOutput updates the weight estimate to account for an additional P2SH
// output.
func (twe *TxWeightEstimator) AddP2SHOutput() *TxWeightEstimator {
twe.outputSize += P2SHOutputSize
twe.outputCount++
return twe
}
// AddOutput estimates the weight of an output based on the pkScript.
func (twe *TxWeightEstimator) AddOutput(pkScript []byte) *TxWeightEstimator {
twe.outputSize += BaseOutputSize + lntypes.VByte(len(pkScript))
twe.outputCount++
return twe
}
// Weight gets the estimated weight of the transaction.
func (twe *TxWeightEstimator) Weight() lntypes.WeightUnit {
inputCount := wire.VarIntSerializeSize(uint64(twe.inputCount))
outputCount := wire.VarIntSerializeSize(uint64(twe.outputCount))
txSizeStripped := BaseTxSize + lntypes.VByte(inputCount) +
twe.inputSize + lntypes.VByte(outputCount) + twe.outputSize
weight := lntypes.WeightUnit(txSizeStripped * witnessScaleFactor)
if twe.hasWitness {
weight += WitnessHeaderSize + twe.inputWitnessSize
}
return weight
}
// VSize gets the estimated virtual size of the transactions, in vbytes.
func (twe *TxWeightEstimator) VSize() int {
// A tx's vsize is 1/4 of the weight, rounded up.
return int(twe.Weight().ToVB())
}
package input
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/lightningnetwork/lnd/fn/v2"
)
const (
// PubKeyFormatCompressedOdd is the identifier prefix byte for a public
// key whose Y coordinate is odd when serialized in the compressed
// format per section 2.3.4 of
// [SEC1](https://secg.org/sec1-v2.pdf#subsubsection.2.3.4).
// This is copied from the github.com/decred/dcrd/dcrec/secp256k1/v4 to
// avoid needing to directly reference (and by accident pull in
// incompatible crypto primitives) the package.
PubKeyFormatCompressedOdd byte = 0x03
)
// AuxTapLeaf is a type alias for an optional tapscript leaf that may be added
// to the tapscript tree of HTLC and commitment outputs.
type AuxTapLeaf = fn.Option[txscript.TapLeaf]
// NoneTapLeaf returns an empty optional tapscript leaf.
func NoneTapLeaf() AuxTapLeaf {
return fn.None[txscript.TapLeaf]()
}
// HtlcIndex represents the monotonically increasing counter that is used to
// identify HTLCs created by a peer.
type HtlcIndex = uint64
// HtlcAuxLeaf is a type that represents an auxiliary leaf for an HTLC output.
// An HTLC may have up to two aux leaves: one for the output on the commitment
// transaction, and one for the second level HTLC.
type HtlcAuxLeaf struct {
AuxTapLeaf
// SecondLevelLeaf is the auxiliary leaf for the second level HTLC
// success or timeout transaction.
SecondLevelLeaf AuxTapLeaf
}
// HtlcAuxLeaves is a type alias for a map of optional tapscript leaves.
type HtlcAuxLeaves = map[HtlcIndex]HtlcAuxLeaf
// NewTxSigHashesV0Only returns a new txscript.TxSigHashes instance that will
// only calculate the sighash midstate values for segwit v0 inputs and can
// therefore never be used for transactions that want to spend segwit v1
// (taproot) inputs.
func NewTxSigHashesV0Only(tx *wire.MsgTx) *txscript.TxSigHashes {
// The canned output fetcher returns a wire.TxOut instance with the
// given pk script and amount. We can get away with nil since the first
// thing the TxSigHashes constructor checks is the length of the pk
// script and whether it matches taproot output script length. If the
// length doesn't match it assumes v0 inputs only.
nilFetcher := txscript.NewCannedPrevOutputFetcher(nil, 0)
return txscript.NewTxSigHashes(tx, nilFetcher)
}
// MultiPrevOutFetcher returns a txscript.MultiPrevOutFetcher for the given set
// of inputs.
func MultiPrevOutFetcher(inputs []Input) (*txscript.MultiPrevOutFetcher, error) {
fetcher := txscript.NewMultiPrevOutFetcher(nil)
for _, inp := range inputs {
op := inp.OutPoint()
desc := inp.SignDesc()
if op == EmptyOutPoint {
return nil, fmt.Errorf("missing input outpoint")
}
if desc == nil || desc.Output == nil {
return nil, fmt.Errorf("missing input utxo information")
}
fetcher.AddPrevOut(op, desc.Output)
}
return fetcher, nil
}
// TapscriptFullTree creates a waddrmgr.Tapscript for the given internal key and
// tree leaves.
func TapscriptFullTree(internalKey *btcec.PublicKey,
allTreeLeaves ...txscript.TapLeaf) *waddrmgr.Tapscript {
tree := txscript.AssembleTaprootScriptTree(allTreeLeaves...)
rootHash := tree.RootNode.TapHash()
tapKey := txscript.ComputeTaprootOutputKey(internalKey, rootHash[:])
var outputKeyYIsOdd bool
if tapKey.SerializeCompressed()[0] == PubKeyFormatCompressedOdd {
outputKeyYIsOdd = true
}
return &waddrmgr.Tapscript{
Type: waddrmgr.TapscriptTypeFullTree,
ControlBlock: &txscript.ControlBlock{
InternalKey: internalKey,
OutputKeyYIsOdd: outputKeyYIsOdd,
LeafVersion: txscript.BaseLeafVersion,
},
Leaves: allTreeLeaves,
}
}
// TapscriptPartialReveal creates a waddrmgr.Tapscript for the given internal
// key and revealed script.
func TapscriptPartialReveal(internalKey *btcec.PublicKey,
revealedLeaf txscript.TapLeaf,
inclusionProof []byte) *waddrmgr.Tapscript {
controlBlock := &txscript.ControlBlock{
InternalKey: internalKey,
LeafVersion: txscript.BaseLeafVersion,
InclusionProof: inclusionProof,
}
rootHash := controlBlock.RootHash(revealedLeaf.Script)
tapKey := txscript.ComputeTaprootOutputKey(internalKey, rootHash)
if tapKey.SerializeCompressed()[0] == PubKeyFormatCompressedOdd {
controlBlock.OutputKeyYIsOdd = true
}
return &waddrmgr.Tapscript{
Type: waddrmgr.TapscriptTypePartialReveal,
ControlBlock: controlBlock,
RevealedScript: revealedLeaf.Script,
}
}
// TapscriptRootHashOnly creates a waddrmgr.Tapscript for the given internal key
// and root hash.
func TapscriptRootHashOnly(internalKey *btcec.PublicKey,
rootHash []byte) *waddrmgr.Tapscript {
controlBlock := &txscript.ControlBlock{
InternalKey: internalKey,
}
tapKey := txscript.ComputeTaprootOutputKey(internalKey, rootHash)
if tapKey.SerializeCompressed()[0] == PubKeyFormatCompressedOdd {
controlBlock.OutputKeyYIsOdd = true
}
return &waddrmgr.Tapscript{
Type: waddrmgr.TaprootKeySpendRootHash,
ControlBlock: controlBlock,
RootHash: rootHash,
}
}
// TapscriptFullKeyOnly creates a waddrmgr.Tapscript for the given full Taproot
// key.
func TapscriptFullKeyOnly(taprootKey *btcec.PublicKey) *waddrmgr.Tapscript {
return &waddrmgr.Tapscript{
Type: waddrmgr.TaprootFullKeyOnly,
FullOutputKey: taprootKey,
}
}
// PayToTaprootScript creates a new script to pay to a version 1 (taproot)
// witness program. The passed public key will be serialized as an x-only key
// to create the witness program.
func PayToTaprootScript(taprootKey *btcec.PublicKey) ([]byte, error) {
builder := txscript.NewScriptBuilder()
builder.AddOp(txscript.OP_1)
builder.AddData(schnorr.SerializePubKey(taprootKey))
return builder.Script()
}
package input
import (
"bytes"
"encoding/hex"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/keychain"
)
var (
// For simplicity a single priv key controls all of our test outputs.
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
// We're alice :)
bobsPrivKey = []byte{
0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
0x63, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
0xd, 0xe7, 0x95, 0xe4, 0xb7, 0x25, 0xb8, 0x4d,
0x1e, 0xb, 0x4c, 0xfd, 0x9e, 0xc5, 0x8c, 0xe9,
}
// Use a hard-coded HD seed.
testHdSeed = chainhash.Hash{
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9,
0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
}
)
// MockSigner is a simple implementation of the Signer interface. Each one has
// a set of private keys in a slice and can sign messages using the appropriate
// one.
type MockSigner struct {
Privkeys []*btcec.PrivateKey
NetParams *chaincfg.Params
*MusigSessionManager
}
// NewMockSigner returns a new instance of the MockSigner given a set of
// backing private keys.
func NewMockSigner(privKeys []*btcec.PrivateKey,
netParams *chaincfg.Params) *MockSigner {
signer := &MockSigner{
Privkeys: privKeys,
NetParams: netParams,
}
keyFetcher := func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
return signer.Privkeys[0], nil
}
signer.MusigSessionManager = NewMusigSessionManager(keyFetcher)
return signer
}
// SignOutputRaw generates a signature for the passed transaction according to
// the data within the passed SignDescriptor.
func (m *MockSigner) SignOutputRaw(tx *wire.MsgTx,
signDesc *SignDescriptor) (Signature, error) {
pubkey := signDesc.KeyDesc.PubKey
switch {
case signDesc.SingleTweak != nil:
pubkey = TweakPubKeyWithTweak(pubkey, signDesc.SingleTweak)
case signDesc.DoubleTweak != nil:
pubkey = DeriveRevocationPubkey(pubkey, signDesc.DoubleTweak.PubKey())
}
hash160 := btcutil.Hash160(pubkey.SerializeCompressed())
privKey := m.findKey(hash160, signDesc.SingleTweak, signDesc.DoubleTweak)
if privKey == nil {
return nil, fmt.Errorf("mock signer does not have key")
}
// In case of a taproot output any signature is always a Schnorr
// signature, based on the new tapscript sighash algorithm.
if txscript.IsPayToTaproot(signDesc.Output.PkScript) {
sigHashes := txscript.NewTxSigHashes(
tx, signDesc.PrevOutputFetcher,
)
// Are we spending a script path or the key path? The API is
// slightly different, so we need to account for that to get
// the raw signature.
var (
rawSig []byte
err error
)
switch signDesc.SignMethod {
case TaprootKeySpendBIP0086SignMethod,
TaprootKeySpendSignMethod:
// This function tweaks the private key using the tap
// root key supplied as the tweak.
rawSig, err = txscript.RawTxInTaprootSignature(
tx, sigHashes, signDesc.InputIndex,
signDesc.Output.Value, signDesc.Output.PkScript,
signDesc.TapTweak, signDesc.HashType,
privKey,
)
if err != nil {
return nil, err
}
case TaprootScriptSpendSignMethod:
leaf := txscript.TapLeaf{
LeafVersion: txscript.BaseLeafVersion,
Script: signDesc.WitnessScript,
}
rawSig, err = txscript.RawTxInTapscriptSignature(
tx, sigHashes, signDesc.InputIndex,
signDesc.Output.Value, signDesc.Output.PkScript,
leaf, signDesc.HashType, privKey,
)
if err != nil {
return nil, err
}
}
// The signature returned above might have a sighash flag
// attached if a non-default type was used. We'll slice this
// off if it exists to ensure we can properly parse the raw
// signature.
sig, err := schnorr.ParseSignature(
rawSig[:schnorr.SignatureSize],
)
if err != nil {
return nil, err
}
return sig, nil
}
sig, err := txscript.RawTxInWitnessSignature(
tx, signDesc.SigHashes, signDesc.InputIndex,
signDesc.Output.Value, signDesc.WitnessScript,
signDesc.HashType, privKey,
)
if err != nil {
return nil, err
}
return ecdsa.ParseDERSignature(sig[:len(sig)-1])
}
// ComputeInputScript generates a complete InputIndex for the passed transaction
// with the signature as defined within the passed SignDescriptor. This method
// should be capable of generating the proper input script for both regular
// p2wkh output and p2wkh outputs nested within a regular p2sh output.
func (m *MockSigner) ComputeInputScript(tx *wire.MsgTx, signDesc *SignDescriptor) (*Script, error) {
scriptType, addresses, _, err := txscript.ExtractPkScriptAddrs(
signDesc.Output.PkScript, m.NetParams)
if err != nil {
return nil, err
}
switch scriptType {
case txscript.PubKeyHashTy:
privKey := m.findKey(addresses[0].ScriptAddress(), signDesc.SingleTweak,
signDesc.DoubleTweak)
if privKey == nil {
return nil, fmt.Errorf("mock signer does not have key for "+
"address %v", addresses[0])
}
sigScript, err := txscript.SignatureScript(
tx, signDesc.InputIndex, signDesc.Output.PkScript,
txscript.SigHashAll, privKey, true,
)
if err != nil {
return nil, err
}
return &Script{SigScript: sigScript}, nil
case txscript.WitnessV0PubKeyHashTy:
privKey := m.findKey(addresses[0].ScriptAddress(), signDesc.SingleTweak,
signDesc.DoubleTweak)
if privKey == nil {
return nil, fmt.Errorf("mock signer does not have key for "+
"address %v", addresses[0])
}
witnessScript, err := txscript.WitnessSignature(tx, signDesc.SigHashes,
signDesc.InputIndex, signDesc.Output.Value,
signDesc.Output.PkScript, txscript.SigHashAll, privKey, true)
if err != nil {
return nil, err
}
return &Script{Witness: witnessScript}, nil
default:
return nil, fmt.Errorf("unexpected script type: %v", scriptType)
}
}
// findKey searches through all stored private keys and returns one
// corresponding to the hashed pubkey if it can be found. The public key may
// either correspond directly to the private key or to the private key with a
// tweak applied.
func (m *MockSigner) findKey(needleHash160 []byte, singleTweak []byte,
doubleTweak *btcec.PrivateKey) *btcec.PrivateKey {
for _, privkey := range m.Privkeys {
// First check whether public key is directly derived from
// private key.
hash160 := btcutil.Hash160(privkey.PubKey().SerializeCompressed())
if bytes.Equal(hash160, needleHash160) {
return privkey
}
// Otherwise check if public key is derived from tweaked
// private key.
switch {
case singleTweak != nil:
privkey = TweakPrivKey(privkey, singleTweak)
case doubleTweak != nil:
privkey = DeriveRevocationPrivKey(privkey, doubleTweak)
default:
continue
}
hash160 = btcutil.Hash160(privkey.PubKey().SerializeCompressed())
if bytes.Equal(hash160, needleHash160) {
return privkey
}
}
return nil
}
// pubkeyFromHex parses a Bitcoin public key from a hex encoded string.
func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) {
bytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, err
}
return btcec.ParsePubKey(bytes)
}
// privkeyFromHex parses a Bitcoin private key from a hex encoded string.
func privkeyFromHex(keyHex string) (*btcec.PrivateKey, error) {
bytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, err
}
key, _ := btcec.PrivKeyFromBytes(bytes)
return key, nil
}
// pubkeyToHex serializes a Bitcoin public key to a hex encoded string.
func pubkeyToHex(key *btcec.PublicKey) string {
return hex.EncodeToString(key.SerializeCompressed())
}
// privkeyFromHex serializes a Bitcoin private key to a hex encoded string.
func privkeyToHex(key *btcec.PrivateKey) string {
return hex.EncodeToString(key.Serialize())
}
package input
import (
"encoding/binary"
"io"
"github.com/btcsuite/btcd/wire"
)
// writeTxOut serializes a wire.TxOut struct into the passed io.Writer stream.
func writeTxOut(w io.Writer, txo *wire.TxOut) error {
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], uint64(txo.Value))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, txo.PkScript); err != nil {
return err
}
return nil
}
// readTxOut deserializes a wire.TxOut struct from the passed io.Reader stream.
func readTxOut(r io.Reader, txo *wire.TxOut) error {
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
value := int64(binary.BigEndian.Uint64(scratch[:]))
pkScript, err := wire.ReadVarBytes(r, 0, 80, "pkScript")
if err != nil {
return err
}
*txo = wire.TxOut{
Value: value,
PkScript: pkScript,
}
return nil
}
package input
import (
"fmt"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lntypes"
)
// WitnessGenerator represents a function that is able to generate the final
// witness for a particular public key script. Additionally, if required, this
// function will also return the sigScript for spending nested P2SH witness
// outputs. This function acts as an abstraction layer, hiding the details of
// the underlying script.
type WitnessGenerator func(tx *wire.MsgTx, hc *txscript.TxSigHashes,
inputIndex int) (*Script, error)
// WitnessType determines how an output's witness will be generated. This
// interface can be implemented to be used for custom sweep scripts if the
// pre-defined StandardWitnessType list doesn't provide a suitable one.
type WitnessType interface {
// String returns a human readable version of the WitnessType.
String() string
// WitnessGenerator will return a WitnessGenerator function that an
// output uses to generate the witness and optionally the sigScript for
// a sweep transaction.
WitnessGenerator(signer Signer,
descriptor *SignDescriptor) WitnessGenerator
// SizeUpperBound returns the maximum length of the witness of this
// WitnessType if it would be included in a tx. It also returns if the
// output itself is a nested p2sh output, if so then we need to take
// into account the extra sigScript data size.
SizeUpperBound() (lntypes.WeightUnit, bool, error)
// AddWeightEstimation adds the estimated size of the witness in bytes
// to the given weight estimator.
AddWeightEstimation(e *TxWeightEstimator) error
}
// StandardWitnessType is a numeric representation of standard pre-defined types
// of witness configurations.
type StandardWitnessType uint16
// A compile time check to ensure StandardWitnessType implements the
// WitnessType interface.
var _ WitnessType = (StandardWitnessType)(0)
// NOTE: When adding a new `StandardWitnessType`, also update the `WitnessType`
// protobuf enum and the `allWitnessTypes` map in the `walletrpc` package.
const (
// CommitmentTimeLock is a witness that allows us to spend our output
// on our local commitment transaction after a relative lock-time
// lockout.
CommitmentTimeLock StandardWitnessType = 0
// CommitmentNoDelay is a witness that allows us to spend a settled
// no-delay output immediately on a counterparty's commitment
// transaction.
CommitmentNoDelay StandardWitnessType = 1
// CommitmentRevoke is a witness that allows us to sweep the settled
// output of a malicious counterparty's who broadcasts a revoked
// commitment transaction.
CommitmentRevoke StandardWitnessType = 2
// HtlcOfferedRevoke is a witness that allows us to sweep an HTLC which
// we offered to the remote party in the case that they broadcast a
// revoked commitment state.
HtlcOfferedRevoke StandardWitnessType = 3
// HtlcAcceptedRevoke is a witness that allows us to sweep an HTLC
// output sent to us in the case that the remote party broadcasts a
// revoked commitment state.
HtlcAcceptedRevoke StandardWitnessType = 4
// HtlcOfferedTimeoutSecondLevel is a witness that allows us to sweep
// an HTLC output that we extended to a party, but was never fulfilled.
// This HTLC output isn't directly on the commitment transaction, but
// is the result of a confirmed second-level HTLC transaction. As a
// result, we can only spend this after a CSV delay.
HtlcOfferedTimeoutSecondLevel StandardWitnessType = 5
// HtlcOfferedTimeoutSecondLevelInputConfirmed is a witness that allows
// us to sweep an HTLC output that we extended to a party, but was
// never fulfilled. This _is_ the HTLC output directly on our
// commitment transaction, and the input to the second-level HTLC
// timeout transaction. It can only be spent after CLTV expiry, and
// commitment confirmation.
HtlcOfferedTimeoutSecondLevelInputConfirmed StandardWitnessType = 15
// HtlcAcceptedSuccessSecondLevel is a witness that allows us to sweep
// an HTLC output that was offered to us, and for which we have a
// payment preimage. This HTLC output isn't directly on our commitment
// transaction, but is the result of confirmed second-level HTLC
// transaction. As a result, we can only spend this after a CSV delay.
HtlcAcceptedSuccessSecondLevel StandardWitnessType = 6
// HtlcAcceptedSuccessSecondLevelInputConfirmed is a witness that
// allows us to sweep an HTLC output that was offered to us, and for
// which we have a payment preimage. This _is_ the HTLC output directly
// on our commitment transaction, and the input to the second-level
// HTLC success transaction. It can only be spent after the commitment
// has confirmed.
HtlcAcceptedSuccessSecondLevelInputConfirmed StandardWitnessType = 16
// HtlcOfferedRemoteTimeout is a witness that allows us to sweep an
// HTLC that we offered to the remote party which lies in the
// commitment transaction of the remote party. We can spend this output
// after the absolute CLTV timeout of the HTLC as passed.
HtlcOfferedRemoteTimeout StandardWitnessType = 7
// HtlcAcceptedRemoteSuccess is a witness that allows us to sweep an
// HTLC that was offered to us by the remote party. We use this witness
// in the case that the remote party goes to chain, and we know the
// pre-image to the HTLC. We can sweep this without any additional
// timeout.
HtlcAcceptedRemoteSuccess StandardWitnessType = 8
// HtlcSecondLevelRevoke is a witness that allows us to sweep an HTLC
// from the remote party's commitment transaction in the case that the
// broadcast a revoked commitment, but then also immediately attempt to
// go to the second level to claim the HTLC.
HtlcSecondLevelRevoke StandardWitnessType = 9
// WitnessKeyHash is a witness type that allows us to spend a regular
// p2wkh output that's sent to an output which is under complete
// control of the backing wallet.
WitnessKeyHash StandardWitnessType = 10
// NestedWitnessKeyHash is a witness type that allows us to sweep an
// output that sends to a nested P2SH script that pays to a key solely
// under our control. The witness generated needs to include the
NestedWitnessKeyHash StandardWitnessType = 11
// CommitSpendNoDelayTweakless is similar to the CommitSpendNoDelay
// type, but it omits the tweak that randomizes the key we need to
// spend with a channel peer supplied set of randomness.
CommitSpendNoDelayTweakless StandardWitnessType = 12
// CommitmentToRemoteConfirmed is a witness that allows us to spend our
// output on the counterparty's commitment transaction after a
// confirmation.
CommitmentToRemoteConfirmed StandardWitnessType = 13
// CommitmentAnchor is a witness that allows us to spend our anchor on
// the commitment transaction.
CommitmentAnchor StandardWitnessType = 14
// LeaseCommitmentTimeLock is a witness that allows us to spend our
// output on our local commitment transaction after a relative and
// absolute lock-time lockout as part of the script enforced lease
// commitment type.
LeaseCommitmentTimeLock StandardWitnessType = 17
// LeaseCommitmentToRemoteConfirmed is a witness that allows us to spend
// our output on the counterparty's commitment transaction after a
// confirmation and absolute locktime as part of the script enforced
// lease commitment type.
LeaseCommitmentToRemoteConfirmed StandardWitnessType = 18
// LeaseHtlcOfferedTimeoutSecondLevel is a witness that allows us to
// sweep an HTLC output that we extended to a party, but was never
// fulfilled. This HTLC output isn't directly on the commitment
// transaction, but is the result of a confirmed second-level HTLC
// transaction. As a result, we can only spend this after a CSV delay
// and CLTV locktime as part of the script enforced lease commitment
// type.
LeaseHtlcOfferedTimeoutSecondLevel StandardWitnessType = 19
// LeaseHtlcAcceptedSuccessSecondLevel is a witness that allows us to
// sweep an HTLC output that was offered to us, and for which we have a
// payment preimage. This HTLC output isn't directly on our commitment
// transaction, but is the result of confirmed second-level HTLC
// transaction. As a result, we can only spend this after a CSV delay
// and CLTV locktime as part of the script enforced lease commitment
// type.
LeaseHtlcAcceptedSuccessSecondLevel StandardWitnessType = 20
// TaprootPubKeySpend is a witness type that allows us to spend a
// regular p2tr output that's sent to an output which is under complete
// control of the backing wallet.
TaprootPubKeySpend StandardWitnessType = 21
// TaprootLocalCommitSpend is a witness type that allows us to spend
// our settled local commitment after a CSV delay when we force close
// the channel.
TaprootLocalCommitSpend StandardWitnessType = 22
// TaprootRemoteCommitSpend is a witness type that allows us to spend
// our settled local commitment after a CSV delay when the remote party
// has force closed the channel.
TaprootRemoteCommitSpend StandardWitnessType = 23
// TaprootAnchorSweepSpend is the witness type we'll use for spending
// our own anchor output.
TaprootAnchorSweepSpend StandardWitnessType = 24
// TaprootHtlcOfferedTimeoutSecondLevel is a witness that allows us to
// timeout an HTLC we offered to the remote party on our commitment
// transaction. We use this when we need to go on chain to time out an
// HTLC.
TaprootHtlcOfferedTimeoutSecondLevel StandardWitnessType = 25
// TaprootHtlcAcceptedSuccessSecondLevel is a witness that allows us to
// sweep an HTLC we accepted on our commitment transaction after we go
// to the second level on chain.
TaprootHtlcAcceptedSuccessSecondLevel StandardWitnessType = 26
// TaprootHtlcSecondLevelRevoke is a witness that allows us to sweep an
// HTLC on the revoked transaction of the remote party that goes to the
// second level.
TaprootHtlcSecondLevelRevoke StandardWitnessType = 27
// TaprootHtlcAcceptedRevoke is a witness that allows us to sweep an
// HTLC sent to us by the remote party in the event that they broadcast
// a revoked state.
TaprootHtlcAcceptedRevoke StandardWitnessType = 28
// TaprootHtlcOfferedRevoke is a witness that allows us to sweep an
// HTLC we offered to the remote party if they broadcast a revoked
// commitment.
TaprootHtlcOfferedRevoke StandardWitnessType = 29
// TaprootHtlcOfferedRemoteTimeout is a witness that allows us to sweep
// an HTLC we offered to the remote party that lies on the commitment
// transaction for the remote party. We can spend this output after the
// absolute CLTV timeout of the HTLC as passed.
TaprootHtlcOfferedRemoteTimeout StandardWitnessType = 30
// TaprootHtlcLocalOfferedTimeout is a witness type that allows us to
// sign the second level HTLC timeout transaction when spending from an
// HTLC residing on our local commitment transaction.
//
// This is used by the sweeper to re-sign inputs if it needs to
// aggregate several second level HTLCs.
TaprootHtlcLocalOfferedTimeout StandardWitnessType = 31
// TaprootHtlcAcceptedRemoteSuccess is a witness that allows us to
// sweep an HTLC that was offered to us by the remote party for a
// taproot channels. We use this witness in the case that the remote
// party goes to chain, and we know the pre-image to the HTLC. We can
// sweep this without any additional timeout.
TaprootHtlcAcceptedRemoteSuccess StandardWitnessType = 32
// TaprootHtlcAcceptedLocalSuccess is a witness type that allows us to
// sweep the HTLC offered to us on our local commitment transaction.
// We'll use this when we need to go on chain to sweep the HTLC. In
// this case, this is the second level HTLC success transaction.
TaprootHtlcAcceptedLocalSuccess StandardWitnessType = 33
// TaprootCommitmentRevoke is a witness that allows us to sweep the
// settled output of a malicious counterparty's who broadcasts a
// revoked taproot commitment transaction.
TaprootCommitmentRevoke StandardWitnessType = 34
)
// String returns a human readable version of the target WitnessType.
//
// NOTE: This is part of the WitnessType interface.
func (wt StandardWitnessType) String() string {
switch wt {
case CommitmentTimeLock:
return "CommitmentTimeLock"
case CommitmentToRemoteConfirmed:
return "CommitmentToRemoteConfirmed"
case CommitmentAnchor:
return "CommitmentAnchor"
case CommitmentNoDelay:
return "CommitmentNoDelay"
case CommitSpendNoDelayTweakless:
return "CommitmentNoDelayTweakless"
case CommitmentRevoke:
return "CommitmentRevoke"
case HtlcOfferedRevoke:
return "HtlcOfferedRevoke"
case HtlcAcceptedRevoke:
return "HtlcAcceptedRevoke"
case HtlcOfferedTimeoutSecondLevel:
return "HtlcOfferedTimeoutSecondLevel"
case HtlcOfferedTimeoutSecondLevelInputConfirmed:
return "HtlcOfferedTimeoutSecondLevelInputConfirmed"
case HtlcAcceptedSuccessSecondLevel:
return "HtlcAcceptedSuccessSecondLevel"
case HtlcAcceptedSuccessSecondLevelInputConfirmed:
return "HtlcAcceptedSuccessSecondLevelInputConfirmed"
case HtlcOfferedRemoteTimeout:
return "HtlcOfferedRemoteTimeout"
case HtlcAcceptedRemoteSuccess:
return "HtlcAcceptedRemoteSuccess"
case HtlcSecondLevelRevoke:
return "HtlcSecondLevelRevoke"
case WitnessKeyHash:
return "WitnessKeyHash"
case NestedWitnessKeyHash:
return "NestedWitnessKeyHash"
case LeaseCommitmentTimeLock:
return "LeaseCommitmentTimeLock"
case LeaseCommitmentToRemoteConfirmed:
return "LeaseCommitmentToRemoteConfirmed"
case LeaseHtlcOfferedTimeoutSecondLevel:
return "LeaseHtlcOfferedTimeoutSecondLevel"
case LeaseHtlcAcceptedSuccessSecondLevel:
return "LeaseHtlcAcceptedSuccessSecondLevel"
case TaprootPubKeySpend:
return "TaprootPubKeySpend"
case TaprootLocalCommitSpend:
return "TaprootLocalCommitSpend"
case TaprootRemoteCommitSpend:
return "TaprootRemoteCommitSpend"
case TaprootAnchorSweepSpend:
return "TaprootAnchorSweepSpend"
case TaprootHtlcOfferedTimeoutSecondLevel:
return "TaprootHtlcOfferedTimeoutSecondLevel"
case TaprootHtlcAcceptedSuccessSecondLevel:
return "TaprootHtlcAcceptedSuccessSecondLevel"
case TaprootHtlcSecondLevelRevoke:
return "TaprootHtlcSecondLevelRevoke"
case TaprootHtlcAcceptedRevoke:
return "TaprootHtlcAcceptedRevoke"
case TaprootHtlcOfferedRevoke:
return "TaprootHtlcOfferedRevoke"
case TaprootHtlcOfferedRemoteTimeout:
return "TaprootHtlcOfferedRemoteTimeout"
case TaprootHtlcLocalOfferedTimeout:
return "TaprootHtlcLocalOfferedTimeout"
case TaprootHtlcAcceptedRemoteSuccess:
return "TaprootHtlcAcceptedRemoteSuccess"
case TaprootHtlcAcceptedLocalSuccess:
return "TaprootHtlcAcceptedLocalSuccess"
case TaprootCommitmentRevoke:
return "TaprootCommitmentRevoke"
default:
return fmt.Sprintf("Unknown WitnessType: %v", uint32(wt))
}
}
// WitnessGenerator will return a WitnessGenerator function that an output uses
// to generate the witness and optionally the sigScript for a sweep
// transaction. The sigScript will be generated if the witness type warrants
// one for spending, such as the NestedWitnessKeyHash witness type.
//
// NOTE: This is part of the WitnessType interface.
func (wt StandardWitnessType) WitnessGenerator(signer Signer,
descriptor *SignDescriptor) WitnessGenerator {
return func(tx *wire.MsgTx, hc *txscript.TxSigHashes,
inputIndex int) (*Script, error) {
// TODO(roasbeef): copy the desc?
desc := descriptor
desc.SigHashes = hc
desc.InputIndex = inputIndex
switch wt {
case CommitmentTimeLock, LeaseCommitmentTimeLock:
witness, err := CommitSpendTimeout(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case CommitmentToRemoteConfirmed, LeaseCommitmentToRemoteConfirmed:
witness, err := CommitSpendToRemoteConfirmed(
signer, desc, tx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case CommitmentAnchor:
witness, err := CommitSpendAnchor(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case CommitmentNoDelay:
witness, err := CommitSpendNoDelay(signer, desc, tx, false)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case CommitSpendNoDelayTweakless:
witness, err := CommitSpendNoDelay(signer, desc, tx, true)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case CommitmentRevoke:
witness, err := CommitSpendRevoke(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case HtlcOfferedRevoke:
witness, err := ReceiverHtlcSpendRevoke(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case HtlcAcceptedRevoke:
witness, err := SenderHtlcSpendRevoke(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case HtlcOfferedTimeoutSecondLevel,
LeaseHtlcOfferedTimeoutSecondLevel,
HtlcAcceptedSuccessSecondLevel,
LeaseHtlcAcceptedSuccessSecondLevel:
witness, err := HtlcSecondLevelSpend(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case HtlcOfferedRemoteTimeout:
// We pass in a value of -1 for the timeout, as we
// expect the caller to have already set the lock time
// value.
witness, err := ReceiverHtlcSpendTimeout(signer, desc, tx, -1)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case HtlcSecondLevelRevoke:
witness, err := HtlcSpendRevoke(signer, desc, tx)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case WitnessKeyHash:
fallthrough
case TaprootPubKeySpend:
fallthrough
case NestedWitnessKeyHash:
return signer.ComputeInputScript(tx, desc)
case TaprootLocalCommitSpend:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootScriptSpendSignMethod
// The control block bytes must be set at this point.
if desc.ControlBlock == nil {
return nil, fmt.Errorf("control block must " +
"be set for taproot spend")
}
witness, err := TaprootCommitSpendSuccess(
signer, desc, tx, nil,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootRemoteCommitSpend:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootScriptSpendSignMethod
// The control block bytes must be set at this point.
if desc.ControlBlock == nil {
return nil, fmt.Errorf("control block must " +
"be set for taproot spend")
}
witness, err := TaprootCommitRemoteSpend(
signer, desc, tx, nil,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootAnchorSweepSpend:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootKeySpendSignMethod
// The tap tweak must be set at this point.
if desc.TapTweak == nil {
return nil, fmt.Errorf("tap tweak must be " +
"set for keyspend")
}
witness, err := TaprootAnchorSpend(
signer, desc, tx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootHtlcOfferedTimeoutSecondLevel,
TaprootHtlcAcceptedSuccessSecondLevel:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootScriptSpendSignMethod
// The control block bytes must be set at this point.
if desc.ControlBlock == nil {
return nil, fmt.Errorf("control block must " +
"be set for taproot spend")
}
witness, err := TaprootHtlcSpendSuccess(
signer, desc, tx, nil, nil,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootHtlcSecondLevelRevoke:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootKeySpendSignMethod
// The tap tweak must be set at this point.
if desc.TapTweak == nil {
return nil, fmt.Errorf("tap tweak must be " +
"set for keyspend")
}
witness, err := TaprootHtlcSpendRevoke(
signer, desc, tx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootHtlcOfferedRevoke:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootKeySpendSignMethod
// The tap tweak must be set at this point.
if desc.TapTweak == nil {
return nil, fmt.Errorf("tap tweak must be " +
"set for keyspend")
}
witness, err := SenderHTLCScriptTaprootRevoke(
signer, desc, tx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootHtlcAcceptedRevoke:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootKeySpendSignMethod
// The tap tweak must be set at this point.
if desc.TapTweak == nil {
return nil, fmt.Errorf("tap tweak must be " +
"set for keyspend")
}
witness, err := ReceiverHTLCScriptTaprootRevoke(
signer, desc, tx,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootHtlcOfferedRemoteTimeout:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootScriptSpendSignMethod
// The control block bytes must be set at this point.
if desc.ControlBlock == nil {
return nil, fmt.Errorf("control block " +
"must be set for taproot spend")
}
witness, err := ReceiverHTLCScriptTaprootTimeout(
signer, desc, tx, -1, nil, nil,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
case TaprootCommitmentRevoke:
// Ensure that the sign desc has the proper sign method
// set, and a valid prev output fetcher.
desc.SignMethod = TaprootScriptSpendSignMethod
// The control block bytes must be set at this point.
if desc.ControlBlock == nil {
return nil, fmt.Errorf("control block " +
"must be set for taproot spend")
}
witness, err := TaprootCommitSpendRevoke(
signer, desc, tx, nil,
)
if err != nil {
return nil, err
}
return &Script{
Witness: witness,
}, nil
default:
return nil, fmt.Errorf("unknown witness type: %v", wt)
}
}
}
// SizeUpperBound returns the maximum length of the witness of this witness
// type if it would be included in a tx. We also return if the output itself is
// a nested p2sh output, if so then we need to take into account the extra
// sigScript data size.
//
// NOTE: This is part of the WitnessType interface.
func (wt StandardWitnessType) SizeUpperBound() (lntypes.WeightUnit,
bool, error) {
switch wt {
// Outputs on a remote commitment transaction that pay directly to us.
case CommitSpendNoDelayTweakless:
fallthrough
case WitnessKeyHash:
fallthrough
case CommitmentNoDelay:
return P2WKHWitnessSize, false, nil
// Outputs on a past commitment transaction that pay directly
// to us.
case CommitmentTimeLock:
return ToLocalTimeoutWitnessSize, false, nil
case LeaseCommitmentTimeLock:
size := ToLocalTimeoutWitnessSize +
LeaseWitnessScriptSizeOverhead
return lntypes.WeightUnit(size), false, nil
// 1 CSV time locked output to us on remote commitment.
case CommitmentToRemoteConfirmed:
return ToRemoteConfirmedWitnessSize, false, nil
case LeaseCommitmentToRemoteConfirmed:
size := ToRemoteConfirmedWitnessSize +
LeaseWitnessScriptSizeOverhead
return lntypes.WeightUnit(size), false, nil
// Anchor output on the commitment transaction.
case CommitmentAnchor:
return AnchorWitnessSize, false, nil
// Outgoing second layer HTLC's that have confirmed within the
// chain, and the output they produced is now mature enough to
// sweep.
case HtlcOfferedTimeoutSecondLevel:
return ToLocalTimeoutWitnessSize, false, nil
case LeaseHtlcOfferedTimeoutSecondLevel:
size := ToLocalTimeoutWitnessSize +
LeaseWitnessScriptSizeOverhead
return lntypes.WeightUnit(size), false, nil
// Input to the outgoing HTLC second layer timeout transaction.
case HtlcOfferedTimeoutSecondLevelInputConfirmed:
return OfferedHtlcTimeoutWitnessSizeConfirmed, false, nil
// Incoming second layer HTLC's that have confirmed within the
// chain, and the output they produced is now mature enough to
// sweep.
case HtlcAcceptedSuccessSecondLevel:
return ToLocalTimeoutWitnessSize, false, nil
case LeaseHtlcAcceptedSuccessSecondLevel:
size := ToLocalTimeoutWitnessSize +
LeaseWitnessScriptSizeOverhead
return lntypes.WeightUnit(size), false, nil
// Input to the incoming second-layer HTLC success transaction.
case HtlcAcceptedSuccessSecondLevelInputConfirmed:
return AcceptedHtlcSuccessWitnessSizeConfirmed, false, nil
// An HTLC on the commitment transaction of the remote party,
// that has had its absolute timelock expire.
case HtlcOfferedRemoteTimeout:
return AcceptedHtlcTimeoutWitnessSize, false, nil
// An HTLC on the commitment transaction of the remote party,
// that can be swept with the preimage.
case HtlcAcceptedRemoteSuccess:
return OfferedHtlcSuccessWitnessSize, false, nil
// A nested P2SH input that has a p2wkh witness script. We'll mark this
// as nested P2SH so the caller can estimate the weight properly
// including the sigScript.
case NestedWitnessKeyHash:
return P2WKHWitnessSize, true, nil
// The revocation output on a revoked commitment transaction.
case CommitmentRevoke:
return ToLocalPenaltyWitnessSize, false, nil
// The revocation output on a revoked HTLC that we offered to the remote
// party.
case HtlcOfferedRevoke:
return OfferedHtlcPenaltyWitnessSize, false, nil
// The revocation output on a revoked HTLC that was sent to us.
case HtlcAcceptedRevoke:
return AcceptedHtlcPenaltyWitnessSize, false, nil
// The revocation output of a second level output of an HTLC.
case HtlcSecondLevelRevoke:
return ToLocalPenaltyWitnessSize, false, nil
case TaprootPubKeySpend:
return TaprootKeyPathCustomSighashWitnessSize, false, nil
// Sweeping a self output after a delay for taproot channels.
case TaprootLocalCommitSpend:
return TaprootToLocalWitnessSize, false, nil
// Sweeping a self output after the remote party fro ce closes. Must
// wait 1 CSV.
case TaprootRemoteCommitSpend:
return TaprootToRemoteWitnessSize, false, nil
// Sweeping our anchor output with a key spend witness.
case TaprootAnchorSweepSpend:
return TaprootAnchorWitnessSize, false, nil
case TaprootHtlcOfferedTimeoutSecondLevel,
TaprootHtlcAcceptedSuccessSecondLevel:
return TaprootSecondLevelHtlcWitnessSize, false, nil
case TaprootHtlcSecondLevelRevoke:
return TaprootSecondLevelRevokeWitnessSize, false, nil
case TaprootHtlcAcceptedRevoke:
return TaprootAcceptedRevokeWitnessSize, false, nil
case TaprootHtlcOfferedRevoke:
return TaprootOfferedRevokeWitnessSize, false, nil
case TaprootHtlcOfferedRemoteTimeout:
return TaprootHtlcOfferedRemoteTimeoutWitnessSize, false, nil
case TaprootHtlcLocalOfferedTimeout:
return TaprootOfferedLocalTimeoutWitnessSize, false, nil
case TaprootHtlcAcceptedRemoteSuccess:
return TaprootHtlcAcceptedRemoteSuccessWitnessSize, false, nil
case TaprootHtlcAcceptedLocalSuccess:
return TaprootHtlcAcceptedLocalSuccessWitnessSize, false, nil
case TaprootCommitmentRevoke:
return TaprootToLocalRevokeWitnessSize, false, nil
}
return 0, false, fmt.Errorf("unexpected witness type: %v", wt)
}
// AddWeightEstimation adds the estimated size of the witness in bytes to the
// given weight estimator.
//
// NOTE: This is part of the WitnessType interface.
func (wt StandardWitnessType) AddWeightEstimation(e *TxWeightEstimator) error {
// For fee estimation purposes, we'll now attempt to obtain an
// upper bound on the weight this input will add when fully
// populated.
size, isNestedP2SH, err := wt.SizeUpperBound()
if err != nil {
return err
}
// If this is a nested P2SH input, then we'll need to factor in
// the additional data push within the sigScript.
if isNestedP2SH {
e.AddNestedP2WSHInput(size)
} else {
e.AddWitnessInput(size)
}
return nil
}
// Copyright (c) 2013-2022 The btcsuite developers
package musig2v040
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
)
var (
// ErrSignersNotSpecified is returned when a caller attempts to create
// a context without specifying either the total number of signers, or
// the complete set of singers.
ErrSignersNotSpecified = fmt.Errorf("total number of signers or all " +
"signers must be known")
// ErrSignerNotInKeySet is returned when a the private key for a signer
// isn't included in the set of signing public keys.
ErrSignerNotInKeySet = fmt.Errorf("signing key is not found in key" +
" set")
// ErrAlredyHaveAllNonces is called when RegisterPubNonce is called too
// many times for a given signing session.
ErrAlredyHaveAllNonces = fmt.Errorf("already have all nonces")
// ErrNotEnoughSigners is returned when a caller attempts to create a
// session from a context, but before all the required signers are
// known.
ErrNotEnoughSigners = fmt.Errorf("not enough signers")
// ErrAlredyHaveAllNonces is returned when a caller attempts to
// register a signer, once we already have the total set of known
// signers.
ErrAlreadyHaveAllSigners = fmt.Errorf("all signers registered")
// ErrAlredyHaveAllSigs is called when CombineSig is called too many
// times for a given signing session.
ErrAlredyHaveAllSigs = fmt.Errorf("already have all sigs")
// ErrSigningContextReuse is returned if a user attempts to sign using
// the same signing context more than once.
ErrSigningContextReuse = fmt.Errorf("nonce already used")
// ErrFinalSigInvalid is returned when the combined signature turns out
// to be invalid.
ErrFinalSigInvalid = fmt.Errorf("final signature is invalid")
// ErrCombinedNonceUnavailable is returned when a caller attempts to
// sign a partial signature, without first having collected all the
// required combined nonces.
ErrCombinedNonceUnavailable = fmt.Errorf("missing combined nonce")
// ErrTaprootInternalKeyUnavailable is returned when a user attempts to
// obtain the
ErrTaprootInternalKeyUnavailable = fmt.Errorf("taproot tweak not used")
// ErrNotEnoughSigners is returned if a caller attempts to obtain an
// early nonce when it wasn't specified
ErrNoEarlyNonce = fmt.Errorf("no early nonce available")
)
// Context is a managed signing context for musig2. It takes care of things
// like securely generating secret nonces, aggregating keys and nonces, etc.
type Context struct {
// signingKey is the key we'll use for signing.
signingKey *btcec.PrivateKey
// pubKey is our even-y coordinate public key.
pubKey *btcec.PublicKey
// combinedKey is the aggregated public key.
combinedKey *AggregateKey
// uniqueKeyIndex is the index of the second unique key in the keySet.
// This is used to speed up signing and verification computations.
uniqueKeyIndex int
// keysHash is the hash of all the keys as defined in musig2.
keysHash []byte
// opts is the set of options for the context.
opts *contextOptions
// shouldSort keeps track of if the public keys should be sorted before
// any operations.
shouldSort bool
// sessionNonce will be populated if the earlyNonce option is true.
// After the first session is created, this nonce will be blanked out.
sessionNonce *Nonces
}
// ContextOption is a functional option argument that allows callers to modify
// the musig2 signing is done within a context.
type ContextOption func(*contextOptions)
// contextOptions houses the set of functional options that can be used to
// musig2 signing protocol.
type contextOptions struct {
// tweaks is the set of optinoal tweaks to apply to the combined public
// key.
tweaks []KeyTweakDesc
// taprootTweak specifies the taproot tweak. If specified, then we'll
// use this as the script root for the BIP 341 taproot (x-only) tweak.
// Normally we'd just apply the raw 32 byte tweak, but for taproot, we
// first need to compute the aggregated key before tweaking, and then
// use it as the internal key. This is required as the taproot tweak
// also commits to the public key, which in this case is the aggregated
// key before the tweak.
taprootTweak []byte
// bip86Tweak if true, then the weak will just be
// h_tapTweak(internalKey) as there is no true script root.
bip86Tweak bool
// keySet is the complete set of signers for this context.
keySet []*btcec.PublicKey
// numSigners is the total number of signers that will eventually be a
// part of the context.
numSigners int
// earlyNonce determines if a nonce should be generated during context
// creation, to be automatically passed to the created session.
earlyNonce bool
}
// defaultContextOptions returns the default context options.
func defaultContextOptions() *contextOptions {
return &contextOptions{}
}
// WithTweakedContext specifies that within the context, the aggregated public
// key should be tweaked with the specified tweaks.
func WithTweakedContext(tweaks ...KeyTweakDesc) ContextOption {
return func(o *contextOptions) {
o.tweaks = tweaks
}
}
// WithTaprootTweakCtx specifies that within this context, the final key should
// use the taproot tweak as defined in BIP 341: outputKey = internalKey +
// h_tapTweak(internalKey || scriptRoot). In this case, the aggreaged key
// before the tweak will be used as the internal key.
func WithTaprootTweakCtx(scriptRoot []byte) ContextOption {
return func(o *contextOptions) {
o.taprootTweak = scriptRoot
}
}
// WithBip86TweakCtx specifies that within this context, the final key should
// use the taproot tweak as defined in BIP 341, with the BIP 86 modification:
// outputKey = internalKey + h_tapTweak(internalKey)*G. In this case, the
// aggreaged key before the tweak will be used as the internal key.
func WithBip86TweakCtx() ContextOption {
return func(o *contextOptions) {
o.bip86Tweak = true
}
}
// WithKnownSigners is an optional parameter that should be used if a session
// can be created as soon as all the singers are known.
func WithKnownSigners(signers []*btcec.PublicKey) ContextOption {
return func(o *contextOptions) {
o.keySet = signers
o.numSigners = len(signers)
}
}
// WithNumSigners is a functional option used to specify that a context should
// be created without knowing all the signers. Instead the total number of
// signers is specified to ensure that a session can only be created once all
// the signers are known.
//
// NOTE: Either WithKnownSigners or WithNumSigners MUST be specified.
func WithNumSigners(n int) ContextOption {
return func(o *contextOptions) {
o.numSigners = n
}
}
// WithEarlyNonceGen allow a caller to specify that a nonce should be generated
// early, before the session is created. This should be used in protocols that
// require some partial nonce exchange before all the signers are known.
//
// NOTE: This option must only be specified with the WithNumSigners option.
func WithEarlyNonceGen() ContextOption {
return func(o *contextOptions) {
o.earlyNonce = true
}
}
// NewContext creates a new signing context with the passed singing key and set
// of public keys for each of the other signers.
//
// NOTE: This struct should be used over the raw Sign API whenever possible.
func NewContext(signingKey *btcec.PrivateKey, shouldSort bool,
ctxOpts ...ContextOption) (*Context, error) {
// First, parse the set of optional context options.
opts := defaultContextOptions()
for _, option := range ctxOpts {
option(opts)
}
pubKey, err := schnorr.ParsePubKey(
schnorr.SerializePubKey(signingKey.PubKey()),
)
if err != nil {
return nil, err
}
ctx := &Context{
signingKey: signingKey,
pubKey: pubKey,
opts: opts,
shouldSort: shouldSort,
}
switch {
// We know all the signers, so we can compute the aggregated key, along
// with all the other intermediate state we need to do signing and
// verification.
case opts.keySet != nil:
if err := ctx.combineSignerKeys(); err != nil {
return nil, err
}
// The total signers are known, so we add ourselves, and skip key
// aggregation.
case opts.numSigners != 0:
// Otherwise, we'll add ourselves as the only known signer, and
// await further calls to RegisterSigner before a session can
// be created.
opts.keySet = make([]*btcec.PublicKey, 0, opts.numSigners)
opts.keySet = append(opts.keySet, pubKey)
// If early nonce generation is specified, then we'll generate
// the nonce now to pass in to the session once all the callers
// are known.
if opts.earlyNonce {
ctx.sessionNonce, err = GenNonces()
if err != nil {
return nil, err
}
}
default:
return nil, ErrSignersNotSpecified
}
return ctx, nil
}
// combineSignerKeys is used to compute the aggregated signer key once all the
// signers are known.
func (c *Context) combineSignerKeys() error {
// As a sanity check, make sure the signing key is actually
// amongst the sit of signers.
var keyFound bool
for _, key := range c.opts.keySet {
if key.IsEqual(c.pubKey) {
keyFound = true
break
}
}
if !keyFound {
return ErrSignerNotInKeySet
}
// Now that we know that we're actually a signer, we'll
// generate the key hash finger print and second unique key
// index so we can speed up signing later.
c.keysHash = keyHashFingerprint(c.opts.keySet, c.shouldSort)
c.uniqueKeyIndex = secondUniqueKeyIndex(
c.opts.keySet, c.shouldSort,
)
keyAggOpts := []KeyAggOption{
WithKeysHash(c.keysHash),
WithUniqueKeyIndex(c.uniqueKeyIndex),
}
switch {
case c.opts.bip86Tweak:
keyAggOpts = append(
keyAggOpts, WithBIP86KeyTweak(),
)
case c.opts.taprootTweak != nil:
keyAggOpts = append(
keyAggOpts, WithTaprootKeyTweak(c.opts.taprootTweak),
)
case len(c.opts.tweaks) != 0:
keyAggOpts = append(keyAggOpts, WithKeyTweaks(c.opts.tweaks...))
}
// Next, we'll use this information to compute the aggregated
// public key that'll be used for signing in practice.
var err error
c.combinedKey, _, _, err = AggregateKeys(
c.opts.keySet, c.shouldSort, keyAggOpts...,
)
if err != nil {
return err
}
return nil
}
// EarlySessionNonce returns the early session nonce, if available.
func (c *Context) EarlySessionNonce() (*Nonces, error) {
if c.sessionNonce == nil {
return nil, ErrNoEarlyNonce
}
return c.sessionNonce, nil
}
// RegisterSigner allows a caller to register a signer after the context has
// been created. This will be used in scenarios where the total number of
// signers is known, but nonce exchange needs to happen before all the signers
// are known.
//
// A bool is returned which indicates if all the signers have been registered.
//
// NOTE: If the set of keys are not to be sorted during signing, then the
// ordering each key is registered with MUST match the desired ordering.
func (c *Context) RegisterSigner(pub *btcec.PublicKey) (bool, error) {
haveAllSigners := len(c.opts.keySet) == c.opts.numSigners
if haveAllSigners {
return false, ErrAlreadyHaveAllSigners
}
c.opts.keySet = append(c.opts.keySet, pub)
// If we have the expected number of signers at this point, then we can
// generate the aggregated key and other necessary information.
haveAllSigners = len(c.opts.keySet) == c.opts.numSigners
if haveAllSigners {
if err := c.combineSignerKeys(); err != nil {
return false, err
}
}
return haveAllSigners, nil
}
// NumRegisteredSigners returns the total number of registered signers.
func (c *Context) NumRegisteredSigners() int {
return len(c.opts.keySet)
}
// CombinedKey returns the combined public key that will be used to generate
// multi-signatures against.
func (c *Context) CombinedKey() (*btcec.PublicKey, error) {
// If the caller hasn't registered all the signers at this point, then
// the combined key won't be available.
if c.combinedKey == nil {
return nil, ErrNotEnoughSigners
}
return c.combinedKey.FinalKey, nil
}
// PubKey returns the public key of the signer of this session.
func (c *Context) PubKey() btcec.PublicKey {
return *c.pubKey
}
// SigningKeys returns the set of keys used for signing.
func (c *Context) SigningKeys() []*btcec.PublicKey {
keys := make([]*btcec.PublicKey, len(c.opts.keySet))
copy(keys, c.opts.keySet)
return keys
}
// TaprootInternalKey returns the internal taproot key, which is the aggregated
// key _before_ the tweak is applied. If a taproot tweak was specified, then
// CombinedKey() will return the fully tweaked output key, with this method
// returning the internal key. If a taproot tweak wasn't specified, then this
// method will return an error.
func (c *Context) TaprootInternalKey() (*btcec.PublicKey, error) {
// If the caller hasn't registered all the signers at this point, then
// the combined key won't be available.
if c.combinedKey == nil {
return nil, ErrNotEnoughSigners
}
if c.opts.taprootTweak == nil && !c.opts.bip86Tweak {
return nil, ErrTaprootInternalKeyUnavailable
}
return c.combinedKey.PreTweakedKey, nil
}
// SessionOption is a functional option argument that allows callers to modify
// the musig2 signing is done within a session.
type SessionOption func(*sessionOptions)
// sessionOptions houses the set of functional options that can be used to
// modify the musig2 signing protocol.
type sessionOptions struct {
externalNonce *Nonces
}
// defaultSessionOptions returns the default session options.
func defaultSessionOptions() *sessionOptions {
return &sessionOptions{}
}
// WithPreGeneratedNonce allows a caller to start a session using a nonce
// they've generated themselves. This may be useful in protocols where all the
// signer keys may not be known before nonce exchange needs to occur.
func WithPreGeneratedNonce(nonce *Nonces) SessionOption {
return func(o *sessionOptions) {
o.externalNonce = nonce
}
}
// Session represents a musig2 signing session. A new instance should be
// created each time a multi-signature is needed. The session struct handles
// nonces management, incremental partial sig vitrifaction, as well as final
// signature combination. Errors are returned when unsafe behavior such as
// nonce re-use is attempted.
//
// NOTE: This struct should be used over the raw Sign API whenever possible.
type Session struct {
opts *sessionOptions
ctx *Context
localNonces *Nonces
pubNonces [][PubNonceSize]byte
combinedNonce *[PubNonceSize]byte
msg [32]byte
ourSig *PartialSignature
sigs []*PartialSignature
finalSig *schnorr.Signature
}
// NewSession creates a new musig2 signing session.
func (c *Context) NewSession(options ...SessionOption) (*Session, error) {
opts := defaultSessionOptions()
for _, opt := range options {
opt(opts)
}
// At this point we verify that we know of all the signers, as
// otherwise we can't proceed with the session. This check is intended
// to catch misuse of the API wherein a caller forgets to register the
// remaining signers if they're doing nonce generation ahead of time.
if len(c.opts.keySet) != c.opts.numSigners {
return nil, ErrNotEnoughSigners
}
// If an early nonce was specified, then we'll automatically add the
// corresponding session option for the caller.
var localNonces *Nonces
if c.sessionNonce != nil {
// Apply the early nonce to the session, and also blank out the
// session nonce on the context to ensure it isn't ever re-used
// for another session.
localNonces = c.sessionNonce
c.sessionNonce = nil
} else if opts.externalNonce != nil {
// Otherwise if there's a custom nonce passed in via the
// session options, then use that instead.
localNonces = opts.externalNonce
}
// Now that we know we have enough signers, we'll either use the caller
// specified nonce, or generate a fresh set.
var err error
if localNonces == nil {
// At this point we need to generate a fresh nonce. We'll pass
// in some auxiliary information to strengthen the nonce
// generated.
localNonces, err = GenNonces(
WithNonceSecretKeyAux(c.signingKey),
WithNonceCombinedKeyAux(c.combinedKey.FinalKey),
)
if err != nil {
return nil, err
}
}
s := &Session{
opts: opts,
ctx: c,
localNonces: localNonces,
pubNonces: make([][PubNonceSize]byte, 0, c.opts.numSigners),
sigs: make([]*PartialSignature, 0, c.opts.numSigners),
}
s.pubNonces = append(s.pubNonces, localNonces.PubNonce)
return s, nil
}
// PublicNonce returns the public nonce for a signer. This should be sent to
// other parties before signing begins, so they can compute the aggregated
// public nonce.
func (s *Session) PublicNonce() [PubNonceSize]byte {
return s.localNonces.PubNonce
}
// NumRegisteredNonces returns the total number of nonces that have been
// regsitered so far.
func (s *Session) NumRegisteredNonces() int {
return len(s.pubNonces)
}
// RegisterPubNonce should be called for each public nonce from the set of
// signers. This method returns true once all the public nonces have been
// accounted for.
func (s *Session) RegisterPubNonce(nonce [PubNonceSize]byte) (bool, error) {
// If we already have all the nonces, then this method was called too
// many times.
haveAllNonces := len(s.pubNonces) == s.ctx.opts.numSigners
if haveAllNonces {
return false, ErrAlredyHaveAllNonces
}
// Add this nonce and check again if we already have tall the nonces we
// need.
s.pubNonces = append(s.pubNonces, nonce)
haveAllNonces = len(s.pubNonces) == s.ctx.opts.numSigners
// If we have all the nonces, then we can go ahead and combine them
// now.
if haveAllNonces {
combinedNonce, err := AggregateNonces(s.pubNonces)
if err != nil {
return false, err
}
s.combinedNonce = &combinedNonce
}
return haveAllNonces, nil
}
// Sign generates a partial signature for the target message, using the target
// context. If this method is called more than once per context, then an error
// is returned, as that means a nonce was re-used.
func (s *Session) Sign(msg [32]byte,
signOpts ...SignOption) (*PartialSignature, error) {
switch {
// If no local nonce is present, then this means we already signed, so
// we'll return an error to prevent nonce re-use.
case s.localNonces == nil:
return nil, ErrSigningContextReuse
// We also need to make sure we have the combined nonce, otherwise this
// function was called too early.
case s.combinedNonce == nil:
return nil, ErrCombinedNonceUnavailable
}
switch {
case s.ctx.opts.bip86Tweak:
signOpts = append(
signOpts, WithBip86SignTweak(),
)
case s.ctx.opts.taprootTweak != nil:
signOpts = append(
signOpts, WithTaprootSignTweak(s.ctx.opts.taprootTweak),
)
case len(s.ctx.opts.tweaks) != 0:
signOpts = append(signOpts, WithTweaks(s.ctx.opts.tweaks...))
}
partialSig, err := Sign(
s.localNonces.SecNonce, s.ctx.signingKey, *s.combinedNonce,
s.ctx.opts.keySet, msg, signOpts...,
)
// Now that we've generated our signature, we'll make sure to blank out
// our signing nonce.
s.localNonces = nil
if err != nil {
return nil, err
}
s.msg = msg
s.ourSig = partialSig
s.sigs = append(s.sigs, partialSig)
return partialSig, nil
}
// CombineSig buffers a partial signature received from a signing party. The
// method returns true once all the signatures are available, and can be
// combined into the final signature.
func (s *Session) CombineSig(sig *PartialSignature) (bool, error) {
// First check if we already have all the signatures we need. We
// already accumulated our own signature when we generated the sig.
haveAllSigs := len(s.sigs) == len(s.ctx.opts.keySet)
if haveAllSigs {
return false, ErrAlredyHaveAllSigs
}
// TODO(roasbeef): incremental check for invalid sig, or just detect at
// the very end?
// Accumulate this sig, and check again if we have all the sigs we
// need.
s.sigs = append(s.sigs, sig)
haveAllSigs = len(s.sigs) == len(s.ctx.opts.keySet)
// If we have all the signatures, then we can combine them all into the
// final signature.
if haveAllSigs {
var combineOpts []CombineOption
switch {
case s.ctx.opts.bip86Tweak:
combineOpts = append(
combineOpts, WithBip86TweakedCombine(
s.msg, s.ctx.opts.keySet,
s.ctx.shouldSort,
),
)
case s.ctx.opts.taprootTweak != nil:
combineOpts = append(
combineOpts, WithTaprootTweakedCombine(
s.msg, s.ctx.opts.keySet,
s.ctx.opts.taprootTweak, s.ctx.shouldSort,
),
)
case len(s.ctx.opts.tweaks) != 0:
combineOpts = append(
combineOpts, WithTweakedCombine(
s.msg, s.ctx.opts.keySet,
s.ctx.opts.tweaks, s.ctx.shouldSort,
),
)
}
finalSig := CombineSigs(s.ourSig.R, s.sigs, combineOpts...)
// We'll also verify the signature at this point to ensure it's
// valid.
//
// TODO(roasbef): allow skipping?
if !finalSig.Verify(s.msg[:], s.ctx.combinedKey.FinalKey) {
return false, ErrFinalSigInvalid
}
s.finalSig = finalSig
}
return haveAllSigs, nil
}
// FinalSig returns the final combined multi-signature, if present.
func (s *Session) FinalSig() *schnorr.Signature {
return s.finalSig
}
// Copyright 2013-2022 The btcsuite developers
package musig2v040
import (
"bytes"
"fmt"
"sort"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
)
var (
// KeyAggTagList is the tagged hash tag used to compute the hash of the
// list of sorted public keys.
KeyAggTagList = []byte("KeyAgg list")
// KeyAggTagCoeff is the tagged hash tag used to compute the key
// aggregation coefficient for each key.
KeyAggTagCoeff = []byte("KeyAgg coefficient")
// ErrTweakedKeyIsInfinity is returned if while tweaking a key, we end
// up with the point at infinity.
ErrTweakedKeyIsInfinity = fmt.Errorf("tweaked key is infinity point")
// ErrTweakedKeyOverflows is returned if a tweaking key is larger than
// 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141.
ErrTweakedKeyOverflows = fmt.Errorf("tweaked key is to large")
)
// sortableKeys defines a type of slice of public keys that implements the sort
// interface for BIP 340 keys.
type sortableKeys []*btcec.PublicKey
// Less reports whether the element with index i must sort before the element
// with index j.
func (s sortableKeys) Less(i, j int) bool {
// TODO(roasbeef): more efficient way to compare...
keyIBytes := schnorr.SerializePubKey(s[i])
keyJBytes := schnorr.SerializePubKey(s[j])
return bytes.Compare(keyIBytes, keyJBytes) == -1
}
// Swap swaps the elements with indexes i and j.
func (s sortableKeys) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// Len is the number of elements in the collection.
func (s sortableKeys) Len() int {
return len(s)
}
// sortKeys takes a set of schnorr public keys and returns a new slice that is
// a copy of the keys sorted in lexicographical order bytes on the x-only
// pubkey serialization.
func sortKeys(keys []*btcec.PublicKey) []*btcec.PublicKey {
keySet := sortableKeys(keys)
if sort.IsSorted(keySet) {
return keys
}
sort.Sort(keySet)
return keySet
}
// keyHashFingerprint computes the tagged hash of the series of (sorted) public
// keys passed as input. This is used to compute the aggregation coefficient
// for each key. The final computation is:
// - H(tag=KeyAgg list, pk1 || pk2..)
func keyHashFingerprint(keys []*btcec.PublicKey, sort bool) []byte {
if sort {
keys = sortKeys(keys)
}
// We'll create a single buffer and slice into that so the bytes buffer
// doesn't continually need to grow the underlying buffer.
keyAggBuf := make([]byte, 32*len(keys))
keyBytes := bytes.NewBuffer(keyAggBuf[0:0])
for _, key := range keys {
keyBytes.Write(schnorr.SerializePubKey(key))
}
h := chainhash.TaggedHash(KeyAggTagList, keyBytes.Bytes())
return h[:]
}
// keyBytesEqual returns true if two keys are the same from the PoV of BIP
// 340's 32-byte x-only public keys.
func keyBytesEqual(a, b *btcec.PublicKey) bool {
return bytes.Equal(
schnorr.SerializePubKey(a),
schnorr.SerializePubKey(b),
)
}
// aggregationCoefficient computes the key aggregation coefficient for the
// specified target key. The coefficient is computed as:
// - H(tag=KeyAgg coefficient, keyHashFingerprint(pks) || pk)
func aggregationCoefficient(keySet []*btcec.PublicKey,
targetKey *btcec.PublicKey, keysHash []byte,
secondKeyIdx int) *btcec.ModNScalar {
var mu btcec.ModNScalar
// If this is the second key, then this coefficient is just one.
if secondKeyIdx != -1 && keyBytesEqual(keySet[secondKeyIdx], targetKey) {
return mu.SetInt(1)
}
// Otherwise, we'll compute the full finger print hash for this given
// key and then use that to compute the coefficient tagged hash:
// * H(tag=KeyAgg coefficient, keyHashFingerprint(pks, pk) || pk)
var coefficientBytes [64]byte
copy(coefficientBytes[:], keysHash[:])
copy(coefficientBytes[32:], schnorr.SerializePubKey(targetKey))
muHash := chainhash.TaggedHash(KeyAggTagCoeff, coefficientBytes[:])
mu.SetByteSlice(muHash[:])
return &mu
}
// secondUniqueKeyIndex returns the index of the second unique key. If all keys
// are the same, then a value of -1 is returned.
func secondUniqueKeyIndex(keySet []*btcec.PublicKey, sort bool) int {
if sort {
keySet = sortKeys(keySet)
}
// Find the first key that isn't the same as the very first key (second
// unique key).
for i := range keySet {
if !keyBytesEqual(keySet[i], keySet[0]) {
return i
}
}
// A value of negative one is used to indicate that all the keys in the
// sign set are actually equal, which in practice actually makes musig2
// useless, but we need a value to distinguish this case.
return -1
}
// KeyTweakDesc describes a tweak to be applied to the aggregated public key
// generation and signing process. The IsXOnly specifies if the target key
// should be converted to an x-only public key before tweaking.
type KeyTweakDesc struct {
// Tweak is the 32-byte value that will modify the public key.
Tweak [32]byte
// IsXOnly if true, then the public key will be mapped to an x-only key
// before the tweaking operation is applied.
IsXOnly bool
}
// KeyAggOption is a functional option argument that allows callers to specify
// more or less information that has been pre-computed to the main routine.
type KeyAggOption func(*keyAggOption)
// keyAggOption houses the set of functional options that modify key
// aggregation.
type keyAggOption struct {
// keyHash is the output of keyHashFingerprint for a given set of keys.
keyHash []byte
// uniqueKeyIndex is the pre-computed index of the second unique key.
uniqueKeyIndex *int
// tweaks specifies a series of tweaks to be applied to the aggregated
// public key.
tweaks []KeyTweakDesc
// taprootTweak controls if the tweaks above should be applied in a BIP
// 340 style.
taprootTweak bool
// bip86Tweak specifies that the taproot tweak should be done in a BIP
// 86 style, where we don't expect an actual tweak and instead just
// commit to the public key itself.
bip86Tweak bool
}
// WithKeysHash allows key aggregation to be optimize, by allowing the caller
// to specify the hash of all the keys.
func WithKeysHash(keyHash []byte) KeyAggOption {
return func(o *keyAggOption) {
o.keyHash = keyHash
}
}
// WithUniqueKeyIndex allows the caller to specify the index of the second
// unique key.
func WithUniqueKeyIndex(idx int) KeyAggOption {
return func(o *keyAggOption) {
i := idx
o.uniqueKeyIndex = &i
}
}
// WithKeyTweaks allows a caller to specify a series of 32-byte tweaks that
// should be applied to the final aggregated public key.
func WithKeyTweaks(tweaks ...KeyTweakDesc) KeyAggOption {
return func(o *keyAggOption) {
o.tweaks = tweaks
}
}
// WithTaprootKeyTweak specifies that within this context, the final key should
// use the taproot tweak as defined in BIP 341: outputKey = internalKey +
// h_tapTweak(internalKey || scriptRoot). In this case, the aggregated key
// before the tweak will be used as the internal key.
//
// This option should be used instead of WithKeyTweaks when the aggregated key
// is intended to be used as a taproot output key that commits to a script
// root.
func WithTaprootKeyTweak(scriptRoot []byte) KeyAggOption {
return func(o *keyAggOption) {
var tweak [32]byte
copy(tweak[:], scriptRoot[:])
o.tweaks = []KeyTweakDesc{
{
Tweak: tweak,
IsXOnly: true,
},
}
o.taprootTweak = true
}
}
// WithBIP86KeyTweak specifies that then during key aggregation, the BIP 86
// tweak which just commits to the hash of the serialized public key should be
// used. This option should be used when signing with a key that was derived
// using BIP 86.
func WithBIP86KeyTweak() KeyAggOption {
return func(o *keyAggOption) {
o.tweaks = []KeyTweakDesc{
{
IsXOnly: true,
},
}
o.taprootTweak = true
o.bip86Tweak = true
}
}
// defaultKeyAggOptions returns the set of default arguments for key
// aggregation.
func defaultKeyAggOptions() *keyAggOption {
return &keyAggOption{}
}
// hasEvenY returns true if the affine representation of the passed jacobian
// point has an even y coordinate.
//
// TODO(roasbeef): double check, can just check the y coord even not jacobian?
func hasEvenY(pJ btcec.JacobianPoint) bool {
pJ.ToAffine()
p := btcec.NewPublicKey(&pJ.X, &pJ.Y)
keyBytes := p.SerializeCompressed()
return keyBytes[0] == secp.PubKeyFormatCompressedEven
}
// tweakKey applies a tweaks to the passed public key using the specified
// tweak. The parityAcc and tweakAcc are returned (in that order) which
// includes the accumulate ration of the parity factor and the tweak multiplied
// by the parity factor. The xOnly bool specifies if this is to be an x-only
// tweak or not.
func tweakKey(keyJ btcec.JacobianPoint, parityAcc btcec.ModNScalar, tweak [32]byte,
tweakAcc btcec.ModNScalar,
xOnly bool) (btcec.JacobianPoint, btcec.ModNScalar, btcec.ModNScalar, error) {
// First we'll compute the new parity factor for this key. If the key has
// an odd y coordinate (not even), then we'll need to negate it (multiply
// by -1 mod n, in this case).
var parityFactor btcec.ModNScalar
if xOnly && !hasEvenY(keyJ) {
parityFactor.SetInt(1).Negate()
} else {
parityFactor.SetInt(1)
}
// Next, map the tweak into a mod n integer so we can use it for
// manipulations below.
tweakInt := new(btcec.ModNScalar)
overflows := tweakInt.SetBytes(&tweak)
if overflows == 1 {
return keyJ, parityAcc, tweakAcc, ErrTweakedKeyOverflows
}
// Next, we'll compute: Q_i = g*Q + t*G, where g is our parityFactor and t
// is the tweakInt above. We'll space things out a bit to make it easier to
// follow.
//
// First compute t*G:
var tweakedGenerator btcec.JacobianPoint
btcec.ScalarBaseMultNonConst(tweakInt, &tweakedGenerator)
// Next compute g*Q:
btcec.ScalarMultNonConst(&parityFactor, &keyJ, &keyJ)
// Finally add both of them together to get our final
// tweaked point.
btcec.AddNonConst(&tweakedGenerator, &keyJ, &keyJ)
// As a sanity check, make sure that we didn't just end up with the
// point at infinity.
if keyJ == infinityPoint {
return keyJ, parityAcc, tweakAcc, ErrTweakedKeyIsInfinity
}
// As a final wrap up step, we'll accumulate the parity
// factor and also this tweak into the final set of accumulators.
parityAcc.Mul(&parityFactor)
tweakAcc.Mul(&parityFactor).Add(tweakInt)
return keyJ, parityAcc, tweakAcc, nil
}
// AggregateKey is a final aggregated key along with a possible version of the
// key without any tweaks applied.
type AggregateKey struct {
// FinalKey is the final aggregated key which may include one or more
// tweaks applied to it.
FinalKey *btcec.PublicKey
// PreTweakedKey is the aggregated *before* any tweaks have been
// applied. This should be used as the internal key in taproot
// contexts.
PreTweakedKey *btcec.PublicKey
}
// AggregateKeys takes a list of possibly unsorted keys and returns a single
// aggregated key as specified by the musig2 key aggregation algorithm. A nil
// value can be passed for keyHash, which causes this function to re-derive it.
// In addition to the combined public key, the parity accumulator and the tweak
// accumulator are returned as well.
func AggregateKeys(keys []*btcec.PublicKey, sort bool,
keyOpts ...KeyAggOption) (
*AggregateKey, *btcec.ModNScalar, *btcec.ModNScalar, error) {
// First, parse the set of optional signing options.
opts := defaultKeyAggOptions()
for _, option := range keyOpts {
option(opts)
}
// Sort the set of public key so we know we're working with them in
// sorted order for all the routines below.
if sort {
keys = sortKeys(keys)
}
// The caller may provide the hash of all the keys as an optimization
// during signing, as it already needs to be computed.
if opts.keyHash == nil {
opts.keyHash = keyHashFingerprint(keys, sort)
}
// A caller may also specify the unique key index themselves so we
// don't need to re-compute it.
if opts.uniqueKeyIndex == nil {
idx := secondUniqueKeyIndex(keys, sort)
opts.uniqueKeyIndex = &idx
}
// For each key, we'll compute the intermediate blinded key: a_i*P_i,
// where a_i is the aggregation coefficient for that key, and P_i is
// the key itself, then accumulate that (addition) into the main final
// key: P = P_1 + P_2 ... P_N.
var finalKeyJ btcec.JacobianPoint
for _, key := range keys {
// Port the key over to Jacobian coordinates as we need it in
// this format for the routines below.
var keyJ btcec.JacobianPoint
key.AsJacobian(&keyJ)
// Compute the aggregation coefficient for the key, then
// multiply it by the key itself: P_i' = a_i*P_i.
var tweakedKeyJ btcec.JacobianPoint
a := aggregationCoefficient(
keys, key, opts.keyHash, *opts.uniqueKeyIndex,
)
btcec.ScalarMultNonConst(a, &keyJ, &tweakedKeyJ)
// Finally accumulate this into the final key in an incremental
// fashion.
btcec.AddNonConst(&finalKeyJ, &tweakedKeyJ, &finalKeyJ)
}
// We'll copy over the key at this point, since this represents the
// aggregated key before any tweaks have been applied. This'll be used
// as the internal key for script path proofs.
finalKeyJ.ToAffine()
combinedKey := btcec.NewPublicKey(&finalKeyJ.X, &finalKeyJ.Y)
// At this point, if this is a taproot tweak, then we'll modify the
// base tweak value to use the BIP 341 tweak value.
if opts.taprootTweak {
// Emulate the same behavior as txscript.ComputeTaprootOutputKey
// which only operates on the x-only public key.
key, _ := schnorr.ParsePubKey(schnorr.SerializePubKey(
combinedKey,
))
// We only use the actual tweak bytes if we're not committing
// to a BIP-0086 key only spend output. Otherwise, we just
// commit to the internal key and an empty byte slice as the
// root hash.
tweakBytes := []byte{}
if !opts.bip86Tweak {
tweakBytes = opts.tweaks[0].Tweak[:]
}
// Compute the taproot key tagged hash of:
// h_tapTweak(internalKey || scriptRoot). We only do this for
// the first one, as you can only specify a single tweak when
// using the taproot mode with this API.
tapTweakHash := chainhash.TaggedHash(
chainhash.TagTapTweak, schnorr.SerializePubKey(key),
tweakBytes,
)
opts.tweaks[0].Tweak = *tapTweakHash
}
var (
err error
tweakAcc btcec.ModNScalar
parityAcc btcec.ModNScalar
)
parityAcc.SetInt(1)
// In this case we have a set of tweaks, so we'll incrementally apply
// each one, until we have our final tweaked key, and the related
// accumulators.
for i := 1; i <= len(opts.tweaks); i++ {
finalKeyJ, parityAcc, tweakAcc, err = tweakKey(
finalKeyJ, parityAcc, opts.tweaks[i-1].Tweak, tweakAcc,
opts.tweaks[i-1].IsXOnly,
)
if err != nil {
return nil, nil, nil, err
}
}
finalKeyJ.ToAffine()
finalKey := btcec.NewPublicKey(&finalKeyJ.X, &finalKeyJ.Y)
return &AggregateKey{
PreTweakedKey: combinedKey,
FinalKey: finalKey,
}, &parityAcc, &tweakAcc, nil
}
// Copyright 2013-2022 The btcsuite developers
package musig2v040
import (
"bytes"
"crypto/rand"
"encoding/binary"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
const (
// PubNonceSize is the size of the public nonces. Each public nonce is
// serialized the full compressed encoding, which uses 32 bytes for each
// nonce.
PubNonceSize = 66
// SecNonceSize is the size of the secret nonces for musig2. The secret
// nonces are the corresponding private keys to the public nonce points.
SecNonceSize = 64
)
var (
// NonceAuxTag is the tag used to optionally mix in the secret key with
// the set of aux randomness.
NonceAuxTag = []byte("MuSig/aux")
// NonceGenTag is used to generate the value (from a set of required an
// optional field) that will be used as the part of the secret nonce.
NonceGenTag = []byte("MuSig/nonce")
byteOrder = binary.BigEndian
)
// zeroSecNonce is a secret nonce that's all zeroes. This is used to check that
// we're not attempting to re-use a nonce, and also protect callers from it.
var zeroSecNonce [SecNonceSize]byte
// Nonces holds the public and secret nonces required for musig2.
//
// TODO(roasbeef): methods on this to help w/ parsing, etc?
type Nonces struct {
// PubNonce holds the two 33-byte compressed encoded points that serve
// as the public set of nonces.
PubNonce [PubNonceSize]byte
// SecNonce holds the two 32-byte scalar values that are the private
// keys to the two public nonces.
SecNonce [SecNonceSize]byte
}
// secNonceToPubNonce takes our two secrete nonces, and produces their two
// corresponding EC points, serialized in compressed format.
func secNonceToPubNonce(secNonce [SecNonceSize]byte) [PubNonceSize]byte {
var k1Mod, k2Mod btcec.ModNScalar
k1Mod.SetByteSlice(secNonce[:btcec.PrivKeyBytesLen])
k2Mod.SetByteSlice(secNonce[btcec.PrivKeyBytesLen:])
var r1, r2 btcec.JacobianPoint
btcec.ScalarBaseMultNonConst(&k1Mod, &r1)
btcec.ScalarBaseMultNonConst(&k2Mod, &r2)
// Next, we'll convert the key in jacobian format to a normal public
// key expressed in affine coordinates.
r1.ToAffine()
r2.ToAffine()
r1Pub := btcec.NewPublicKey(&r1.X, &r1.Y)
r2Pub := btcec.NewPublicKey(&r2.X, &r2.Y)
var pubNonce [PubNonceSize]byte
// The public nonces are serialized as: R1 || R2, where both keys are
// serialized in compressed format.
copy(pubNonce[:], r1Pub.SerializeCompressed())
copy(
pubNonce[btcec.PubKeyBytesLenCompressed:],
r2Pub.SerializeCompressed(),
)
return pubNonce
}
// NonceGenOption is a function option that allows callers to modify how nonce
// generation happens.
type NonceGenOption func(*nonceGenOpts)
// nonceGenOpts is the set of options that control how nonce generation
// happens.
type nonceGenOpts struct {
// randReader is what we'll use to generate a set of random bytes. If
// unspecified, then the normal crypto/rand rand.Read method will be
// used in place.
randReader io.Reader
// secretKey is an optional argument that's used to further augment the
// generated nonce by xor'ing it with this secret key.
secretKey []byte
// combinedKey is an optional argument that if specified, will be
// combined along with the nonce generation.
combinedKey []byte
// msg is an optional argument that will be mixed into the nonce
// derivation algorithm.
msg []byte
// auxInput is an optional argument that will be mixed into the nonce
// derivation algorithm.
auxInput []byte
}
// cryptoRandAdapter is an adapter struct that allows us to pass in the package
// level Read function from crypto/rand into a context that accepts an
// io.Reader.
type cryptoRandAdapter struct {
}
// Read implements the io.Reader interface for the crypto/rand package. By
// default, we always use the crypto/rand reader, but the caller is able to
// specify their own generation, which can be useful for deterministic tests.
func (c *cryptoRandAdapter) Read(p []byte) (n int, err error) {
return rand.Read(p)
}
// defaultNonceGenOpts returns the default set of nonce generation options.
func defaultNonceGenOpts() *nonceGenOpts {
return &nonceGenOpts{
randReader: &cryptoRandAdapter{},
}
}
// WithCustomRand allows a caller to use a custom random number generator in
// place for crypto/rand. This should only really be used to generate
// determinstic tests.
func WithCustomRand(r io.Reader) NonceGenOption {
return func(o *nonceGenOpts) {
o.randReader = r
}
}
// WithNonceSecretKeyAux allows a caller to optionally specify a secret key
// that should be used to augment the randomness used to generate the nonces.
func WithNonceSecretKeyAux(secKey *btcec.PrivateKey) NonceGenOption {
return func(o *nonceGenOpts) {
o.secretKey = secKey.Serialize()
}
}
// WithNonceCombinedKeyAux allows a caller to optionally specify the combined
// key used in this signing session to further augment the randomness used to
// generate nonces.
func WithNonceCombinedKeyAux(combinedKey *btcec.PublicKey) NonceGenOption {
return func(o *nonceGenOpts) {
o.combinedKey = schnorr.SerializePubKey(combinedKey)
}
}
// WithNonceMessageAux allows a caller to optionally specify a message to be
// mixed into the randomness generated to create the nonce.
func WithNonceMessageAux(msg [32]byte) NonceGenOption {
return func(o *nonceGenOpts) {
o.msg = msg[:]
}
}
// WithNonceAuxInput is a set of auxiliary randomness, similar to BIP 340 that
// can be used to further augment the nonce generation process.
func WithNonceAuxInput(aux []byte) NonceGenOption {
return func(o *nonceGenOpts) {
o.auxInput = aux
}
}
// withCustomOptions allows a caller to pass a complete set of custom
// nonceGenOpts, without needing to create custom and checked structs such as
// *btcec.PrivateKey. This is mainly used to match the testcases provided by
// the MuSig2 BIP.
func withCustomOptions(customOpts nonceGenOpts) NonceGenOption {
return func(o *nonceGenOpts) {
o.randReader = customOpts.randReader
o.secretKey = customOpts.secretKey
o.combinedKey = customOpts.combinedKey
o.msg = customOpts.msg
o.auxInput = customOpts.auxInput
}
}
// lengthWriter is a function closure that allows a caller to control how the
// length prefix of a byte slice is written.
type lengthWriter func(w io.Writer, b []byte) error
// uint8Writer is an implementation of lengthWriter that writes the length of
// the byte slice using 1 byte.
func uint8Writer(w io.Writer, b []byte) error {
return binary.Write(w, byteOrder, uint8(len(b)))
}
// uint32Writer is an implementation of lengthWriter that writes the length of
// the byte slice using 4 bytes.
func uint32Writer(w io.Writer, b []byte) error {
return binary.Write(w, byteOrder, uint32(len(b)))
}
// writeBytesPrefix is used to write out: len(b) || b, to the passed io.Writer.
// The lengthWriter function closure is used to allow the caller to specify the
// precise byte packing of the length.
func writeBytesPrefix(w io.Writer, b []byte, lenWriter lengthWriter) error {
// Write out the length of the byte first, followed by the set of bytes
// itself.
if err := lenWriter(w, b); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return nil
}
// genNonceAuxBytes writes out the full byte string used to derive a secret
// nonce based on some initial randomness as well as the series of optional
// fields. The byte string used for derivation is:
// - tagged_hash("MuSig/nonce", rand || len(aggpk) || aggpk || len(m)
// || m || len(in) || in || i).
//
// where i is the ith secret nonce being generated.
func genNonceAuxBytes(rand []byte, i int,
opts *nonceGenOpts) (*chainhash.Hash, error) {
var w bytes.Buffer
// First, write out the randomness generated in the prior step.
if _, err := w.Write(rand); err != nil {
return nil, err
}
// Next, we'll write out: len(aggpk) || aggpk.
err := writeBytesPrefix(&w, opts.combinedKey, uint8Writer)
if err != nil {
return nil, err
}
// Next, we'll write out the length prefixed message.
err = writeBytesPrefix(&w, opts.msg, uint8Writer)
if err != nil {
return nil, err
}
// Finally we'll write out the auxiliary input.
err = writeBytesPrefix(&w, opts.auxInput, uint32Writer)
if err != nil {
return nil, err
}
// Next we'll write out the interaction/index number which will
// uniquely generate two nonces given the rest of the possibly static
// parameters.
if err := binary.Write(&w, byteOrder, uint8(i)); err != nil {
return nil, err
}
// With the message buffer complete, we'll now derive the tagged hash
// using our set of params.
return chainhash.TaggedHash(NonceGenTag, w.Bytes()), nil
}
// GenNonces generates the secret nonces, as well as the public nonces which
// correspond to an EC point generated using the secret nonce as a private key.
func GenNonces(options ...NonceGenOption) (*Nonces, error) {
opts := defaultNonceGenOpts()
for _, opt := range options {
opt(opts)
}
// First, we'll start out by generating 32 random bytes drawn from our
// CSPRNG.
var randBytes [32]byte
if _, err := opts.randReader.Read(randBytes[:]); err != nil {
return nil, err
}
// If the options contain a secret key, we XOR it with with the tagged
// random bytes.
if len(opts.secretKey) == 32 {
taggedHash := chainhash.TaggedHash(NonceAuxTag, randBytes[:])
for i := 0; i < chainhash.HashSize; i++ {
randBytes[i] = opts.secretKey[i] ^ taggedHash[i]
}
}
// Using our randomness and the set of optional params, generate our
// two secret nonces: k1 and k2.
k1, err := genNonceAuxBytes(randBytes[:], 0, opts)
if err != nil {
return nil, err
}
k2, err := genNonceAuxBytes(randBytes[:], 1, opts)
if err != nil {
return nil, err
}
var k1Mod, k2Mod btcec.ModNScalar
k1Mod.SetBytes((*[32]byte)(k1))
k2Mod.SetBytes((*[32]byte)(k2))
// The secret nonces are serialized as the concatenation of the two 32
// byte secret nonce values.
var nonces Nonces
k1Mod.PutBytesUnchecked(nonces.SecNonce[:])
k2Mod.PutBytesUnchecked(nonces.SecNonce[btcec.PrivKeyBytesLen:])
// Next, we'll generate R_1 = k_1*G and R_2 = k_2*G. Along the way we
// need to map our nonce values into mod n scalars so we can work with
// the btcec API.
nonces.PubNonce = secNonceToPubNonce(nonces.SecNonce)
return &nonces, nil
}
// AggregateNonces aggregates the set of a pair of public nonces for each party
// into a single aggregated nonces to be used for multi-signing.
func AggregateNonces(pubNonces [][PubNonceSize]byte) ([PubNonceSize]byte, error) {
// combineNonces is a helper function that aggregates (adds) up a
// series of nonces encoded in compressed format. It uses a slicing
// function to extra 33 bytes at a time from the packed 2x public
// nonces.
type nonceSlicer func([PubNonceSize]byte) []byte
combineNonces := func(slicer nonceSlicer) (btcec.JacobianPoint, error) {
// Convert the set of nonces into jacobian coordinates we can
// use to accumulate them all into each other.
pubNonceJs := make([]*btcec.JacobianPoint, len(pubNonces))
for i, pubNonceBytes := range pubNonces {
// Using the slicer, extract just the bytes we need to
// decode.
var nonceJ btcec.JacobianPoint
nonceJ, err := btcec.ParseJacobian(slicer(pubNonceBytes))
if err != nil {
return btcec.JacobianPoint{}, err
}
pubNonceJs[i] = &nonceJ
}
// Now that we have the set of complete nonces, we'll aggregate
// them: R = R_i + R_i+1 + ... + R_i+n.
var aggregateNonce btcec.JacobianPoint
for _, pubNonceJ := range pubNonceJs {
btcec.AddNonConst(
&aggregateNonce, pubNonceJ, &aggregateNonce,
)
}
aggregateNonce.ToAffine()
return aggregateNonce, nil
}
// The final nonce public nonce is actually two nonces, one that
// aggregate the first nonce of all the parties, and the other that
// aggregates the second nonce of all the parties.
var finalNonce [PubNonceSize]byte
combinedNonce1, err := combineNonces(func(n [PubNonceSize]byte) []byte {
return n[:btcec.PubKeyBytesLenCompressed]
})
if err != nil {
return finalNonce, err
}
combinedNonce2, err := combineNonces(func(n [PubNonceSize]byte) []byte {
return n[btcec.PubKeyBytesLenCompressed:]
})
if err != nil {
return finalNonce, err
}
copy(finalNonce[:], btcec.JacobianToByteSlice(combinedNonce1))
copy(
finalNonce[btcec.PubKeyBytesLenCompressed:],
btcec.JacobianToByteSlice(combinedNonce2),
)
return finalNonce, nil
}
// Copyright 2013-2022 The btcsuite developers
package musig2v040
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
)
var (
// NonceBlindTag is that tag used to construct the value b, which
// blinds the second public nonce of each party.
NonceBlindTag = []byte("MuSig/noncecoef")
// ChallengeHashTag is the tag used to construct the challenge hash
ChallengeHashTag = []byte("BIP0340/challenge")
// ErrNoncePointAtInfinity is returned if during signing, the fully
// combined public nonce is the point at infinity.
ErrNoncePointAtInfinity = fmt.Errorf("signing nonce is the infinity " +
"point")
// ErrPrivKeyZero is returned when the private key for signing is
// actually zero.
ErrPrivKeyZero = fmt.Errorf("priv key is zero")
// ErrPartialSigInvalid is returned when a partial is found to be
// invalid.
ErrPartialSigInvalid = fmt.Errorf("partial signature is invalid")
// ErrSecretNonceZero is returned when a secret nonce is passed in a
// zero.
ErrSecretNonceZero = fmt.Errorf("secret nonce is blank")
)
// infinityPoint is the jacobian representation of the point at infinity.
var infinityPoint btcec.JacobianPoint
// PartialSignature reprints a partial (s-only) musig2 multi-signature. This
// isn't a valid schnorr signature by itself, as it needs to be aggregated
// along with the other partial signatures to be completed.
type PartialSignature struct {
S *btcec.ModNScalar
R *btcec.PublicKey
}
// NewPartialSignature returns a new instances of the partial sig struct.
func NewPartialSignature(s *btcec.ModNScalar,
r *btcec.PublicKey) PartialSignature {
return PartialSignature{
S: s,
R: r,
}
}
// Encode writes a serialized version of the partial signature to the passed
// io.Writer
func (p *PartialSignature) Encode(w io.Writer) error {
var sBytes [32]byte
p.S.PutBytes(&sBytes)
if _, err := w.Write(sBytes[:]); err != nil {
return err
}
return nil
}
// Decode attempts to parse a serialized PartialSignature stored in the passed
// io reader.
func (p *PartialSignature) Decode(r io.Reader) error {
p.S = new(btcec.ModNScalar)
var sBytes [32]byte
if _, err := io.ReadFull(r, sBytes[:]); err != nil {
return nil
}
overflows := p.S.SetBytes(&sBytes)
if overflows == 1 {
return ErrPartialSigInvalid
}
return nil
}
// SignOption is a functional option argument that allows callers to modify the
// way we generate musig2 schnorr signatures.
type SignOption func(*signOptions)
// signOptions houses the set of functional options that can be used to modify
// the method used to generate the musig2 partial signature.
type signOptions struct {
// fastSign determines if we'll skip the check at the end of the
// routine where we attempt to verify the produced signature.
fastSign bool
// sortKeys determines if the set of keys should be sorted before doing
// key aggregation.
sortKeys bool
// tweaks specifies a series of tweaks to be applied to the aggregated
// public key, which also partially carries over into the signing
// process.
tweaks []KeyTweakDesc
// taprootTweak specifies a taproot specific tweak. of the tweaks
// specified above. Normally we'd just apply the raw 32 byte tweak, but
// for taproot, we first need to compute the aggregated key before
// tweaking, and then use it as the internal key. This is required as
// the taproot tweak also commits to the public key, which in this case
// is the aggregated key before the tweak.
taprootTweak []byte
// bip86Tweak specifies that the taproot tweak should be done in a BIP
// 86 style, where we don't expect an actual tweak and instead just
// commit to the public key itself.
bip86Tweak bool
}
// defaultSignOptions returns the default set of signing operations.
func defaultSignOptions() *signOptions {
return &signOptions{}
}
// WithFastSign forces signing to skip the extra verification step at the end.
// Performance sensitive applications may opt to use this option to speed up
// the signing operation.
func WithFastSign() SignOption {
return func(o *signOptions) {
o.fastSign = true
}
}
// WithSortedKeys determines if the set of signing public keys are to be sorted
// or not before doing key aggregation.
func WithSortedKeys() SignOption {
return func(o *signOptions) {
o.sortKeys = true
}
}
// WithTweaks determines if the aggregated public key used should apply a
// series of tweaks before key aggregation.
func WithTweaks(tweaks ...KeyTweakDesc) SignOption {
return func(o *signOptions) {
o.tweaks = tweaks
}
}
// WithTaprootSignTweak allows a caller to specify a tweak that should be used
// in a bip 340 manner when signing. This differs from WithTweaks as the tweak
// will be assumed to always be x-only and the intermediate aggregate key
// before tweaking will be used to generate part of the tweak (as the taproot
// tweak also commits to the internal key).
//
// This option should be used in the taproot context to create a valid
// signature for the keypath spend for taproot, when the output key is actually
// committing to a script path, or some other data.
func WithTaprootSignTweak(scriptRoot []byte) SignOption {
return func(o *signOptions) {
o.taprootTweak = scriptRoot
}
}
// WithBip86SignTweak allows a caller to specify a tweak that should be used in
// a bip 340 manner when signing, factoring in BIP 86 as well. This differs
// from WithTaprootSignTweak as no true script root will be committed to,
// instead we just commit to the internal key.
//
// This option should be used in the taproot context to create a valid
// signature for the keypath spend for taproot, when the output key was
// generated using BIP 86.
func WithBip86SignTweak() SignOption {
return func(o *signOptions) {
o.bip86Tweak = true
}
}
// Sign generates a musig2 partial signature given the passed key set, secret
// nonce, public nonce, and private keys. This method returns an error if the
// generated nonces are either too large, or end up mapping to the point at
// infinity.
func Sign(secNonce [SecNonceSize]byte, privKey *btcec.PrivateKey,
combinedNonce [PubNonceSize]byte, pubKeys []*btcec.PublicKey,
msg [32]byte, signOpts ...SignOption) (*PartialSignature, error) {
// First, parse the set of optional signing options.
opts := defaultSignOptions()
for _, option := range signOpts {
option(opts)
}
// Compute the hash of all the keys here as we'll need it do aggregate
// the keys and also at the final step of signing.
keysHash := keyHashFingerprint(pubKeys, opts.sortKeys)
uniqueKeyIndex := secondUniqueKeyIndex(pubKeys, opts.sortKeys)
keyAggOpts := []KeyAggOption{
WithKeysHash(keysHash), WithUniqueKeyIndex(uniqueKeyIndex),
}
switch {
case opts.bip86Tweak:
keyAggOpts = append(
keyAggOpts, WithBIP86KeyTweak(),
)
case opts.taprootTweak != nil:
keyAggOpts = append(
keyAggOpts, WithTaprootKeyTweak(opts.taprootTweak),
)
case len(opts.tweaks) != 0:
keyAggOpts = append(keyAggOpts, WithKeyTweaks(opts.tweaks...))
}
// Next we'll construct the aggregated public key based on the set of
// signers.
combinedKey, parityAcc, _, err := AggregateKeys(
pubKeys, opts.sortKeys, keyAggOpts...,
)
if err != nil {
return nil, err
}
// Next we'll compute the value b, that blinds our second public
// nonce:
// * b = h(tag=NonceBlindTag, combinedNonce || combinedKey || m).
var (
nonceMsgBuf bytes.Buffer
nonceBlinder btcec.ModNScalar
)
nonceMsgBuf.Write(combinedNonce[:])
nonceMsgBuf.Write(schnorr.SerializePubKey(combinedKey.FinalKey))
nonceMsgBuf.Write(msg[:])
nonceBlindHash := chainhash.TaggedHash(
NonceBlindTag, nonceMsgBuf.Bytes(),
)
nonceBlinder.SetByteSlice(nonceBlindHash[:])
// Next, we'll parse the public nonces into R1 and R2.
r1J, err := btcec.ParseJacobian(
combinedNonce[:btcec.PubKeyBytesLenCompressed],
)
if err != nil {
return nil, err
}
r2J, err := btcec.ParseJacobian(
combinedNonce[btcec.PubKeyBytesLenCompressed:],
)
if err != nil {
return nil, err
}
// With our nonce blinding value, we'll now combine both the public
// nonces, using the blinding factor to tweak the second nonce:
// * R = R_1 + b*R_2
var nonce btcec.JacobianPoint
btcec.ScalarMultNonConst(&nonceBlinder, &r2J, &r2J)
btcec.AddNonConst(&r1J, &r2J, &nonce)
// If the combined nonce it eh point at infinity, then we'll bail out.
if nonce == infinityPoint {
G := btcec.Generator()
G.AsJacobian(&nonce)
}
// Next we'll parse out our two secret nonces, which we'll be using in
// the core signing process below.
var k1, k2 btcec.ModNScalar
k1.SetByteSlice(secNonce[:btcec.PrivKeyBytesLen])
k2.SetByteSlice(secNonce[btcec.PrivKeyBytesLen:])
if k1.IsZero() || k2.IsZero() {
return nil, ErrSecretNonceZero
}
nonce.ToAffine()
nonceKey := btcec.NewPublicKey(&nonce.X, &nonce.Y)
// If the nonce R has an odd y coordinate, then we'll negate both our
// secret nonces.
if nonce.Y.IsOdd() {
k1.Negate()
k2.Negate()
}
privKeyScalar := privKey.Key
if privKeyScalar.IsZero() {
return nil, ErrPrivKeyZero
}
pubKey := privKey.PubKey()
pubKeyYIsOdd := func() bool {
pubKeyBytes := pubKey.SerializeCompressed()
return pubKeyBytes[0] == secp.PubKeyFormatCompressedOdd
}()
combinedKeyYIsOdd := func() bool {
combinedKeyBytes := combinedKey.FinalKey.SerializeCompressed()
return combinedKeyBytes[0] == secp.PubKeyFormatCompressedOdd
}()
// Next we'll compute our two parity factors for Q the combined public
// key, and P, the public key we're signing with. If the keys are odd,
// then we'll negate them.
parityCombinedKey := new(btcec.ModNScalar).SetInt(1)
paritySignKey := new(btcec.ModNScalar).SetInt(1)
if combinedKeyYIsOdd {
parityCombinedKey.Negate()
}
if pubKeyYIsOdd {
paritySignKey.Negate()
}
// Before we sign below, we'll multiply by our various parity factors
// to ensure that the signing key is properly negated (if necessary):
// * d = gv⋅gaccv⋅gp⋅d'
privKeyScalar.Mul(parityCombinedKey).Mul(paritySignKey).Mul(parityAcc)
// Next we'll create the challenge hash that commits to the combined
// nonce, combined public key and also the message:
// * e = H(tag=ChallengeHashTag, R || Q || m) mod n
var challengeMsg bytes.Buffer
challengeMsg.Write(schnorr.SerializePubKey(nonceKey))
challengeMsg.Write(schnorr.SerializePubKey(combinedKey.FinalKey))
challengeMsg.Write(msg[:])
challengeBytes := chainhash.TaggedHash(
ChallengeHashTag, challengeMsg.Bytes(),
)
var e btcec.ModNScalar
e.SetByteSlice(challengeBytes[:])
// Next, we'll compute a, our aggregation coefficient for the key that
// we're signing with.
a := aggregationCoefficient(pubKeys, pubKey, keysHash, uniqueKeyIndex)
// With mu constructed, we can finally generate our partial signature
// as: s = (k1_1 + b*k_2 + e*a*d) mod n.
s := new(btcec.ModNScalar)
s.Add(&k1).Add(k2.Mul(&nonceBlinder)).Add(e.Mul(a).Mul(&privKeyScalar))
sig := NewPartialSignature(s, nonceKey)
// If we're not in fast sign mode, then we'll also validate our partial
// signature.
if !opts.fastSign {
pubNonce := secNonceToPubNonce(secNonce)
sigValid := sig.Verify(
pubNonce, combinedNonce, pubKeys, pubKey, msg,
signOpts...,
)
if !sigValid {
return nil, fmt.Errorf("sig is invalid!")
}
}
return &sig, nil
}
// Verify implements partial signature verification given the public nonce for
// the signer, aggregate nonce, signer set and finally the message being
// signed.
func (p *PartialSignature) Verify(pubNonce [PubNonceSize]byte,
combinedNonce [PubNonceSize]byte, keySet []*btcec.PublicKey,
signingKey *btcec.PublicKey, msg [32]byte, signOpts ...SignOption) bool {
pubKey := schnorr.SerializePubKey(signingKey)
return verifyPartialSig(
p, pubNonce, combinedNonce, keySet, pubKey, msg, signOpts...,
) == nil
}
// verifyPartialSig attempts to verify a partial schnorr signature given the
// necessary parameters. This is the internal version of Verify that returns
// detailed errors. signed.
func verifyPartialSig(partialSig *PartialSignature, pubNonce [PubNonceSize]byte,
combinedNonce [PubNonceSize]byte, keySet []*btcec.PublicKey,
pubKey []byte, msg [32]byte, signOpts ...SignOption) error {
opts := defaultSignOptions()
for _, option := range signOpts {
option(opts)
}
// First we'll map the internal partial signature back into something
// we can manipulate.
s := partialSig.S
// Next we'll parse out the two public nonces into something we can
// use.
//
// Compute the hash of all the keys here as we'll need it do aggregate
// the keys and also at the final step of verification.
keysHash := keyHashFingerprint(keySet, opts.sortKeys)
uniqueKeyIndex := secondUniqueKeyIndex(keySet, opts.sortKeys)
keyAggOpts := []KeyAggOption{
WithKeysHash(keysHash), WithUniqueKeyIndex(uniqueKeyIndex),
}
switch {
case opts.bip86Tweak:
keyAggOpts = append(
keyAggOpts, WithBIP86KeyTweak(),
)
case opts.taprootTweak != nil:
keyAggOpts = append(
keyAggOpts, WithTaprootKeyTweak(opts.taprootTweak),
)
case len(opts.tweaks) != 0:
keyAggOpts = append(keyAggOpts, WithKeyTweaks(opts.tweaks...))
}
// Next we'll construct the aggregated public key based on the set of
// signers.
combinedKey, parityAcc, _, err := AggregateKeys(
keySet, opts.sortKeys, keyAggOpts...,
)
if err != nil {
return err
}
// Next we'll compute the value b, that blinds our second public
// nonce:
// * b = h(tag=NonceBlindTag, combinedNonce || combinedKey || m).
var (
nonceMsgBuf bytes.Buffer
nonceBlinder btcec.ModNScalar
)
nonceMsgBuf.Write(combinedNonce[:])
nonceMsgBuf.Write(schnorr.SerializePubKey(combinedKey.FinalKey))
nonceMsgBuf.Write(msg[:])
nonceBlindHash := chainhash.TaggedHash(NonceBlindTag, nonceMsgBuf.Bytes())
nonceBlinder.SetByteSlice(nonceBlindHash[:])
r1J, err := btcec.ParseJacobian(
combinedNonce[:btcec.PubKeyBytesLenCompressed],
)
if err != nil {
return err
}
r2J, err := btcec.ParseJacobian(
combinedNonce[btcec.PubKeyBytesLenCompressed:],
)
if err != nil {
return err
}
// With our nonce blinding value, we'll now combine both the public
// nonces, using the blinding factor to tweak the second nonce:
// * R = R_1 + b*R_2
var nonce btcec.JacobianPoint
btcec.ScalarMultNonConst(&nonceBlinder, &r2J, &r2J)
btcec.AddNonConst(&r1J, &r2J, &nonce)
// Next, we'll parse out the set of public nonces this signer used to
// generate the signature.
pubNonce1J, err := btcec.ParseJacobian(
pubNonce[:btcec.PubKeyBytesLenCompressed],
)
if err != nil {
return err
}
pubNonce2J, err := btcec.ParseJacobian(
pubNonce[btcec.PubKeyBytesLenCompressed:],
)
if err != nil {
return err
}
// If the nonce is the infinity point we set it to the Generator.
if nonce == infinityPoint {
btcec.GeneratorJacobian(&nonce)
} else {
nonce.ToAffine()
}
// We'll perform a similar aggregation and blinding operator as we did
// above for the combined nonces: R' = R_1' + b*R_2'.
var pubNonceJ btcec.JacobianPoint
btcec.ScalarMultNonConst(&nonceBlinder, &pubNonce2J, &pubNonce2J)
btcec.AddNonConst(&pubNonce1J, &pubNonce2J, &pubNonceJ)
pubNonceJ.ToAffine()
// If the combined nonce used in the challenge hash has an odd y
// coordinate, then we'll negate our final public nonce.
if nonce.Y.IsOdd() {
pubNonceJ.Y.Negate(1)
pubNonceJ.Y.Normalize()
}
// Next we'll create the challenge hash that commits to the combined
// nonce, combined public key and also the message:
// * e = H(tag=ChallengeHashTag, R || Q || m) mod n
var challengeMsg bytes.Buffer
challengeMsg.Write(schnorr.SerializePubKey(btcec.NewPublicKey(
&nonce.X, &nonce.Y,
)))
challengeMsg.Write(schnorr.SerializePubKey(combinedKey.FinalKey))
challengeMsg.Write(msg[:])
challengeBytes := chainhash.TaggedHash(
ChallengeHashTag, challengeMsg.Bytes(),
)
var e btcec.ModNScalar
e.SetByteSlice(challengeBytes[:])
signingKey, err := schnorr.ParsePubKey(pubKey)
if err != nil {
return err
}
// Next, we'll compute a, our aggregation coefficient for the key that
// we're signing with.
a := aggregationCoefficient(keySet, signingKey, keysHash, uniqueKeyIndex)
// If the combined key has an odd y coordinate, then we'll negate
// parity factor for the signing key.
paritySignKey := new(btcec.ModNScalar).SetInt(1)
combinedKeyBytes := combinedKey.FinalKey.SerializeCompressed()
if combinedKeyBytes[0] == secp.PubKeyFormatCompressedOdd {
paritySignKey.Negate()
}
// Next, we'll construct the final parity factor by multiplying the
// sign key parity factor with the accumulated parity factor for all
// the keys.
finalParityFactor := paritySignKey.Mul(parityAcc)
// Now we'll multiply the parity factor by our signing key, which'll
// take care of the amount of negation needed.
var signKeyJ btcec.JacobianPoint
signingKey.AsJacobian(&signKeyJ)
btcec.ScalarMultNonConst(finalParityFactor, &signKeyJ, &signKeyJ)
// In the final set, we'll check that: s*G == R' + e*a*P.
var sG, rP btcec.JacobianPoint
btcec.ScalarBaseMultNonConst(s, &sG)
btcec.ScalarMultNonConst(e.Mul(a), &signKeyJ, &rP)
btcec.AddNonConst(&rP, &pubNonceJ, &rP)
sG.ToAffine()
rP.ToAffine()
if sG != rP {
return ErrPartialSigInvalid
}
return nil
}
// CombineOption is a functional option argument that allows callers to modify the
// way we combine musig2 schnorr signatures.
type CombineOption func(*combineOptions)
// combineOptions houses the set of functional options that can be used to
// modify the method used to combine the musig2 partial signatures.
type combineOptions struct {
msg [32]byte
combinedKey *btcec.PublicKey
tweakAcc *btcec.ModNScalar
}
// defaultCombineOptions returns the default set of signing operations.
func defaultCombineOptions() *combineOptions {
return &combineOptions{}
}
// WithTweakedCombine is a functional option that allows callers to specify
// that the signature was produced using a tweaked aggregated public key. In
// order to properly aggregate the partial signatures, the caller must specify
// enough information to reconstruct the challenge, and also the final
// accumulated tweak value.
func WithTweakedCombine(msg [32]byte, keys []*btcec.PublicKey,
tweaks []KeyTweakDesc, sort bool) CombineOption {
return func(o *combineOptions) {
combinedKey, _, tweakAcc, _ := AggregateKeys(
keys, sort, WithKeyTweaks(tweaks...),
)
o.msg = msg
o.combinedKey = combinedKey.FinalKey
o.tweakAcc = tweakAcc
}
}
// WithTaprootTweakedCombine is similar to the WithTweakedCombine option, but
// assumes a BIP 341 context where the final tweaked key is to be used as the
// output key, where the internal key is the aggregated key pre-tweak.
//
// This option should be used over WithTweakedCombine when attempting to
// aggregate signatures for a top-level taproot keyspend, where the output key
// commits to a script root.
func WithTaprootTweakedCombine(msg [32]byte, keys []*btcec.PublicKey,
scriptRoot []byte, sort bool) CombineOption {
return func(o *combineOptions) {
combinedKey, _, tweakAcc, _ := AggregateKeys(
keys, sort, WithTaprootKeyTweak(scriptRoot),
)
o.msg = msg
o.combinedKey = combinedKey.FinalKey
o.tweakAcc = tweakAcc
}
}
// WithBip86TweakedCombine is similar to the WithTaprootTweakedCombine option,
// but assumes a BIP 341 + BIP 86 context where the final tweaked key is to be
// used as the output key, where the internal key is the aggregated key
// pre-tweak.
//
// This option should be used over WithTaprootTweakedCombine when attempting to
// aggregate signatures for a top-level taproot keyspend, where the output key
// was generated using BIP 86.
func WithBip86TweakedCombine(msg [32]byte, keys []*btcec.PublicKey,
sort bool) CombineOption {
return func(o *combineOptions) {
combinedKey, _, tweakAcc, _ := AggregateKeys(
keys, sort, WithBIP86KeyTweak(),
)
o.msg = msg
o.combinedKey = combinedKey.FinalKey
o.tweakAcc = tweakAcc
}
}
// CombineSigs combines the set of public keys given the final aggregated
// nonce, and the series of partial signatures for each nonce.
func CombineSigs(combinedNonce *btcec.PublicKey,
partialSigs []*PartialSignature,
combineOpts ...CombineOption) *schnorr.Signature {
// First, parse the set of optional combine options.
opts := defaultCombineOptions()
for _, option := range combineOpts {
option(opts)
}
// If signer keys and tweaks are specified, then we need to carry out
// some intermediate steps before we can combine the signature.
var tweakProduct *btcec.ModNScalar
if opts.combinedKey != nil && opts.tweakAcc != nil {
// Next, we'll construct the parity factor of the combined key,
// negating it if the combined key has an even y coordinate.
parityFactor := new(btcec.ModNScalar).SetInt(1)
combinedKeyBytes := opts.combinedKey.SerializeCompressed()
if combinedKeyBytes[0] == secp.PubKeyFormatCompressedOdd {
parityFactor.Negate()
}
// Next we'll reconstruct e the challenge has based on the
// nonce and combined public key.
// * e = H(tag=ChallengeHashTag, R || Q || m) mod n
var challengeMsg bytes.Buffer
challengeMsg.Write(schnorr.SerializePubKey(combinedNonce))
challengeMsg.Write(schnorr.SerializePubKey(opts.combinedKey))
challengeMsg.Write(opts.msg[:])
challengeBytes := chainhash.TaggedHash(
ChallengeHashTag, challengeMsg.Bytes(),
)
var e btcec.ModNScalar
e.SetByteSlice(challengeBytes[:])
tweakProduct = new(btcec.ModNScalar).Set(&e)
tweakProduct.Mul(opts.tweakAcc).Mul(parityFactor)
}
// Finally, the tweak factor also needs to be re-computed as well.
var combinedSig btcec.ModNScalar
for _, partialSig := range partialSigs {
combinedSig.Add(partialSig.S)
}
// If the tweak product was set above, then we'll need to add the value
// at the very end in order to produce a valid signature under the
// final tweaked key.
if tweakProduct != nil {
combinedSig.Add(tweakProduct)
}
// TODO(roasbeef): less verbose way to get the x coord...
var nonceJ btcec.JacobianPoint
combinedNonce.AsJacobian(&nonceJ)
nonceJ.ToAffine()
return schnorr.NewSignature(&nonceJ.X, &combinedSig)
}
package invoices
import (
"errors"
"fmt"
)
var (
// ErrInvoiceAlreadySettled is returned when the invoice is already
// settled.
ErrInvoiceAlreadySettled = errors.New("invoice already settled")
// ErrInvoiceAlreadyCanceled is returned when the invoice is already
// canceled.
ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled")
// ErrInvoiceAlreadyAccepted is returned when the invoice is already
// accepted.
ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted")
// ErrInvoiceStillOpen is returned when the invoice is still open.
ErrInvoiceStillOpen = errors.New("invoice still open")
// ErrInvoiceCannotOpen is returned when an attempt is made to move an
// invoice to the open state.
ErrInvoiceCannotOpen = errors.New("cannot move invoice to open")
// ErrInvoiceCannotAccept is returned when an attempt is made to accept
// an invoice while the invoice is not in the open state.
ErrInvoiceCannotAccept = errors.New("cannot accept invoice")
// ErrInvoicePreimageMismatch is returned when the preimage doesn't
// match the invoice hash.
ErrInvoicePreimageMismatch = errors.New("preimage does not match")
// ErrHTLCPreimageMissing is returned when trying to accept/settle an
// AMP HTLC but the HTLC-level preimage has not been set.
ErrHTLCPreimageMissing = errors.New("AMP htlc missing preimage")
// ErrHTLCPreimageMismatch is returned when trying to accept/settle an
// AMP HTLC but the HTLC-level preimage does not satisfying the
// HTLC-level payment hash.
ErrHTLCPreimageMismatch = errors.New("htlc preimage mismatch")
// ErrHTLCAlreadySettled is returned when trying to settle an invoice
// but HTLC already exists in the settled state.
ErrHTLCAlreadySettled = errors.New("htlc already settled")
// ErrInvoiceHasHtlcs is returned when attempting to insert an invoice
// that already has HTLCs.
ErrInvoiceHasHtlcs = errors.New("cannot add invoice with htlcs")
// ErrEmptyHTLCSet is returned when attempting to accept or settle and
// HTLC set that has no HTLCs.
ErrEmptyHTLCSet = errors.New("cannot settle/accept empty HTLC set")
// ErrUnexpectedInvoicePreimage is returned when an invoice-level
// preimage is provided when trying to settle an invoice that shouldn't
// have one, e.g. an AMP invoice.
ErrUnexpectedInvoicePreimage = errors.New(
"unexpected invoice preimage provided on settle",
)
// ErrHTLCPreimageAlreadyExists is returned when trying to set an
// htlc-level preimage but one is already known.
ErrHTLCPreimageAlreadyExists = errors.New(
"htlc-level preimage already exists",
)
// ErrInvoiceNotFound is returned when a targeted invoice can't be
// found.
ErrInvoiceNotFound = errors.New("unable to locate invoice")
// ErrNoInvoicesCreated is returned when we don't have invoices in
// our database to return.
ErrNoInvoicesCreated = errors.New("there are no existing invoices")
// ErrDuplicateInvoice is returned when an invoice with the target
// payment hash already exists.
ErrDuplicateInvoice = errors.New(
"invoice with payment hash already exists",
)
// ErrDuplicatePayAddr is returned when an invoice with the target
// payment addr already exists.
ErrDuplicatePayAddr = errors.New(
"invoice with payemnt addr already exists",
)
// ErrInvRefEquivocation is returned when an InvoiceRef targets
// multiple, distinct invoices.
ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices")
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
// created.
ErrNoPaymentsCreated = errors.New("there are no existing payments")
)
// ErrDuplicateSetID is an error returned when attempting to adding an AMP HTLC
// to an invoice, but another invoice is already indexed by the same set id.
type ErrDuplicateSetID struct {
SetID [32]byte
}
// Error returns a human-readable description of ErrDuplicateSetID.
func (e ErrDuplicateSetID) Error() string {
return fmt.Sprintf("invoice with set_id=%x already exists", e.SetID)
}
package invoices
import (
"errors"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/zpay32"
)
// invoiceExpiry is a vanity interface for different invoice expiry types
// which implement the priority queue item interface, used to improve code
// readability.
type invoiceExpiry queue.PriorityQueueItem
// Compile time assertion that invoiceExpiryTs implements invoiceExpiry.
var _ invoiceExpiry = (*invoiceExpiryTs)(nil)
// invoiceExpiryTs holds and invoice's payment hash and its expiry. This
// is used to order invoices by their expiry time for cancellation.
type invoiceExpiryTs struct {
PaymentHash lntypes.Hash
Expiry time.Time
Keysend bool
}
// Less implements PriorityQueueItem.Less such that the top item in the
// priority queue will be the one that expires next.
func (e invoiceExpiryTs) Less(other queue.PriorityQueueItem) bool {
return e.Expiry.Before(other.(*invoiceExpiryTs).Expiry)
}
// Compile time assertion that invoiceExpiryHeight implements invoiceExpiry.
var _ invoiceExpiry = (*invoiceExpiryHeight)(nil)
// invoiceExpiryHeight holds information about an invoice which can be used to
// cancel it based on its expiry height.
type invoiceExpiryHeight struct {
paymentHash lntypes.Hash
expiryHeight uint32
}
// Less implements PriorityQueueItem.Less such that the top item in the
// priority queue is the lowest block height.
func (b invoiceExpiryHeight) Less(other queue.PriorityQueueItem) bool {
return b.expiryHeight < other.(*invoiceExpiryHeight).expiryHeight
}
// expired returns a boolean that indicates whether this entry has expired,
// taking our expiry delta into account.
func (b invoiceExpiryHeight) expired(currentHeight, delta uint32) bool {
return currentHeight+delta >= b.expiryHeight
}
// InvoiceExpiryWatcher handles automatic invoice cancellation of expired
// invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
// settled or canceled) invoices invoices to its watching queue. When a new
// invoice is added to the InvoiceRegistry, it'll be forwarded to the
// InvoiceExpiryWatcher and will end up in the watching queue as well.
// If any of the watched invoices expire, they'll be removed from the watching
// queue and will be cancelled through InvoiceRegistry.CancelInvoice().
type InvoiceExpiryWatcher struct {
sync.Mutex
started bool
// clock is the clock implementation that InvoiceExpiryWatcher uses.
// It is useful for testing.
clock clock.Clock
// notifier provides us with block height updates.
notifier chainntnfs.ChainNotifier
// blockExpiryDelta is the number of blocks before a htlc's expiry that
// we expire the invoice based on expiry height. We use a delta because
// we will go to some delta before our expiry, so we want to cancel
// before this to prevent force closes.
blockExpiryDelta uint32
// currentHeight is the current block height.
currentHeight uint32
// currentHash is the block hash for our current height.
currentHash *chainhash.Hash
// cancelInvoice is a template method that cancels an expired invoice.
cancelInvoice func(lntypes.Hash, bool) error
// timestampExpiryQueue holds invoiceExpiry items and is used to find
// the next invoice to expire.
timestampExpiryQueue queue.PriorityQueue
// blockExpiryQueue holds blockExpiry items and is used to find the
// next invoice to expire based on block height. Only hold invoices
// with active htlcs are added to this queue, because they require
// manual cancellation when the hltc is going to time out. Items in
// this queue may already be in the timestampExpiryQueue, this is ok
// because they will not be expired based on timestamp if they have
// active htlcs.
blockExpiryQueue queue.PriorityQueue
// newInvoices channel is used to wake up the main loop when a new
// invoices is added.
newInvoices chan []invoiceExpiry
wg sync.WaitGroup
// quit signals InvoiceExpiryWatcher to stop.
quit chan struct{}
}
// NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
func NewInvoiceExpiryWatcher(clock clock.Clock,
expiryDelta, startHeight uint32, startHash *chainhash.Hash,
notifier chainntnfs.ChainNotifier) *InvoiceExpiryWatcher {
return &InvoiceExpiryWatcher{
clock: clock,
notifier: notifier,
blockExpiryDelta: expiryDelta,
currentHeight: startHeight,
currentHash: startHash,
newInvoices: make(chan []invoiceExpiry),
quit: make(chan struct{}),
}
}
// Start starts the subscription handler and the main loop. Start() will
// return with error if InvoiceExpiryWatcher is already started. Start()
// expects a cancellation function passed that will be use to cancel expired
// invoices by their payment hash.
func (ew *InvoiceExpiryWatcher) Start(
cancelInvoice func(lntypes.Hash, bool) error) error {
ew.Lock()
defer ew.Unlock()
if ew.started {
return fmt.Errorf("InvoiceExpiryWatcher already started")
}
ew.started = true
ew.cancelInvoice = cancelInvoice
ntfn, err := ew.notifier.RegisterBlockEpochNtfn(&chainntnfs.BlockEpoch{
Height: int32(ew.currentHeight),
Hash: ew.currentHash,
})
if err != nil {
return err
}
ew.wg.Add(1)
go ew.mainLoop(ntfn)
return nil
}
// Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
// fully stop.
func (ew *InvoiceExpiryWatcher) Stop() {
ew.Lock()
defer ew.Unlock()
if ew.started {
// Signal subscriptionHandler to quit and wait for it to return.
close(ew.quit)
ew.wg.Wait()
ew.started = false
}
}
// makeInvoiceExpiry checks if the passed invoice may be canceled and calculates
// the expiry time and creates a slimmer invoiceExpiry implementation.
func makeInvoiceExpiry(paymentHash lntypes.Hash,
invoice *Invoice) invoiceExpiry {
switch invoice.State {
// If we have an open invoice with no htlcs, we want to expire the
// invoice based on timestamp
case ContractOpen:
return makeTimestampExpiry(paymentHash, invoice)
// If an invoice has active htlcs, we want to expire it based on block
// height. We only do this for hodl invoices, since regular invoices
// should resolve themselves automatically.
case ContractAccepted:
if !invoice.HodlInvoice {
log.Debugf("Invoice in accepted state not added to "+
"expiry watcher: %v", paymentHash)
return nil
}
var minHeight uint32
for _, htlc := range invoice.Htlcs {
// We only care about accepted htlcs, since they will
// trigger force-closes.
if htlc.State != HtlcStateAccepted {
continue
}
if minHeight == 0 || htlc.Expiry < minHeight {
minHeight = htlc.Expiry
}
}
return makeHeightExpiry(paymentHash, minHeight)
default:
log.Debugf("Invoice not added to expiry watcher: %v",
paymentHash)
return nil
}
}
// makeTimestampExpiry creates a timestamp-based expiry entry.
func makeTimestampExpiry(paymentHash lntypes.Hash,
invoice *Invoice) *invoiceExpiryTs {
if invoice.State != ContractOpen {
return nil
}
realExpiry := invoice.Terms.Expiry
if realExpiry == 0 {
realExpiry = zpay32.DefaultInvoiceExpiry
}
expiry := invoice.CreationDate.Add(realExpiry)
return &invoiceExpiryTs{
PaymentHash: paymentHash,
Expiry: expiry,
Keysend: len(invoice.PaymentRequest) == 0,
}
}
// makeHeightExpiry creates height-based expiry for an invoice based on its
// lowest htlc expiry height.
func makeHeightExpiry(paymentHash lntypes.Hash,
minHeight uint32) *invoiceExpiryHeight {
if minHeight == 0 {
log.Warnf("make height expiry called with 0 height")
return nil
}
return &invoiceExpiryHeight{
paymentHash: paymentHash,
expiryHeight: minHeight,
}
}
// AddInvoices adds invoices to the InvoiceExpiryWatcher.
func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) {
if len(invoices) == 0 {
return
}
select {
case ew.newInvoices <- invoices:
log.Debugf("Added %d invoices to the expiry watcher",
len(invoices))
// Select on quit too so that callers won't get blocked in case
// of concurrent shutdown.
case <-ew.quit:
}
}
// nextTimestampExpiry returns a Time chan to wait on until the next invoice
// expires. If there are no active invoices, then it'll simply wait
// indefinitely.
func (ew *InvoiceExpiryWatcher) nextTimestampExpiry() <-chan time.Time {
if !ew.timestampExpiryQueue.Empty() {
top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs)
return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
}
return nil
}
// nextHeightExpiry returns a channel that will immediately be read from if
// the top item on our queue has expired.
func (ew *InvoiceExpiryWatcher) nextHeightExpiry() <-chan uint32 {
if ew.blockExpiryQueue.Empty() {
return nil
}
top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight)
if !top.expired(ew.currentHeight, ew.blockExpiryDelta) {
return nil
}
blockChan := make(chan uint32, 1)
blockChan <- top.expiryHeight
return blockChan
}
// cancelNextExpiredInvoice will cancel the next expired invoice and removes
// it from the expiry queue.
func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() {
if !ew.timestampExpiryQueue.Empty() {
top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs)
if !top.Expiry.Before(ew.clock.Now()) {
return
}
// Don't force-cancel already accepted invoices. An exception to
// this are auto-generated keysend invoices. Because those move
// to the Accepted state directly after being opened, the expiry
// field would never be used. Enabling cancellation for accepted
// keysend invoices creates a safety mechanism that can prevents
// channel force-closes.
ew.expireInvoice(top.PaymentHash, top.Keysend)
ew.timestampExpiryQueue.Pop()
}
}
// cancelNextHeightExpiredInvoice looks at our height based queue and expires
// the next invoice if we have reached its expiry block.
func (ew *InvoiceExpiryWatcher) cancelNextHeightExpiredInvoice() {
if ew.blockExpiryQueue.Empty() {
return
}
top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight)
if !top.expired(ew.currentHeight, ew.blockExpiryDelta) {
return
}
// We always force-cancel block-based expiry so that we can
// cancel invoices that have been accepted but not yet resolved.
// This helps us avoid force closes.
ew.expireInvoice(top.paymentHash, true)
ew.blockExpiryQueue.Pop()
}
// expireInvoice attempts to expire an invoice and logs an error if we get an
// unexpected error.
func (ew *InvoiceExpiryWatcher) expireInvoice(hash lntypes.Hash, force bool) {
err := ew.cancelInvoice(hash, force)
switch {
case err == nil:
case errors.Is(err, ErrInvoiceAlreadyCanceled):
case errors.Is(err, ErrInvoiceAlreadySettled):
case errors.Is(err, ErrInvoiceNotFound):
// It's possible that the user has manually canceled the invoice
// which will then be deleted by the garbage collector resulting
// in an ErrInvoiceNotFound error.
default:
log.Errorf("Unable to cancel invoice: %v: %v", hash, err)
}
}
// pushInvoices adds invoices to be expired to their relevant queue.
func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) {
for _, inv := range invoices {
// Switch on the type of entry we have. We need to check nil
// on the implementation of the interface because the interface
// itself is non-nil.
switch expiry := inv.(type) {
case *invoiceExpiryTs:
if expiry != nil {
ew.timestampExpiryQueue.Push(expiry)
}
case *invoiceExpiryHeight:
if expiry != nil {
ew.blockExpiryQueue.Push(expiry)
}
default:
log.Errorf("unexpected queue item: %T", inv)
}
}
}
// mainLoop is a goroutine that receives new invoices and handles cancellation
// of expired invoices.
func (ew *InvoiceExpiryWatcher) mainLoop(blockNtfns *chainntnfs.BlockEpochEvent) {
defer func() {
blockNtfns.Cancel()
ew.wg.Done()
}()
// We have two different queues, so we use a different cancel method
// depending on which expiry condition we have hit. Starting with time
// based expiry is an arbitrary choice to start off.
cancelNext := ew.cancelNextExpiredInvoice
for {
// Cancel any invoices that may have expired.
cancelNext()
select {
case newInvoices := <-ew.newInvoices:
// Take newly forwarded invoices with higher priority
// in order to not block the newInvoices channel.
ew.pushInvoices(newInvoices)
continue
default:
select {
// Wait until the next invoice expires.
case <-ew.nextTimestampExpiry():
cancelNext = ew.cancelNextExpiredInvoice
continue
case <-ew.nextHeightExpiry():
cancelNext = ew.cancelNextHeightExpiredInvoice
continue
case newInvoices := <-ew.newInvoices:
ew.pushInvoices(newInvoices)
// Consume new blocks.
case block, ok := <-blockNtfns.Epochs:
if !ok {
log.Debugf("block notifications " +
"canceled")
return
}
ew.currentHeight = uint32(block.Height)
ew.currentHash = block.Hash
case <-ew.quit:
return
}
}
}
}
package invoices
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/queue"
"github.com/lightningnetwork/lnd/record"
)
var (
// ErrInvoiceExpiryTooSoon is returned when an invoice is attempted to
// be accepted or settled with not enough blocks remaining.
ErrInvoiceExpiryTooSoon = errors.New("invoice expiry too soon")
// ErrInvoiceAmountTooLow is returned when an invoice is attempted to
// be accepted or settled with an amount that is too low.
ErrInvoiceAmountTooLow = errors.New(
"paid amount less than invoice amount",
)
// ErrShuttingDown is returned when an operation failed because the
// invoice registry is shutting down.
ErrShuttingDown = errors.New("invoice registry shutting down")
)
const (
// DefaultHtlcHoldDuration defines the default for how long mpp htlcs
// are held while waiting for the other set members to arrive.
DefaultHtlcHoldDuration = 120 * time.Second
)
// RegistryConfig contains the configuration parameters for invoice registry.
type RegistryConfig struct {
// FinalCltvRejectDelta defines the number of blocks before the expiry
// of the htlc where we no longer settle it as an exit hop and instead
// cancel it back. Normally this value should be lower than the cltv
// expiry of any invoice we create and the code effectuating this should
// not be hit.
FinalCltvRejectDelta int32
// HtlcHoldDuration defines for how long mpp htlcs are held while
// waiting for the other set members to arrive.
HtlcHoldDuration time.Duration
// Clock holds the clock implementation that is used to provide
// Now() and TickAfter() and is useful to stub out the clock functions
// during testing.
Clock clock.Clock
// AcceptKeySend indicates whether we want to accept spontaneous key
// send payments.
AcceptKeySend bool
// AcceptAMP indicates whether we want to accept spontaneous AMP
// payments.
AcceptAMP bool
// GcCanceledInvoicesOnStartup if set, we'll attempt to garbage collect
// all canceled invoices upon start.
GcCanceledInvoicesOnStartup bool
// GcCanceledInvoicesOnTheFly if set, we'll garbage collect all newly
// canceled invoices on the fly.
GcCanceledInvoicesOnTheFly bool
// KeysendHoldTime indicates for how long we want to accept and hold
// spontaneous keysend payments.
KeysendHoldTime time.Duration
// HtlcInterceptor is an interface that allows the invoice registry to
// let clients intercept invoices before they are settled.
HtlcInterceptor HtlcInterceptor
}
// htlcReleaseEvent describes an htlc auto-release event. It is used to release
// mpp htlcs for which the complete set didn't arrive in time.
type htlcReleaseEvent struct {
// invoiceRef identifiers the invoice this htlc belongs to.
invoiceRef InvoiceRef
// key is the circuit key of the htlc to release.
key CircuitKey
// releaseTime is the time at which to release the htlc.
releaseTime time.Time
}
// Less is used to order PriorityQueueItem's by their release time such that
// items with the older release time are at the top of the queue.
//
// NOTE: Part of the queue.PriorityQueueItem interface.
func (r *htlcReleaseEvent) Less(other queue.PriorityQueueItem) bool {
return r.releaseTime.Before(other.(*htlcReleaseEvent).releaseTime)
}
// InvoiceRegistry is a central registry of all the outstanding invoices
// created by the daemon. The registry is a thin wrapper around a map in order
// to ensure that all updates/reads are thread safe.
type InvoiceRegistry struct {
started atomic.Bool
stopped atomic.Bool
sync.RWMutex
nextClientID uint32 // must be used atomically
idb InvoiceDB
// cfg contains the registry's configuration parameters.
cfg *RegistryConfig
// notificationClientMux locks notificationClients and
// singleNotificationClients. Using a separate mutex for these maps is
// necessary to avoid deadlocks in the registry when processing invoice
// events.
notificationClientMux sync.RWMutex
notificationClients map[uint32]*InvoiceSubscription
// TODO(yy): use map[lntypes.Hash]*SingleInvoiceSubscription for better
// performance.
singleNotificationClients map[uint32]*SingleInvoiceSubscription
// invoiceEvents is a single channel over which invoice updates are
// carried.
invoiceEvents chan *invoiceEvent
// hodlSubscriptionsMux locks the hodlSubscriptions and
// hodlReverseSubscriptions. Using a separate mutex for these maps is
// necessary to avoid deadlocks in the registry when processing invoice
// events.
hodlSubscriptionsMux sync.RWMutex
// hodlSubscriptions is a map from a circuit key to a list of
// subscribers. It is used for efficient notification of links.
hodlSubscriptions map[CircuitKey]map[chan<- interface{}]struct{}
// reverseSubscriptions tracks circuit keys subscribed to per
// subscriber. This is used to unsubscribe from all hashes efficiently.
hodlReverseSubscriptions map[chan<- interface{}]map[CircuitKey]struct{}
// htlcAutoReleaseChan contains the new htlcs that need to be
// auto-released.
htlcAutoReleaseChan chan *htlcReleaseEvent
expiryWatcher *InvoiceExpiryWatcher
wg sync.WaitGroup
quit chan struct{}
}
// NewRegistry creates a new invoice registry. The invoice registry
// wraps the persistent on-disk invoice storage with an additional in-memory
// layer. The in-memory layer is in place such that debug invoices can be added
// which are volatile yet available system wide within the daemon.
func NewRegistry(idb InvoiceDB, expiryWatcher *InvoiceExpiryWatcher,
cfg *RegistryConfig) *InvoiceRegistry {
notificationClients := make(map[uint32]*InvoiceSubscription)
singleNotificationClients := make(map[uint32]*SingleInvoiceSubscription)
return &InvoiceRegistry{
idb: idb,
notificationClients: notificationClients,
singleNotificationClients: singleNotificationClients,
invoiceEvents: make(chan *invoiceEvent, 100),
hodlSubscriptions: make(
map[CircuitKey]map[chan<- interface{}]struct{},
),
hodlReverseSubscriptions: make(
map[chan<- interface{}]map[CircuitKey]struct{},
),
cfg: cfg,
htlcAutoReleaseChan: make(chan *htlcReleaseEvent),
expiryWatcher: expiryWatcher,
quit: make(chan struct{}),
}
}
// scanInvoicesOnStart will scan all invoices on start and add active invoices
// to the invoice expiry watcher while also attempting to delete all canceled
// invoices.
func (i *InvoiceRegistry) scanInvoicesOnStart(ctx context.Context) error {
pendingInvoices, err := i.idb.FetchPendingInvoices(ctx)
if err != nil {
return err
}
var pending []invoiceExpiry
for paymentHash, invoice := range pendingInvoices {
invoice := invoice
expiryRef := makeInvoiceExpiry(paymentHash, &invoice)
if expiryRef != nil {
pending = append(pending, expiryRef)
}
}
log.Debugf("Adding %d pending invoices to the expiry watcher",
len(pending))
i.expiryWatcher.AddInvoices(pending...)
if i.cfg.GcCanceledInvoicesOnStartup {
log.Infof("Deleting canceled invoices")
err = i.idb.DeleteCanceledInvoices(ctx)
if err != nil {
log.Warnf("Deleting canceled invoices failed: %v", err)
return err
}
}
return nil
}
// Start starts the registry and all goroutines it needs to carry out its task.
func (i *InvoiceRegistry) Start() error {
var err error
log.Info("InvoiceRegistry starting...")
if i.started.Swap(true) {
return fmt.Errorf("InvoiceRegistry started more than once")
}
// Start InvoiceExpiryWatcher and prepopulate it with existing
// active invoices.
err = i.expiryWatcher.Start(
func(hash lntypes.Hash, force bool) error {
return i.cancelInvoiceImpl(
context.Background(), hash, force,
)
})
if err != nil {
return err
}
i.wg.Add(1)
go i.invoiceEventLoop()
// Now scan all pending and removable invoices to the expiry
// watcher or delete them.
err = i.scanInvoicesOnStart(context.Background())
if err != nil {
_ = i.Stop()
}
log.Debug("InvoiceRegistry started")
return err
}
// Stop signals the registry for a graceful shutdown.
func (i *InvoiceRegistry) Stop() error {
log.Info("InvoiceRegistry shutting down...")
if i.stopped.Swap(true) {
return fmt.Errorf("InvoiceRegistry stopped more than once")
}
log.Info("InvoiceRegistry shutting down...")
defer log.Debug("InvoiceRegistry shutdown complete")
var err error
if i.expiryWatcher == nil {
err = fmt.Errorf("InvoiceRegistry expiryWatcher is not " +
"initialized")
} else {
i.expiryWatcher.Stop()
}
close(i.quit)
i.wg.Wait()
log.Debug("InvoiceRegistry shutdown complete")
return err
}
// invoiceEvent represents a new event that has modified on invoice on disk.
// Only two event types are currently supported: newly created invoices, and
// instance where invoices are settled.
type invoiceEvent struct {
hash lntypes.Hash
invoice *Invoice
setID *[32]byte
}
// tickAt returns a channel that ticks at the specified time. If the time has
// already passed, it will tick immediately.
func (i *InvoiceRegistry) tickAt(t time.Time) <-chan time.Time {
now := i.cfg.Clock.Now()
return i.cfg.Clock.TickAfter(t.Sub(now))
}
// invoiceEventLoop is the dedicated goroutine responsible for accepting
// new notification subscriptions, cancelling old subscriptions, and
// dispatching new invoice events.
func (i *InvoiceRegistry) invoiceEventLoop() {
defer i.wg.Done()
// Set up a heap for htlc auto-releases.
autoReleaseHeap := &queue.PriorityQueue{}
for {
// If there is something to release, set up a release tick
// channel.
var nextReleaseTick <-chan time.Time
if autoReleaseHeap.Len() > 0 {
head := autoReleaseHeap.Top().(*htlcReleaseEvent)
nextReleaseTick = i.tickAt(head.releaseTime)
}
select {
// A sub-systems has just modified the invoice state, so we'll
// dispatch notifications to all registered clients.
case event := <-i.invoiceEvents:
// For backwards compatibility, do not notify all
// invoice subscribers of cancel and accept events.
state := event.invoice.State
if state != ContractCanceled &&
state != ContractAccepted {
i.dispatchToClients(event)
}
i.dispatchToSingleClients(event)
// A new htlc came in for auto-release.
case event := <-i.htlcAutoReleaseChan:
log.Debugf("Scheduling auto-release for htlc: "+
"ref=%v, key=%v at %v",
event.invoiceRef, event.key, event.releaseTime)
// We use an independent timer for every htlc rather
// than a set timer that is reset with every htlc coming
// in. Otherwise the sender could keep resetting the
// timer until the broadcast window is entered and our
// channel is force closed.
autoReleaseHeap.Push(event)
// The htlc at the top of the heap needs to be auto-released.
case <-nextReleaseTick:
event := autoReleaseHeap.Pop().(*htlcReleaseEvent)
err := i.cancelSingleHtlc(
event.invoiceRef, event.key, ResultMppTimeout,
)
if err != nil {
log.Errorf("HTLC timer: %v", err)
}
case <-i.quit:
return
}
}
}
// dispatchToSingleClients passes the supplied event to all notification
// clients that subscribed to all the invoice this event applies to.
func (i *InvoiceRegistry) dispatchToSingleClients(event *invoiceEvent) {
// Dispatch to single invoice subscribers.
clients := i.copySingleClients()
for _, client := range clients {
payHash := client.invoiceRef.PayHash()
if payHash == nil || *payHash != event.hash {
continue
}
select {
case <-client.backlogDelivered:
// We won't deliver any events until the backlog has
// went through first.
case <-i.quit:
return
}
client.notify(event)
}
}
// dispatchToClients passes the supplied event to all notification clients that
// subscribed to all invoices. Add and settle indices are used to make sure
// that clients don't receive duplicate or unwanted events.
func (i *InvoiceRegistry) dispatchToClients(event *invoiceEvent) {
invoice := event.invoice
clients := i.copyClients()
for clientID, client := range clients {
// Before we dispatch this event, we'll check
// to ensure that this client hasn't already
// received this notification in order to
// ensure we don't duplicate any events.
// TODO(joostjager): Refactor switches.
state := event.invoice.State
switch {
// If we've already sent this settle event to
// the client, then we can skip this.
case state == ContractSettled &&
client.settleIndex >= invoice.SettleIndex:
continue
// Similarly, if we've already sent this add to
// the client then we can skip this one, but only if this isn't
// an AMP invoice. AMP invoices always remain in the settle
// state as a base invoice.
case event.setID == nil && state == ContractOpen &&
client.addIndex >= invoice.AddIndex:
continue
// These two states should never happen, but we
// log them just in case so we can detect this
// instance.
case state == ContractOpen &&
client.addIndex+1 != invoice.AddIndex:
log.Warnf("client=%v for invoice "+
"notifications missed an update, "+
"add_index=%v, new add event index=%v",
clientID, client.addIndex,
invoice.AddIndex)
case state == ContractSettled &&
client.settleIndex+1 != invoice.SettleIndex:
log.Warnf("client=%v for invoice "+
"notifications missed an update, "+
"settle_index=%v, new settle event index=%v",
clientID, client.settleIndex,
invoice.SettleIndex)
}
select {
case <-client.backlogDelivered:
// We won't deliver any events until the backlog has
// been processed.
case <-i.quit:
return
}
err := client.notify(&invoiceEvent{
invoice: invoice,
setID: event.setID,
})
if err != nil {
log.Errorf("Failed dispatching to client: %v", err)
return
}
// Each time we send a notification to a client, we'll record
// the latest add/settle index it has. We'll use this to ensure
// we don't send a notification twice, which can happen if a new
// event is added while we're catching up a new client.
invState := event.invoice.State
switch {
case invState == ContractSettled:
client.settleIndex = invoice.SettleIndex
case invState == ContractOpen && event.setID == nil:
client.addIndex = invoice.AddIndex
// If this is an AMP invoice, then we'll need to use the set ID
// to keep track of the settle index of the client. AMP
// invoices never go to the open state, but if a setID is
// passed, then we know it was just settled and will track the
// highest settle index so far.
case invState == ContractOpen && event.setID != nil:
setID := *event.setID
client.settleIndex = invoice.AMPState[setID].SettleIndex
default:
log.Errorf("unexpected invoice state: %v",
event.invoice.State)
}
}
}
// deliverBacklogEvents will attempts to query the invoice database for any
// notifications that the client has missed since it reconnected last.
func (i *InvoiceRegistry) deliverBacklogEvents(ctx context.Context,
client *InvoiceSubscription) error {
log.Debugf("Collecting added invoices since %v for client %v",
client.addIndex, client.id)
addEvents, err := i.idb.InvoicesAddedSince(ctx, client.addIndex)
if err != nil {
return err
}
log.Debugf("Collecting settled invoices since %v for client %v",
client.settleIndex, client.id)
settleEvents, err := i.idb.InvoicesSettledSince(ctx, client.settleIndex)
if err != nil {
return err
}
log.Debugf("Delivering %d added invoices and %d settled invoices "+
"for client %v", len(addEvents), len(settleEvents), client.id)
// If we have any to deliver, then we'll append them to the end of the
// notification queue in order to catch up the client before delivering
// any new notifications.
for _, addEvent := range addEvents {
// We re-bind the loop variable to ensure we don't hold onto
// the loop reference causing is to point to the same item.
addEvent := addEvent
select {
case client.ntfnQueue.ChanIn() <- &invoiceEvent{
invoice: &addEvent,
}:
case <-i.quit:
return ErrShuttingDown
}
}
for _, settleEvent := range settleEvents {
// We re-bind the loop variable to ensure we don't hold onto
// the loop reference causing is to point to the same item.
settleEvent := settleEvent
select {
case client.ntfnQueue.ChanIn() <- &invoiceEvent{
invoice: &settleEvent,
}:
case <-i.quit:
return ErrShuttingDown
}
}
return nil
}
// deliverSingleBacklogEvents will attempt to query the invoice database to
// retrieve the current invoice state and deliver this to the subscriber. Single
// invoice subscribers will always receive the current state right after
// subscribing. Only in case the invoice does not yet exist, nothing is sent
// yet.
func (i *InvoiceRegistry) deliverSingleBacklogEvents(ctx context.Context,
client *SingleInvoiceSubscription) error {
invoice, err := i.idb.LookupInvoice(ctx, client.invoiceRef)
// It is possible that the invoice does not exist yet, but the client is
// already watching it in anticipation.
isNotFound := errors.Is(err, ErrInvoiceNotFound)
isNotCreated := errors.Is(err, ErrNoInvoicesCreated)
if isNotFound || isNotCreated {
return nil
}
if err != nil {
return err
}
payHash := client.invoiceRef.PayHash()
if payHash == nil {
return nil
}
err = client.notify(&invoiceEvent{
hash: *payHash,
invoice: &invoice,
})
if err != nil {
return err
}
log.Debugf("Client(id=%v) delivered single backlog event: payHash=%v",
client.id, payHash)
return nil
}
// AddInvoice adds a regular invoice for the specified amount, identified by
// the passed preimage. Additionally, any memo or receipt data provided will
// also be stored on-disk. Once this invoice is added, subsystems within the
// daemon add/forward HTLCs are able to obtain the proper preimage required for
// redemption in the case that we're the final destination. We also return the
// addIndex of the newly created invoice which monotonically increases for each
// new invoice added. A side effect of this function is that it also sets
// AddIndex on the invoice argument.
func (i *InvoiceRegistry) AddInvoice(ctx context.Context, invoice *Invoice,
paymentHash lntypes.Hash) (uint64, error) {
i.Lock()
ref := InvoiceRefByHash(paymentHash)
log.Debugf("Invoice%v: added with terms %v", ref, invoice.Terms)
addIndex, err := i.idb.AddInvoice(ctx, invoice, paymentHash)
if err != nil {
i.Unlock()
return 0, err
}
// Now that we've added the invoice, we'll send dispatch a message to
// notify the clients of this new invoice.
i.notifyClients(paymentHash, invoice, nil)
i.Unlock()
// InvoiceExpiryWatcher.AddInvoice must not be locked by InvoiceRegistry
// to avoid deadlock when a new invoice is added while an other is being
// canceled.
invoiceExpiryRef := makeInvoiceExpiry(paymentHash, invoice)
if invoiceExpiryRef != nil {
i.expiryWatcher.AddInvoices(invoiceExpiryRef)
}
return addIndex, nil
}
// LookupInvoice looks up an invoice by its payment hash (R-Hash), if found
// then we're able to pull the funds pending within an HTLC.
//
// TODO(roasbeef): ignore if settled?
func (i *InvoiceRegistry) LookupInvoice(ctx context.Context,
rHash lntypes.Hash) (Invoice, error) {
// We'll check the database to see if there's an existing matching
// invoice.
ref := InvoiceRefByHash(rHash)
return i.idb.LookupInvoice(ctx, ref)
}
// LookupInvoiceByRef looks up an invoice by the given reference, if found
// then we're able to pull the funds pending within an HTLC.
func (i *InvoiceRegistry) LookupInvoiceByRef(ctx context.Context,
ref InvoiceRef) (Invoice, error) {
return i.idb.LookupInvoice(ctx, ref)
}
// startHtlcTimer starts a new timer via the invoice registry main loop that
// cancels a single htlc on an invoice when the htlc hold duration has passed.
func (i *InvoiceRegistry) startHtlcTimer(invoiceRef InvoiceRef,
key CircuitKey, acceptTime time.Time) error {
releaseTime := acceptTime.Add(i.cfg.HtlcHoldDuration)
event := &htlcReleaseEvent{
invoiceRef: invoiceRef,
key: key,
releaseTime: releaseTime,
}
select {
case i.htlcAutoReleaseChan <- event:
return nil
case <-i.quit:
return ErrShuttingDown
}
}
// cancelSingleHtlc cancels a single accepted htlc on an invoice. It takes
// a resolution result which will be used to notify subscribed links and
// resolvers of the details of the htlc cancellation.
func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef,
key CircuitKey, result FailResolutionResult) error {
updateInvoice := func(invoice *Invoice, setID *SetID) (
*InvoiceUpdateDesc, error) {
// Only allow individual htlc cancellation on open invoices.
if invoice.State != ContractOpen {
log.Debugf("CancelSingleHtlc: cannot cancel htlc %v "+
"on invoice %v, invoice is no longer open", key,
invoiceRef)
return nil, nil
}
// Also for AMP invoices we fetch the relevant HTLCs, so
// the HTLC should be found, otherwise we return an error.
htlc, ok := invoice.Htlcs[key]
if !ok {
return nil, fmt.Errorf("htlc %v not found on "+
"invoice %v", key, invoiceRef)
}
htlcState := htlc.State
// Cancellation is only possible if the htlc wasn't already
// resolved.
if htlcState != HtlcStateAccepted {
log.Debugf("CancelSingleHtlc: htlc %v on invoice %v "+
"is already resolved", key, invoiceRef)
return nil, nil
}
log.Debugf("CancelSingleHtlc: cancelling htlc %v on invoice %v",
key, invoiceRef)
// Return an update descriptor that cancels htlc and keeps
// invoice open.
canceledHtlcs := map[CircuitKey]struct{}{
key: {},
}
return &InvoiceUpdateDesc{
UpdateType: CancelHTLCsUpdate,
CancelHtlcs: canceledHtlcs,
SetID: setID,
}, nil
}
// Try to mark the specified htlc as canceled in the invoice database.
// Intercept the update descriptor to set the local updated variable. If
// no invoice update is performed, we can return early.
// setID is only set for AMP HTLCs, so it can be nil and it is expected
// to be nil for non-AMP HTLCs.
setID := (*SetID)(invoiceRef.SetID())
var updated bool
invoice, err := i.idb.UpdateInvoice(
context.Background(), invoiceRef, setID,
func(invoice *Invoice) (
*InvoiceUpdateDesc, error) {
updateDesc, err := updateInvoice(invoice, setID)
if err != nil {
return nil, err
}
updated = updateDesc != nil
return updateDesc, err
},
)
if err != nil {
return err
}
if !updated {
return nil
}
// The invoice has been updated. Notify subscribers of the htlc
// resolution.
htlc, ok := invoice.Htlcs[key]
if !ok {
return fmt.Errorf("htlc %v not found", key)
}
if htlc.State == HtlcStateCanceled {
resolution := NewFailResolution(
key, int32(htlc.AcceptHeight), result,
)
log.Debugf("Signaling htlc(%v) cancellation of invoice(%v) "+
"with resolution(%v) to the link subsystem", key,
invoiceRef, result)
i.notifyHodlSubscribers(resolution)
}
return nil
}
// processKeySend just-in-time inserts an invoice if this htlc is a keysend
// htlc.
func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error {
// Retrieve keysend record if present.
preimageSlice, ok := ctx.customRecords[record.KeySendType]
if !ok {
return nil
}
// Cancel htlc is preimage is invalid.
preimage, err := lntypes.MakePreimage(preimageSlice)
if err != nil {
return err
}
if preimage.Hash() != ctx.hash {
return fmt.Errorf("invalid keysend preimage %v for hash %v",
preimage, ctx.hash)
}
// Only allow keysend for non-mpp payments.
if ctx.mpp != nil {
return errors.New("no mpp keysend supported")
}
// Create an invoice for the htlc amount.
amt := ctx.amtPaid
// Set tlv required feature vector on the invoice. Otherwise we wouldn't
// be able to pay to it with keysend.
rawFeatures := lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadRequired,
)
features := lnwire.NewFeatureVector(rawFeatures, lnwire.Features)
// Use the minimum block delta that we require for settling htlcs.
finalCltvDelta := i.cfg.FinalCltvRejectDelta
// Pre-check expiry here to prevent inserting an invoice that will not
// be settled.
if ctx.expiry < uint32(ctx.currentHeight+finalCltvDelta) {
return errors.New("final expiry too soon")
}
// The invoice database indexes all invoices by payment address, however
// legacy keysend payment do not have one. In order to avoid a new
// payment type on-disk wrt. to indexing, we'll continue to insert a
// blank payment address which is special cased in the insertion logic
// to not be indexed. In the future, once AMP is merged, this should be
// replaced by generating a random payment address on the behalf of the
// sender.
payAddr := BlankPayAddr
// Create placeholder invoice.
invoice := &Invoice{
CreationDate: i.cfg.Clock.Now(),
Terms: ContractTerm{
FinalCltvDelta: finalCltvDelta,
Value: amt,
PaymentPreimage: &preimage,
PaymentAddr: payAddr,
Features: features,
},
}
if i.cfg.KeysendHoldTime != 0 {
invoice.HodlInvoice = true
invoice.Terms.Expiry = i.cfg.KeysendHoldTime
}
// Insert invoice into database. Ignore duplicates, because this
// may be a replay.
_, err = i.AddInvoice(context.Background(), invoice, ctx.hash)
if err != nil && !errors.Is(err, ErrDuplicateInvoice) {
return err
}
return nil
}
// processAMP just-in-time inserts an invoice if this htlc is a keysend
// htlc.
func (i *InvoiceRegistry) processAMP(ctx invoiceUpdateCtx) error {
// AMP payments MUST also include an MPP record.
if ctx.mpp == nil {
return errors.New("no MPP record for AMP")
}
// Create an invoice for the total amount expected, provided in the MPP
// record.
amt := ctx.mpp.TotalMsat()
// Set the TLV required and MPP optional features on the invoice. We'll
// also make the AMP features required so that it can't be paid by
// legacy or MPP htlcs.
rawFeatures := lnwire.NewRawFeatureVector(
lnwire.TLVOnionPayloadRequired,
lnwire.PaymentAddrOptional,
lnwire.AMPRequired,
)
features := lnwire.NewFeatureVector(rawFeatures, lnwire.Features)
// Use the minimum block delta that we require for settling htlcs.
finalCltvDelta := i.cfg.FinalCltvRejectDelta
// Pre-check expiry here to prevent inserting an invoice that will not
// be settled.
if ctx.expiry < uint32(ctx.currentHeight+finalCltvDelta) {
return errors.New("final expiry too soon")
}
// We'll use the sender-generated payment address provided in the HTLC
// to create our AMP invoice.
payAddr := ctx.mpp.PaymentAddr()
// Create placeholder invoice.
invoice := &Invoice{
CreationDate: i.cfg.Clock.Now(),
Terms: ContractTerm{
FinalCltvDelta: finalCltvDelta,
Value: amt,
PaymentPreimage: nil,
PaymentAddr: payAddr,
Features: features,
},
}
// Insert invoice into database. Ignore duplicates payment hashes and
// payment addrs, this may be a replay or a different HTLC for the AMP
// invoice.
_, err := i.AddInvoice(context.Background(), invoice, ctx.hash)
isDuplicatedInvoice := errors.Is(err, ErrDuplicateInvoice)
isDuplicatedPayAddr := errors.Is(err, ErrDuplicatePayAddr)
switch {
case isDuplicatedInvoice || isDuplicatedPayAddr:
return nil
default:
return err
}
}
// NotifyExitHopHtlc attempts to mark an invoice as settled. The return value
// describes how the htlc should be resolved.
//
// When the preimage of the invoice is not yet known (hodl invoice), this
// function moves the invoice to the accepted state. When SettleHoldInvoice is
// called later, a resolution message will be send back to the caller via the
// provided hodlChan. Invoice registry sends on this channel what action needs
// to be taken on the htlc (settle or cancel). The caller needs to ensure that
// the channel is either buffered or received on from another goroutine to
// prevent deadlock.
//
// In the case that the htlc is part of a larger set of htlcs that pay to the
// same invoice (multi-path payment), the htlc is held until the set is
// complete. If the set doesn't fully arrive in time, a timer will cancel the
// held htlc.
func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
amtPaid lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
circuitKey CircuitKey, hodlChan chan<- interface{},
wireCustomRecords lnwire.CustomRecords,
payload Payload) (HtlcResolution, error) {
// Create the update context containing the relevant details of the
// incoming htlc.
ctx := invoiceUpdateCtx{
hash: rHash,
circuitKey: circuitKey,
amtPaid: amtPaid,
expiry: expiry,
currentHeight: currentHeight,
finalCltvRejectDelta: i.cfg.FinalCltvRejectDelta,
wireCustomRecords: wireCustomRecords,
customRecords: payload.CustomRecords(),
mpp: payload.MultiPath(),
amp: payload.AMPRecord(),
metadata: payload.Metadata(),
pathID: payload.PathID(),
totalAmtMsat: payload.TotalAmtMsat(),
}
switch {
// If we are accepting spontaneous AMP payments and this payload
// contains an AMP record, create an AMP invoice that will be settled
// below.
case i.cfg.AcceptAMP && ctx.amp != nil:
err := i.processAMP(ctx)
if err != nil {
ctx.log(fmt.Sprintf("amp error: %v", err))
return NewFailResolution(
circuitKey, currentHeight, ResultAmpError,
), nil
}
// If we are accepting spontaneous keysend payments, create a regular
// invoice that will be settled below. We also enforce that this is only
// done when no AMP payload is present since it will only be settle-able
// by regular HTLCs.
case i.cfg.AcceptKeySend && ctx.amp == nil:
err := i.processKeySend(ctx)
if err != nil {
ctx.log(fmt.Sprintf("keysend error: %v", err))
return NewFailResolution(
circuitKey, currentHeight, ResultKeySendError,
), nil
}
}
// Execute locked notify exit hop logic.
i.Lock()
resolution, invoiceToExpire, err := i.notifyExitHopHtlcLocked(
&ctx, hodlChan,
)
i.Unlock()
if err != nil {
return nil, err
}
if invoiceToExpire != nil {
i.expiryWatcher.AddInvoices(invoiceToExpire)
}
switch r := resolution.(type) {
// The htlc is held. Start a timer outside the lock if the htlc should
// be auto-released, because otherwise a deadlock may happen with the
// main event loop.
case *htlcAcceptResolution:
if r.autoRelease {
var invRef InvoiceRef
if ctx.amp != nil {
invRef = InvoiceRefBySetID(*ctx.setID())
} else {
invRef = ctx.invoiceRef()
}
err := i.startHtlcTimer(
invRef, circuitKey, r.acceptTime,
)
if err != nil {
return nil, err
}
}
// We return a nil resolution because htlc acceptances are
// represented as nil resolutions externally.
// TODO(carla) update calling code to handle accept resolutions.
return nil, nil
// A direct resolution was received for this htlc.
case HtlcResolution:
return r, nil
// Fail if an unknown resolution type was received.
default:
return nil, errors.New("invalid resolution type")
}
}
// notifyExitHopHtlcLocked is the internal implementation of NotifyExitHopHtlc
// that should be executed inside the registry lock. The returned invoiceExpiry
// (if not nil) needs to be added to the expiry watcher outside of the lock.
func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
ctx *invoiceUpdateCtx, hodlChan chan<- interface{}) (
HtlcResolution, invoiceExpiry, error) {
invoiceRef := ctx.invoiceRef()
// This setID is only set for AMP HTLCs, so it can be nil and it is
// also expected to be nil for non-AMP HTLCs.
setID := (*SetID)(ctx.setID())
// We need to look up the current state of the invoice in order to send
// the previously accepted/settled HTLCs to the interceptor.
existingInvoice, err := i.idb.LookupInvoice(
context.Background(), invoiceRef,
)
switch {
case errors.Is(err, ErrInvoiceNotFound) ||
errors.Is(err, ErrNoInvoicesCreated):
// If the invoice was not found, return a failure resolution
// with an invoice not found result.
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
case err != nil:
ctx.log(err.Error())
return nil, nil, err
}
var cancelSet bool
// Provide the invoice to the settlement interceptor to allow
// the interceptor's client an opportunity to manipulate the
// settlement process.
err = i.cfg.HtlcInterceptor.Intercept(HtlcModifyRequest{
WireCustomRecords: ctx.wireCustomRecords,
ExitHtlcCircuitKey: ctx.circuitKey,
ExitHtlcAmt: ctx.amtPaid,
ExitHtlcExpiry: ctx.expiry,
CurrentHeight: uint32(ctx.currentHeight),
Invoice: existingInvoice,
}, func(resp HtlcModifyResponse) {
log.Debugf("Received invoice HTLC interceptor response: %v",
resp)
if resp.AmountPaid != 0 {
ctx.amtPaid = resp.AmountPaid
}
cancelSet = resp.CancelSet
})
if err != nil {
err := fmt.Errorf("error during invoice HTLC interception: %w",
err)
ctx.log(err.Error())
return nil, nil, err
}
// We'll attempt to settle an invoice matching this rHash on disk (if
// one exists). The callback will update the invoice state and/or htlcs.
var (
resolution HtlcResolution
updateSubscribers bool
)
callback := func(inv *Invoice) (*InvoiceUpdateDesc, error) {
// First check if this is a replayed htlc and resolve it
// according to its current state. We cannot decide differently
// once the HTLC has already been processed before.
isReplayed, res, err := resolveReplayedHtlc(ctx, inv)
if err != nil {
return nil, err
}
if isReplayed {
resolution = res
return nil, nil
}
// In case the HTLC interceptor cancels the HTLC set, we do NOT
// cancel the invoice however we cancel the complete HTLC set.
if cancelSet {
// If the invoice is not open, something is wrong, we
// fail just the HTLC with the specific error.
if inv.State != ContractOpen {
log.Errorf("Invoice state (%v) is not OPEN, "+
"cancelling HTLC set not allowed by "+
"external source", inv.State)
resolution = NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotOpen,
)
return nil, nil
}
// The error `ExternalValidationFailed` error
// information will be packed in the
// `FailIncorrectDetails` msg when sending the msg to
// the peer. Error codes are defined by the BOLT 04
// specification. The error text can be arbitrary
// therefore we return a custom error msg.
resolution = NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ExternalValidationFailed,
)
// We cancel all HTLCs which are in the accepted state.
//
// NOTE: The current HTLC is not included because it
// was never accepted in the first place.
htlcs := inv.HTLCSet(ctx.setID(), HtlcStateAccepted)
htlcKeys := fn.KeySet[CircuitKey](htlcs)
// The external source did cancel the htlc set, so we
// cancel all HTLCs in the set. We however keep the
// invoice in the open state.
//
// NOTE: The invoice event loop will still call the
// `cancelSingleHTLC` method for MPP payments, however
// because the HTLCs are already cancled back it will be
// a NOOP.
update := &InvoiceUpdateDesc{
UpdateType: CancelHTLCsUpdate,
CancelHtlcs: htlcKeys,
SetID: setID,
}
return update, nil
}
updateDesc, res, err := updateInvoice(ctx, inv)
if err != nil {
return nil, err
}
// Set resolution in outer scope only after successful update.
resolution = res
// Only send an update if the invoice state was changed.
updateSubscribers = updateDesc != nil &&
updateDesc.State != nil
return updateDesc, nil
}
invoice, err := i.idb.UpdateInvoice(
context.Background(), invoiceRef, setID, callback,
)
var duplicateSetIDErr ErrDuplicateSetID
if errors.As(err, &duplicateSetIDErr) {
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
}
switch {
case errors.Is(err, ErrInvoiceNotFound):
// If the invoice was not found, return a failure resolution
// with an invoice not found result.
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
case errors.Is(err, ErrInvRefEquivocation):
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
case err == nil:
default:
ctx.log(err.Error())
return nil, nil, err
}
var invoiceToExpire invoiceExpiry
log.Tracef("Settlement resolution: %T %v", resolution, resolution)
switch res := resolution.(type) {
case *HtlcFailResolution:
// Inspect latest htlc state on the invoice. If it is found,
// we will update the accept height as it was recorded in the
// invoice database (which occurs in the case where the htlc
// reached the database in a previous call). If the htlc was
// not found on the invoice, it was immediately failed so we
// send the failure resolution as is, which has the current
// height set as the accept height.
invoiceHtlc, ok := invoice.Htlcs[ctx.circuitKey]
if ok {
res.AcceptHeight = int32(invoiceHtlc.AcceptHeight)
}
ctx.log(fmt.Sprintf("failure resolution result "+
"outcome: %v, at accept height: %v",
res.Outcome, res.AcceptHeight))
// Some failures apply to the entire HTLC set. Break here if
// this isn't one of them.
if !res.Outcome.IsSetFailure() {
break
}
// Also cancel any HTLCs in the HTLC set that are also in the
// canceled state with the same failure result.
setID := ctx.setID()
canceledHtlcSet := invoice.HTLCSet(setID, HtlcStateCanceled)
for key, htlc := range canceledHtlcSet {
htlcFailResolution := NewFailResolution(
key, int32(htlc.AcceptHeight), res.Outcome,
)
i.notifyHodlSubscribers(htlcFailResolution)
}
// If the htlc was settled, we will settle any previously accepted
// htlcs and notify our peer to settle them.
case *HtlcSettleResolution:
ctx.log(fmt.Sprintf("settle resolution result "+
"outcome: %v, at accept height: %v",
res.Outcome, res.AcceptHeight))
// Also settle any previously accepted htlcs. If a htlc is
// marked as settled, we should follow now and settle the htlc
// with our peer.
setID := ctx.setID()
settledHtlcSet := invoice.HTLCSet(setID, HtlcStateSettled)
for key, htlc := range settledHtlcSet {
preimage := res.Preimage
if htlc.AMP != nil && htlc.AMP.Preimage != nil {
preimage = *htlc.AMP.Preimage
}
// Notify subscribers that the htlcs should be settled
// with our peer. Note that the outcome of the
// resolution is set based on the outcome of the single
// htlc that we just settled, so may not be accurate
// for all htlcs.
htlcSettleResolution := NewSettleResolution(
preimage, key,
int32(htlc.AcceptHeight), res.Outcome,
)
// Notify subscribers that the htlc should be settled
// with our peer.
i.notifyHodlSubscribers(htlcSettleResolution)
}
// If concurrent payments were attempted to this invoice before
// the current one was ultimately settled, cancel back any of
// the HTLCs immediately. As a result of the settle, the HTLCs
// in other HTLC sets are automatically converted to a canceled
// state when updating the invoice.
//
// TODO(roasbeef): can remove now??
canceledHtlcSet := invoice.HTLCSetCompliment(
setID, HtlcStateCanceled,
)
for key, htlc := range canceledHtlcSet {
htlcFailResolution := NewFailResolution(
key, int32(htlc.AcceptHeight),
ResultInvoiceAlreadySettled,
)
i.notifyHodlSubscribers(htlcFailResolution)
}
// If we accepted the htlc, subscribe to the hodl invoice and return
// an accept resolution with the htlc's accept time on it.
case *htlcAcceptResolution:
invoiceHtlc, ok := invoice.Htlcs[ctx.circuitKey]
if !ok {
return nil, nil, fmt.Errorf("accepted htlc: %v not"+
" present on invoice: %x", ctx.circuitKey,
ctx.hash[:])
}
// Determine accepted height of this htlc. If the htlc reached
// the invoice database (possibly in a previous call to the
// invoice registry), we'll take the original accepted height
// as it was recorded in the database.
acceptHeight := int32(invoiceHtlc.AcceptHeight)
ctx.log(fmt.Sprintf("accept resolution result "+
"outcome: %v, at accept height: %v",
res.outcome, acceptHeight))
// Auto-release the htlc if the invoice is still open. It can
// only happen for mpp payments that there are htlcs in state
// Accepted while the invoice is Open.
if invoice.State == ContractOpen {
res.acceptTime = invoiceHtlc.AcceptTime
res.autoRelease = true
}
// If we have fully accepted the set of htlcs for this invoice,
// we can now add it to our invoice expiry watcher. We do not
// add invoices before they are fully accepted, because it is
// possible that we MppTimeout the htlcs, and then our relevant
// expiry height could change.
if res.outcome == resultAccepted {
invoiceToExpire = makeInvoiceExpiry(ctx.hash, invoice)
}
// Subscribe to the resolution if the caller specified a
// notification channel.
if hodlChan != nil {
i.hodlSubscribe(hodlChan, ctx.circuitKey)
}
default:
panic("unknown action")
}
// Now that the links have been notified of any state changes to their
// HTLCs, we'll go ahead and notify any clients waiting on the invoice
// state changes.
if updateSubscribers {
// We'll add a setID onto the notification, but only if this is
// an AMP invoice being settled.
var setID *[32]byte
if _, ok := resolution.(*HtlcSettleResolution); ok {
setID = ctx.setID()
}
i.notifyClients(ctx.hash, invoice, setID)
}
return resolution, invoiceToExpire, nil
}
// SettleHodlInvoice sets the preimage of a hodl invoice.
func (i *InvoiceRegistry) SettleHodlInvoice(ctx context.Context,
preimage lntypes.Preimage) error {
i.Lock()
defer i.Unlock()
updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
switch invoice.State {
case ContractOpen:
return nil, ErrInvoiceStillOpen
case ContractCanceled:
return nil, ErrInvoiceAlreadyCanceled
case ContractSettled:
return nil, ErrInvoiceAlreadySettled
}
return &InvoiceUpdateDesc{
UpdateType: SettleHodlInvoiceUpdate,
State: &InvoiceStateUpdateDesc{
NewState: ContractSettled,
Preimage: &preimage,
},
}, nil
}
hash := preimage.Hash()
invoiceRef := InvoiceRefByHash(hash)
// AMP hold invoices are not supported so we set the setID to nil.
// For non-AMP invoices this parameter is ignored during the fetching
// of the database state.
setID := (*SetID)(nil)
invoice, err := i.idb.UpdateInvoice(
ctx, invoiceRef, setID, updateInvoice,
)
if err != nil {
log.Errorf("SettleHodlInvoice with preimage %v: %v",
preimage, err)
return err
}
log.Debugf("Invoice%v: settled with preimage %v", invoiceRef,
invoice.Terms.PaymentPreimage)
// In the callback, we marked the invoice as settled. UpdateInvoice will
// have seen this and should have moved all htlcs that were accepted to
// the settled state. In the loop below, we go through all of these and
// notify links and resolvers that are waiting for resolution. Any htlcs
// that were already settled before, will be notified again. This isn't
// necessary but doesn't hurt either.
for key, htlc := range invoice.Htlcs {
if htlc.State != HtlcStateSettled {
continue
}
resolution := NewSettleResolution(
preimage, key, int32(htlc.AcceptHeight), ResultSettled,
)
i.notifyHodlSubscribers(resolution)
}
i.notifyClients(hash, invoice, nil)
return nil
}
// CancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash.
func (i *InvoiceRegistry) CancelInvoice(ctx context.Context,
payHash lntypes.Hash) error {
return i.cancelInvoiceImpl(ctx, payHash, true)
}
// shouldCancel examines the state of an invoice and whether we want to
// cancel already accepted invoices, taking our force cancel boolean into
// account. This is pulled out into its own function so that tests that mock
// cancelInvoiceImpl can reuse this logic.
func shouldCancel(state ContractState, cancelAccepted bool) bool {
if state != ContractAccepted {
return true
}
// If the invoice is accepted, we should only cancel if we want to
// force cancellation of accepted invoices.
return cancelAccepted
}
// cancelInvoice attempts to cancel the invoice corresponding to the passed
// payment hash. Accepted invoices will only be canceled if explicitly
// requested to do so. It notifies subscribing links and resolvers that
// the associated htlcs were canceled if they change state.
func (i *InvoiceRegistry) cancelInvoiceImpl(ctx context.Context,
payHash lntypes.Hash, cancelAccepted bool) error {
i.Lock()
defer i.Unlock()
ref := InvoiceRefByHash(payHash)
log.Debugf("Invoice%v: canceling invoice", ref)
updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
if !shouldCancel(invoice.State, cancelAccepted) {
return nil, nil
}
// Move invoice to the canceled state. Rely on validation in
// channeldb to return an error if the invoice is already
// settled or canceled.
return &InvoiceUpdateDesc{
UpdateType: CancelInvoiceUpdate,
State: &InvoiceStateUpdateDesc{
NewState: ContractCanceled,
},
}, nil
}
// If it's an AMP invoice we need to fetch all AMP HTLCs here so that
// we can cancel all of HTLCs which are in the accepted state across
// different setIDs.
setID := (*SetID)(nil)
invoiceRef := InvoiceRefByHash(payHash)
invoice, err := i.idb.UpdateInvoice(
ctx, invoiceRef, setID, updateInvoice,
)
// Implement idempotency by returning success if the invoice was already
// canceled.
if errors.Is(err, ErrInvoiceAlreadyCanceled) {
log.Debugf("Invoice%v: already canceled", ref)
return nil
}
if err != nil {
return err
}
// Return without cancellation if the invoice state is ContractAccepted.
if invoice.State == ContractAccepted {
log.Debugf("Invoice%v: remains accepted as cancel wasn't"+
"explicitly requested.", ref)
return nil
}
log.Debugf("Invoice%v: canceled", ref)
// In the callback, some htlcs may have been moved to the canceled
// state. We now go through all of these and notify links and resolvers
// that are waiting for resolution. Any htlcs that were already canceled
// before, will be notified again. This isn't necessary but doesn't hurt
// either.
// For AMP invoices we fetched all AMP HTLCs for all sub AMP invoices
// here so we can clean up all of them.
for key, htlc := range invoice.Htlcs {
if htlc.State != HtlcStateCanceled {
continue
}
i.notifyHodlSubscribers(
NewFailResolution(
key, int32(htlc.AcceptHeight), ResultCanceled,
),
)
}
i.notifyClients(payHash, invoice, nil)
// Attempt to also delete the invoice if requested through the registry
// config.
if i.cfg.GcCanceledInvoicesOnTheFly {
// Assemble the delete reference and attempt to delete through
// the invocice from the DB.
deleteRef := InvoiceDeleteRef{
PayHash: payHash,
AddIndex: invoice.AddIndex,
SettleIndex: invoice.SettleIndex,
}
if invoice.Terms.PaymentAddr != BlankPayAddr {
deleteRef.PayAddr = &invoice.Terms.PaymentAddr
}
err = i.idb.DeleteInvoice(ctx, []InvoiceDeleteRef{deleteRef})
// If by any chance deletion failed, then log it instead of
// returning the error, as the invoice itself has already been
// canceled.
if err != nil {
log.Warnf("Invoice %v could not be deleted: %v", ref,
err)
}
}
return nil
}
// notifyClients notifies all currently registered invoice notification clients
// of a newly added/settled invoice.
func (i *InvoiceRegistry) notifyClients(hash lntypes.Hash,
invoice *Invoice, setID *[32]byte) {
event := &invoiceEvent{
invoice: invoice,
hash: hash,
setID: setID,
}
select {
case i.invoiceEvents <- event:
case <-i.quit:
}
}
// invoiceSubscriptionKit defines that are common to both all invoice
// subscribers and single invoice subscribers.
type invoiceSubscriptionKit struct {
id uint32 // nolint:structcheck
// quit is a chan mouted to InvoiceRegistry that signals a shutdown.
quit chan struct{}
ntfnQueue *queue.ConcurrentQueue
canceled uint32 // To be used atomically.
cancelChan chan struct{}
// backlogDelivered is closed when the backlog events have been
// delivered.
backlogDelivered chan struct{}
}
// InvoiceSubscription represents an intent to receive updates for newly added
// or settled invoices. For each newly added invoice, a copy of the invoice
// will be sent over the NewInvoices channel. Similarly, for each newly settled
// invoice, a copy of the invoice will be sent over the SettledInvoices
// channel.
type InvoiceSubscription struct {
invoiceSubscriptionKit
// NewInvoices is a channel that we'll use to send all newly created
// invoices with an invoice index greater than the specified
// StartingInvoiceIndex field.
NewInvoices chan *Invoice
// SettledInvoices is a channel that we'll use to send all settled
// invoices with an invoices index greater than the specified
// StartingInvoiceIndex field.
SettledInvoices chan *Invoice
// addIndex is the highest add index the caller knows of. We'll use
// this information to send out an event backlog to the notifications
// subscriber. Any new add events with an index greater than this will
// be dispatched before any new notifications are sent out.
addIndex uint64
// settleIndex is the highest settle index the caller knows of. We'll
// use this information to send out an event backlog to the
// notifications subscriber. Any new settle events with an index
// greater than this will be dispatched before any new notifications
// are sent out.
settleIndex uint64
}
// SingleInvoiceSubscription represents an intent to receive updates for a
// specific invoice.
type SingleInvoiceSubscription struct {
invoiceSubscriptionKit
invoiceRef InvoiceRef
// Updates is a channel that we'll use to send all invoice events for
// the invoice that is subscribed to.
Updates chan *Invoice
}
// PayHash returns the optional payment hash of the target invoice.
//
// TODO(positiveblue): This method is only supposed to be used in tests. It will
// be deleted as soon as invoiceregistery_test is in the same module.
func (s *SingleInvoiceSubscription) PayHash() *lntypes.Hash {
return s.invoiceRef.PayHash()
}
// Cancel unregisters the InvoiceSubscription, freeing any previously allocated
// resources.
func (i *invoiceSubscriptionKit) Cancel() {
if !atomic.CompareAndSwapUint32(&i.canceled, 0, 1) {
return
}
i.ntfnQueue.Stop()
close(i.cancelChan)
}
func (i *invoiceSubscriptionKit) notify(event *invoiceEvent) error {
select {
case i.ntfnQueue.ChanIn() <- event:
case <-i.cancelChan:
// This can only be triggered by delivery of non-backlog
// events.
return ErrShuttingDown
case <-i.quit:
return ErrShuttingDown
}
return nil
}
// SubscribeNotifications returns an InvoiceSubscription which allows the
// caller to receive async notifications when any invoices are settled or
// added. The invoiceIndex parameter is a streaming "checkpoint". We'll start
// by first sending out all new events with an invoice index _greater_ than
// this value. Afterwards, we'll send out real-time notifications.
func (i *InvoiceRegistry) SubscribeNotifications(ctx context.Context,
addIndex, settleIndex uint64) (*InvoiceSubscription, error) {
client := &InvoiceSubscription{
NewInvoices: make(chan *Invoice),
SettledInvoices: make(chan *Invoice),
addIndex: addIndex,
settleIndex: settleIndex,
invoiceSubscriptionKit: invoiceSubscriptionKit{
quit: i.quit,
ntfnQueue: queue.NewConcurrentQueue(20),
cancelChan: make(chan struct{}),
backlogDelivered: make(chan struct{}),
},
}
client.ntfnQueue.Start()
// This notifies other goroutines that the backlog phase is over.
defer close(client.backlogDelivered)
// Always increment by 1 first, and our client ID will start with 1,
// not 0.
client.id = atomic.AddUint32(&i.nextClientID, 1)
// Before we register this new invoice subscription, we'll launch a new
// goroutine that will proxy all notifications appended to the end of
// the concurrent queue to the two client-side channels the caller will
// feed off of.
i.wg.Add(1)
go func() {
defer i.wg.Done()
defer i.deleteClient(client.id)
for {
select {
// A new invoice event has been sent by the
// invoiceRegistry! We'll figure out if this is an add
// event or a settle event, then dispatch the event to
// the client.
case ntfn := <-client.ntfnQueue.ChanOut():
invoiceEvent := ntfn.(*invoiceEvent)
var targetChan chan *Invoice
state := invoiceEvent.invoice.State
switch {
// AMP invoices never move to settled, but will
// be sent with a set ID if an HTLC set is
// being settled.
case state == ContractOpen &&
invoiceEvent.setID != nil:
fallthrough
case state == ContractSettled:
targetChan = client.SettledInvoices
case state == ContractOpen:
targetChan = client.NewInvoices
default:
log.Errorf("unknown invoice state: %v",
state)
continue
}
select {
case targetChan <- invoiceEvent.invoice:
case <-client.cancelChan:
return
case <-i.quit:
return
}
case <-client.cancelChan:
return
case <-i.quit:
return
}
}
}()
i.notificationClientMux.Lock()
i.notificationClients[client.id] = client
i.notificationClientMux.Unlock()
// Query the database to see if based on the provided addIndex and
// settledIndex we need to deliver any backlog notifications.
err := i.deliverBacklogEvents(ctx, client)
if err != nil {
return nil, err
}
log.Infof("New invoice subscription client: id=%v", client.id)
return client, nil
}
// SubscribeSingleInvoice returns an SingleInvoiceSubscription which allows the
// caller to receive async notifications for a specific invoice.
func (i *InvoiceRegistry) SubscribeSingleInvoice(ctx context.Context,
hash lntypes.Hash) (*SingleInvoiceSubscription, error) {
client := &SingleInvoiceSubscription{
Updates: make(chan *Invoice),
invoiceSubscriptionKit: invoiceSubscriptionKit{
quit: i.quit,
ntfnQueue: queue.NewConcurrentQueue(20),
cancelChan: make(chan struct{}),
backlogDelivered: make(chan struct{}),
},
invoiceRef: InvoiceRefByHash(hash),
}
client.ntfnQueue.Start()
// This notifies other goroutines that the backlog phase is done.
defer close(client.backlogDelivered)
// Always increment by 1 first, and our client ID will start with 1,
// not 0.
client.id = atomic.AddUint32(&i.nextClientID, 1)
// Before we register this new invoice subscription, we'll launch a new
// goroutine that will proxy all notifications appended to the end of
// the concurrent queue to the two client-side channels the caller will
// feed off of.
i.wg.Add(1)
go func() {
defer i.wg.Done()
defer i.deleteClient(client.id)
for {
select {
// A new invoice event has been sent by the
// invoiceRegistry. We will dispatch the event to the
// client.
case ntfn := <-client.ntfnQueue.ChanOut():
invoiceEvent := ntfn.(*invoiceEvent)
select {
case client.Updates <- invoiceEvent.invoice:
case <-client.cancelChan:
return
case <-i.quit:
return
}
case <-client.cancelChan:
return
case <-i.quit:
return
}
}
}()
i.notificationClientMux.Lock()
i.singleNotificationClients[client.id] = client
i.notificationClientMux.Unlock()
err := i.deliverSingleBacklogEvents(ctx, client)
if err != nil {
return nil, err
}
log.Infof("New single invoice subscription client: id=%v, ref=%v",
client.id, client.invoiceRef)
return client, nil
}
// notifyHodlSubscribers sends out the htlc resolution to all current
// subscribers.
func (i *InvoiceRegistry) notifyHodlSubscribers(htlcResolution HtlcResolution) {
i.hodlSubscriptionsMux.Lock()
defer i.hodlSubscriptionsMux.Unlock()
subscribers, ok := i.hodlSubscriptions[htlcResolution.CircuitKey()]
if !ok {
return
}
// Notify all interested subscribers and remove subscription from both
// maps. The subscription can be removed as there only ever will be a
// single resolution for each hash.
for subscriber := range subscribers {
select {
case subscriber <- htlcResolution:
case <-i.quit:
return
}
delete(
i.hodlReverseSubscriptions[subscriber],
htlcResolution.CircuitKey(),
)
}
delete(i.hodlSubscriptions, htlcResolution.CircuitKey())
}
// hodlSubscribe adds a new invoice subscription.
func (i *InvoiceRegistry) hodlSubscribe(subscriber chan<- interface{},
circuitKey CircuitKey) {
i.hodlSubscriptionsMux.Lock()
defer i.hodlSubscriptionsMux.Unlock()
log.Debugf("Hodl subscribe for %v", circuitKey)
subscriptions, ok := i.hodlSubscriptions[circuitKey]
if !ok {
subscriptions = make(map[chan<- interface{}]struct{})
i.hodlSubscriptions[circuitKey] = subscriptions
}
subscriptions[subscriber] = struct{}{}
reverseSubscriptions, ok := i.hodlReverseSubscriptions[subscriber]
if !ok {
reverseSubscriptions = make(map[CircuitKey]struct{})
i.hodlReverseSubscriptions[subscriber] = reverseSubscriptions
}
reverseSubscriptions[circuitKey] = struct{}{}
}
// HodlUnsubscribeAll cancels the subscription.
func (i *InvoiceRegistry) HodlUnsubscribeAll(subscriber chan<- interface{}) {
i.hodlSubscriptionsMux.Lock()
defer i.hodlSubscriptionsMux.Unlock()
hashes := i.hodlReverseSubscriptions[subscriber]
for hash := range hashes {
delete(i.hodlSubscriptions[hash], subscriber)
}
delete(i.hodlReverseSubscriptions, subscriber)
}
// copySingleClients copies i.SingleInvoiceSubscription inside a lock. This is
// useful when we need to iterate the map to send notifications.
func (i *InvoiceRegistry) copySingleClients() map[uint32]*SingleInvoiceSubscription { //nolint:ll
i.notificationClientMux.RLock()
defer i.notificationClientMux.RUnlock()
clients := make(map[uint32]*SingleInvoiceSubscription)
for k, v := range i.singleNotificationClients {
clients[k] = v
}
return clients
}
// copyClients copies i.notificationClients inside a lock. This is useful when
// we need to iterate the map to send notifications.
func (i *InvoiceRegistry) copyClients() map[uint32]*InvoiceSubscription {
i.notificationClientMux.RLock()
defer i.notificationClientMux.RUnlock()
clients := make(map[uint32]*InvoiceSubscription)
for k, v := range i.notificationClients {
clients[k] = v
}
return clients
}
// deleteClient removes a client by its ID inside a lock. Noop if the client is
// not found.
func (i *InvoiceRegistry) deleteClient(clientID uint32) {
i.notificationClientMux.Lock()
defer i.notificationClientMux.Unlock()
log.Infof("Cancelling invoice subscription for client=%v", clientID)
delete(i.notificationClients, clientID)
delete(i.singleNotificationClients, clientID)
}
package invoices
import (
"errors"
"fmt"
"strings"
"time"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
const (
// MaxMemoSize is maximum size of the memo field within invoices stored
// in the database.
MaxMemoSize = 1024
// MaxPaymentRequestSize is the max size of a payment request for
// this invoice.
// TODO(halseth): determine the max length payment request when field
// lengths are final.
MaxPaymentRequestSize = 4096
)
var (
// unknownPreimage is an all-zeroes preimage that indicates that the
// preimage for this invoice is not yet known.
UnknownPreimage lntypes.Preimage
// BlankPayAddr is a sentinel payment address for legacy invoices.
// Invoices with this payment address are special-cased in the insertion
// logic to prevent being indexed in the payment address index,
// otherwise they would cause collisions after the first insertion.
BlankPayAddr [32]byte
)
// RefModifier is a modification on top of a base invoice ref. It allows the
// caller to opt to skip out on HTLCs for a given payAddr, or only return the
// set of specified HTLCs for a given setID.
type RefModifier uint8
const (
// DefaultModifier is the base modifier that doesn't change any
// behavior.
DefaultModifier RefModifier = iota
// HtlcSetOnlyModifier can only be used with a setID based invoice ref,
// and specifies that only the set of HTLCs related to that setID are
// to be returned.
HtlcSetOnlyModifier
// HtlcSetOnlyModifier can only be used with a payAddr based invoice
// ref, and specifies that the returned invoice shouldn't include any
// HTLCs at all.
HtlcSetBlankModifier
)
// InvoiceRef is a composite identifier for invoices. Invoices can be referenced
// by various combinations of payment hash and payment addr, in certain contexts
// only some of these are known. An InvoiceRef and its constructors thus
// encapsulate the valid combinations of query parameters that can be supplied
// to LookupInvoice and UpdateInvoice.
type InvoiceRef struct {
// payHash is the payment hash of the target invoice. All invoices are
// currently indexed by payment hash. This value will be used as a
// fallback when no payment address is known.
payHash *lntypes.Hash
// payAddr is the payment addr of the target invoice. Newer invoices
// (0.11 and up) are indexed by payment address in addition to payment
// hash, but pre 0.8 invoices do not have one at all. When this value is
// known it will be used as the primary identifier, falling back to
// payHash if no value is known.
payAddr *[32]byte
// setID is the optional set id for an AMP payment. This can be used to
// lookup or update the invoice knowing only this value. Queries by set
// id are only used to facilitate user-facing requests, e.g. lookup,
// settle or cancel an AMP invoice. The regular update flow from the
// invoice registry will always query for the invoice by
// payHash+payAddr.
setID *[32]byte
// refModifier allows an invoice ref to include or exclude specific
// HTLC sets based on the payAddr or setId.
refModifier RefModifier
}
// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by
// its payment hash.
func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef {
return InvoiceRef{
payHash: &payHash,
}
}
// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an
// invoice by the provided payment address, falling back to the payment hash if
// the payment address is unknown.
func InvoiceRefByHashAndAddr(payHash lntypes.Hash,
payAddr [32]byte) InvoiceRef {
return InvoiceRef{
payHash: &payHash,
payAddr: &payAddr,
}
}
// InvoiceRefByAddr creates an InvoiceRef that queries the payment addr index
// for an invoice with the provided payment address.
func InvoiceRefByAddr(addr [32]byte) InvoiceRef {
return InvoiceRef{
payAddr: &addr,
}
}
// InvoiceRefByAddrBlankHtlc creates an InvoiceRef that queries the payment addr
// index for an invoice with the provided payment address, but excludes any of
// the core HTLC information.
func InvoiceRefByAddrBlankHtlc(addr [32]byte) InvoiceRef {
return InvoiceRef{
payAddr: &addr,
refModifier: HtlcSetBlankModifier,
}
}
// InvoiceRefBySetID creates an InvoiceRef that queries the set id index for an
// invoice with the provided setID. If the invoice is not found, the query will
// not fallback to payHash or payAddr.
func InvoiceRefBySetID(setID [32]byte) InvoiceRef {
return InvoiceRef{
setID: &setID,
}
}
// InvoiceRefBySetIDFiltered is similar to the InvoiceRefBySetID identifier,
// but it specifies that the returned set of HTLCs should be filtered to only
// include HTLCs that are part of that set.
func InvoiceRefBySetIDFiltered(setID [32]byte) InvoiceRef {
return InvoiceRef{
setID: &setID,
refModifier: HtlcSetOnlyModifier,
}
}
// PayHash returns the optional payment hash of the target invoice.
//
// NOTE: This value may be nil.
func (r InvoiceRef) PayHash() *lntypes.Hash {
if r.payHash != nil {
hash := *r.payHash
return &hash
}
return nil
}
// PayAddr returns the optional payment address of the target invoice.
//
// NOTE: This value may be nil.
func (r InvoiceRef) PayAddr() *[32]byte {
if r.payAddr != nil {
addr := *r.payAddr
return &addr
}
return nil
}
// SetID returns the optional set id of the target invoice.
//
// NOTE: This value may be nil.
func (r InvoiceRef) SetID() *[32]byte {
if r.setID != nil {
id := *r.setID
return &id
}
return nil
}
// Modifier defines the set of available modifications to the base invoice ref
// look up that are available.
func (r InvoiceRef) Modifier() RefModifier {
return r.refModifier
}
// IsHashOnly returns true if the invoice ref only contains a payment hash.
func (r InvoiceRef) IsHashOnly() bool {
return r.payHash != nil && r.payAddr == nil && r.setID == nil
}
// String returns a human-readable representation of an InvoiceRef.
func (r InvoiceRef) String() string {
var ids []string
if r.payHash != nil {
ids = append(ids, fmt.Sprintf("pay_hash=%v", *r.payHash))
}
if r.payAddr != nil {
ids = append(ids, fmt.Sprintf("pay_addr=%x", *r.payAddr))
}
if r.setID != nil {
ids = append(ids, fmt.Sprintf("set_id=%x", *r.setID))
}
return fmt.Sprintf("(%s)", strings.Join(ids, ", "))
}
// ContractState describes the state the invoice is in.
type ContractState uint8
const (
// ContractOpen means the invoice has only been created.
ContractOpen ContractState = 0
// ContractSettled means the htlc is settled and the invoice has been
// paid.
ContractSettled ContractState = 1
// ContractCanceled means the invoice has been canceled.
ContractCanceled ContractState = 2
// ContractAccepted means the HTLC has been accepted but not settled
// yet.
ContractAccepted ContractState = 3
)
// String returns a human readable identifier for the ContractState type.
func (c ContractState) String() string {
switch c {
case ContractOpen:
return "Open"
case ContractSettled:
return "Settled"
case ContractCanceled:
return "Canceled"
case ContractAccepted:
return "Accepted"
}
return "Unknown"
}
// IsFinal returns a boolean indicating whether an invoice state is final.
func (c ContractState) IsFinal() bool {
return c == ContractSettled || c == ContractCanceled
}
// ContractTerm is a companion struct to the Invoice struct. This struct houses
// the necessary conditions required before the invoice can be considered fully
// settled by the payee.
type ContractTerm struct {
// FinalCltvDelta is the minimum required number of blocks before htlc
// expiry when the invoice is accepted.
FinalCltvDelta int32
// Expiry defines how long after creation this invoice should expire.
Expiry time.Duration
// PaymentPreimage is the preimage which is to be revealed in the
// occasion that an HTLC paying to the hash of this preimage is
// extended. Set to nil if the preimage isn't known yet.
PaymentPreimage *lntypes.Preimage
// Value is the expected amount of milli-satoshis to be paid to an HTLC
// which can be satisfied by the above preimage.
Value lnwire.MilliSatoshi
// PaymentAddr is a randomly generated value include in the MPP record
// by the sender to prevent probing of the receiver.
PaymentAddr [32]byte
// Features is the feature vectors advertised on the payment request.
Features *lnwire.FeatureVector
}
// String returns a human-readable description of the prominent contract terms.
func (c ContractTerm) String() string {
return fmt.Sprintf("amt=%v, expiry=%v, final_cltv_delta=%v", c.Value,
c.Expiry, c.FinalCltvDelta)
}
// SetID is the extra unique tuple item for AMP invoices. In addition to
// setting a payment address, each repeated payment to an AMP invoice will also
// contain a set ID as well.
type SetID [32]byte
// InvoiceStateAMP is a struct that associates the current state of an AMP
// invoice identified by its set ID along with the set of invoices identified
// by the circuit key. This allows callers to easily look up the latest state
// of an AMP "sub-invoice" and also look up the invoice HLTCs themselves in the
// greater HTLC map index.
type InvoiceStateAMP struct {
// State is the state of this sub-AMP invoice.
State HtlcState
// SettleIndex indicates the location in the settle index that
// references this instance of InvoiceStateAMP, but only if
// this value is set (non-zero), and State is HtlcStateSettled.
SettleIndex uint64
// SettleDate is the date that the setID was settled.
SettleDate time.Time
// InvoiceKeys is the set of circuit keys that can be used to locate
// the invoices for a given set ID.
InvoiceKeys map[CircuitKey]struct{}
// AmtPaid is the total amount that was paid in the AMP sub-invoice.
// Fetching the full HTLC/invoice state allows one to extract the
// custom records as well as the break down of the payment splits used
// when paying.
AmtPaid lnwire.MilliSatoshi
}
// copy makes a deep copy of the underlying InvoiceStateAMP.
func (i *InvoiceStateAMP) copy() (InvoiceStateAMP, error) {
result := *i
// Make a copy of the InvoiceKeys map.
result.InvoiceKeys = make(map[CircuitKey]struct{})
for k := range i.InvoiceKeys {
result.InvoiceKeys[k] = struct{}{}
}
// As a safety measure, copy SettleDate. time.Time is concurrency safe
// except when using any of the (un)marshalling methods.
settleDateBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return InvoiceStateAMP{}, err
}
err = result.SettleDate.UnmarshalBinary(settleDateBytes)
if err != nil {
return InvoiceStateAMP{}, err
}
return result, nil
}
// AMPInvoiceState represents a type that stores metadata related to the set of
// settled AMP "sub-invoices".
type AMPInvoiceState map[SetID]InvoiceStateAMP
// Invoice is a payment invoice generated by a payee in order to request
// payment for some good or service. The inclusion of invoices within Lightning
// creates a payment work flow for merchants very similar to that of the
// existing financial system within PayPal, etc. Invoices are added to the
// database when a payment is requested, then can be settled manually once the
// payment is received at the upper layer. For record keeping purposes,
// invoices are never deleted from the database, instead a bit is toggled
// denoting the invoice has been fully settled. Within the database, all
// invoices must have a unique payment hash which is generated by taking the
// sha256 of the payment preimage.
type Invoice struct {
// Memo is an optional memo to be stored along side an invoice. The
// memo may contain further details pertaining to the invoice itself,
// or any other message which fits within the size constraints.
Memo []byte
// PaymentRequest is the encoded payment request for this invoice. For
// spontaneous (keysend) payments, this field will be empty.
PaymentRequest []byte
// CreationDate is the exact time the invoice was created.
CreationDate time.Time
// SettleDate is the exact time the invoice was settled.
SettleDate time.Time
// Terms are the contractual payment terms of the invoice. Once all the
// terms have been satisfied by the payer, then the invoice can be
// considered fully fulfilled.
//
// TODO(roasbeef): later allow for multiple terms to fulfill the final
// invoice: payment fragmentation, etc.
Terms ContractTerm
// AddIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all invoices created.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been added before they re-connected.
//
// NOTE: This index starts at 1.
AddIndex uint64
// SettleIndex is an auto-incrementing integer that acts as a
// monotonically increasing sequence number for all settled invoices.
// Clients can then use this field as a "checkpoint" of sorts when
// implementing a streaming RPC to notify consumers of instances where
// an invoice has been settled before they re-connected.
//
// NOTE: This index starts at 1.
SettleIndex uint64
// State describes the state the invoice is in. This is the global
// state of the invoice which may remain open even when a series of
// sub-invoices for this invoice has been settled.
State ContractState
// AmtPaid is the final amount that we ultimately accepted for pay for
// this invoice. We specify this value independently as it's possible
// that the invoice originally didn't specify an amount, or the sender
// overpaid.
AmtPaid lnwire.MilliSatoshi
// Htlcs records all htlcs that paid to this invoice. Some of these
// htlcs may have been marked as canceled.
Htlcs map[CircuitKey]*InvoiceHTLC
// AMPState describes the state of any related sub-invoices AMP to this
// greater invoice. A sub-invoice is defined by a set of HTLCs with the
// same set ID that attempt to make one time or recurring payments to
// this greater invoice. It's possible for a sub-invoice to be canceled
// or settled, but the greater invoice still open.
AMPState AMPInvoiceState
// HodlInvoice indicates whether the invoice should be held in the
// Accepted state or be settled right away.
HodlInvoice bool
}
// HTLCSet returns the set of HTLCs belonging to setID and in the provided
// state. Passing a nil setID will return all HTLCs in the provided state in the
// case of legacy or MPP, and no HTLCs in the case of AMP. Otherwise, the
// returned set will be filtered by the populated setID which is used to
// retrieve AMP HTLC sets.
func (i *Invoice) HTLCSet(setID *[32]byte,
state HtlcState) map[CircuitKey]*InvoiceHTLC {
htlcSet := make(map[CircuitKey]*InvoiceHTLC)
for key, htlc := range i.Htlcs {
// Only add HTLCs that are in the requested HtlcState.
if htlc.State != state {
continue
}
if !htlc.IsInHTLCSet(setID) {
continue
}
htlcSet[key] = htlc
}
return htlcSet
}
// HTLCSetCompliment returns the set of all HTLCs not belonging to setID that
// are in the target state. Passing a nil setID will return no invoices, since
// all MPP HTLCs are part of the same HTLC set.
func (i *Invoice) HTLCSetCompliment(setID *[32]byte,
state HtlcState) map[CircuitKey]*InvoiceHTLC {
htlcSet := make(map[CircuitKey]*InvoiceHTLC)
for key, htlc := range i.Htlcs {
// Only add HTLCs that are in the requested HtlcState.
if htlc.State != state {
continue
}
// We are constructing the compliment, so filter anything that
// matches this set id.
if htlc.IsInHTLCSet(setID) {
continue
}
htlcSet[key] = htlc
}
return htlcSet
}
// IsKeysend returns true if the invoice is a Keysend invoice.
func (i *Invoice) IsKeysend() bool {
// TODO(positiveblue): look for a more reliable way to tests if
// an invoice is keysend.
return len(i.PaymentRequest) == 0 && !i.IsAMP()
}
// IsAMP returns true if the invoice is an AMP invoice.
func (i *Invoice) IsAMP() bool {
if i.Terms.Features == nil {
return false
}
return i.Terms.Features.HasFeature(
lnwire.AMPRequired,
)
}
// IsBlinded returns true if the invoice contains blinded paths.
func (i *Invoice) IsBlinded() bool {
if i.Terms.Features == nil {
return false
}
return i.Terms.Features.IsSet(lnwire.Bolt11BlindedPathsRequired)
}
// HtlcState defines the states an htlc paying to an invoice can be in.
type HtlcState uint8
const (
// HtlcStateAccepted indicates the htlc is locked-in, but not resolved.
HtlcStateAccepted HtlcState = iota
// HtlcStateCanceled indicates the htlc is canceled back to the
// sender.
HtlcStateCanceled
// HtlcStateSettled indicates the htlc is settled.
HtlcStateSettled
)
// InvoiceHTLC contains details about an htlc paying to this invoice.
type InvoiceHTLC struct {
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi
// AcceptHeight is the block height at which the invoice registry
// decided to accept this htlc as a payment to the invoice. At this
// height, the invoice cltv delay must have been met.
AcceptHeight uint32
// AcceptTime is the wall clock time at which the invoice registry
// decided to accept the htlc.
AcceptTime time.Time
// ResolveTime is the wall clock time at which the invoice registry
// decided to settle the htlc.
ResolveTime time.Time
// Expiry is the expiry height of this htlc.
Expiry uint32
// State indicates the state the invoice htlc is currently in. A
// canceled htlc isn't just removed from the invoice htlcs map, because
// we need AcceptHeight to properly cancel the htlc back.
State HtlcState
// CustomRecords contains the custom key/value pairs that accompanied
// the htlc.
CustomRecords record.CustomSet
// WireCustomRecords contains the custom key/value pairs that were only
// included in p2p wire message of the HTLC and not in the onion
// payload.
WireCustomRecords lnwire.CustomRecords
// AMP encapsulates additional data relevant to AMP HTLCs. This includes
// the AMP onion record, in addition to the HTLC's payment hash and
// preimage since these are unique to each AMP HTLC, and not the invoice
// as a whole.
//
// NOTE: This value will only be set for AMP HTLCs.
AMP *InvoiceHtlcAMPData
}
// Copy makes a deep copy of the target InvoiceHTLC.
func (h *InvoiceHTLC) Copy() *InvoiceHTLC {
result := *h
// Make a copy of the CustomSet map.
result.CustomRecords = make(record.CustomSet)
for k, v := range h.CustomRecords {
result.CustomRecords[k] = v
}
result.WireCustomRecords = h.WireCustomRecords.Copy()
result.AMP = h.AMP.Copy()
return &result
}
// IsInHTLCSet returns true if this HTLC is part an HTLC set. If nil is passed,
// this method returns true if this is an MPP HTLC. Otherwise, it only returns
// true if the AMP HTLC's set id matches the populated setID.
func (h *InvoiceHTLC) IsInHTLCSet(setID *[32]byte) bool {
wantAMPSet := setID != nil
isAMPHtlc := h.AMP != nil
// Non-AMP HTLCs cannot be part of AMP HTLC sets, and vice versa.
if wantAMPSet != isAMPHtlc {
return false
}
// Skip AMP HTLCs that have differing set ids.
if isAMPHtlc && *setID != h.AMP.Record.SetID() {
return false
}
return true
}
// InvoiceHtlcAMPData is a struct hodling the additional metadata stored for
// each received AMP HTLC. This includes the AMP onion record, in addition to
// the HTLC's payment hash and preimage.
type InvoiceHtlcAMPData struct {
// AMP is a copy of the AMP record presented in the onion payload
// containing the information necessary to correlate and settle a
// spontaneous HTLC set. Newly accepted legacy keysend payments will
// also have this field set as we automatically promote them into an AMP
// payment for internal processing.
Record record.AMP
// Hash is an HTLC-level payment hash that is stored only for AMP
// payments. This is done because an AMP HTLC will carry a different
// payment hash from the invoice it might be satisfying, so we track the
// payment hashes individually to able to compute whether or not the
// reconstructed preimage correctly matches the HTLC's hash.
Hash lntypes.Hash
// Preimage is an HTLC-level preimage that satisfies the AMP HTLC's
// Hash. The preimage will be derived either from secret share
// reconstruction of the shares in the AMP payload.
//
// NOTE: Preimage will only be present once the HTLC is in
// HtlcStateSettled.
Preimage *lntypes.Preimage
}
// Copy returns a deep copy of the InvoiceHtlcAMPData.
func (d *InvoiceHtlcAMPData) Copy() *InvoiceHtlcAMPData {
if d == nil {
return nil
}
var preimage *lntypes.Preimage
if d.Preimage != nil {
pimg := *d.Preimage
preimage = &pimg
}
return &InvoiceHtlcAMPData{
Record: d.Record,
Hash: d.Hash,
Preimage: preimage,
}
}
// HtlcAcceptDesc describes the details of a newly accepted htlc.
type HtlcAcceptDesc struct {
// AcceptHeight is the block height at which this htlc was accepted.
AcceptHeight int32
// Amt is the amount that is carried by this htlc.
Amt lnwire.MilliSatoshi
// MppTotalAmt is a field for mpp that indicates the expected total
// amount.
MppTotalAmt lnwire.MilliSatoshi
// Expiry is the expiry height of this htlc.
Expiry uint32
// CustomRecords contains the custom key/value pairs that accompanied
// the htlc.
CustomRecords record.CustomSet
// AMP encapsulates additional data relevant to AMP HTLCs. This includes
// the AMP onion record, in addition to the HTLC's payment hash and
// preimage since these are unique to each AMP HTLC, and not the invoice
// as a whole.
//
// NOTE: This value will only be set for AMP HTLCs.
AMP *InvoiceHtlcAMPData
}
// UpdateType is an enum that describes the type of update that was applied to
// an invoice.
type UpdateType uint8
const (
// UnknownUpdate indicates that the UpdateType has not been set for a
// given udpate. This kind of updates are not allowed.
UnknownUpdate UpdateType = iota
// CancelHTLCsUpdate indicates that this update cancels one or more
// HTLCs.
CancelHTLCsUpdate
// AddHTLCsUpdate indicates that this update adds one or more HTLCs.
AddHTLCsUpdate
// SettleHodlInvoiceUpdate indicates that this update settles one or
// more HTLCs from a hodl invoice.
SettleHodlInvoiceUpdate
// CancelInvoiceUpdate indicates that this update is trying to cancel
// an invoice.
CancelInvoiceUpdate
)
// String returns a human readable string for the UpdateType.
func (u UpdateType) String() string {
switch u {
case CancelHTLCsUpdate:
return "CancelHTLCsUpdate"
case AddHTLCsUpdate:
return "AddHTLCsUpdate"
case SettleHodlInvoiceUpdate:
return "SettleHodlInvoiceUpdate"
case CancelInvoiceUpdate:
return "CancelInvoiceUpdate"
default:
return fmt.Sprintf("unknown invoice update type: %d", u)
}
}
// InvoiceUpdateDesc describes the changes that should be applied to the
// invoice.
type InvoiceUpdateDesc struct {
// State is the new state that this invoice should progress to. If nil,
// the state is left unchanged.
State *InvoiceStateUpdateDesc
// CancelHtlcs describes the htlcs that need to be canceled.
CancelHtlcs map[CircuitKey]struct{}
// AddHtlcs describes the newly accepted htlcs that need to be added to
// the invoice.
AddHtlcs map[CircuitKey]*HtlcAcceptDesc
// SetID is an optional set ID for AMP invoices that allows operations
// to be more efficient by ensuring we don't need to read out the
// entire HTLC set each timee an HTLC is to be cancelled.
SetID *SetID
// UpdateType indicates what type of update is being applied.
UpdateType UpdateType
}
// InvoiceStateUpdateDesc describes an invoice-level state transition.
type InvoiceStateUpdateDesc struct {
// NewState is the new state that this invoice should progress to.
NewState ContractState
// Preimage must be set to the preimage when NewState is settled.
Preimage *lntypes.Preimage
// HTLCPreimages set the HTLC-level preimages stored for AMP HTLCs.
// These are only learned when settling the invoice as a whole. Must be
// set when settling an invoice with non-nil SetID.
HTLCPreimages map[CircuitKey]lntypes.Preimage
// SetID identifies a specific set of HTLCs destined for the same
// invoice as part of a larger AMP payment. This value will be nil for
// legacy or MPP payments.
SetID *[32]byte
}
// InvoiceUpdateCallback is a callback used in the db transaction to update the
// invoice.
// TODO(ziggie): Add the option of additional return values to the callback
// for example the resolution which is currently assigned via an outer scope
// variable.
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
// ValidateInvoice assures the invoice passes the checks for all the relevant
// constraints.
func ValidateInvoice(i *Invoice, paymentHash lntypes.Hash) error {
// Avoid conflicts with all-zeroes magic value in the database.
if paymentHash == UnknownPreimage.Hash() {
return fmt.Errorf("cannot use hash of all-zeroes preimage")
}
if len(i.Memo) > MaxMemoSize {
return fmt.Errorf("max length a memo is %v, and invoice "+
"of length %v was provided", MaxMemoSize, len(i.Memo))
}
if len(i.PaymentRequest) > MaxPaymentRequestSize {
return fmt.Errorf("max length of payment request is %v, "+
"length provided was %v", MaxPaymentRequestSize,
len(i.PaymentRequest))
}
if i.Terms.Features == nil {
return errors.New("invoice must have a feature vector")
}
err := feature.ValidateDeps(i.Terms.Features)
if err != nil {
return err
}
if i.requiresPreimage() && i.Terms.PaymentPreimage == nil {
return errors.New("this invoice must have a preimage")
}
if len(i.Htlcs) > 0 {
return ErrInvoiceHasHtlcs
}
return nil
}
// requiresPreimage returns true if the invoice requires a preimage to be valid.
func (i *Invoice) requiresPreimage() bool {
// AMP invoices and hodl invoices are allowed to have no preimage
// specified.
if i.HodlInvoice || i.IsAMP() {
return false
}
return true
}
// IsPending returns true if the invoice is in ContractOpen state.
func (i *Invoice) IsPending() bool {
return i.State == ContractOpen || i.State == ContractAccepted
}
// copySlice allocates a new slice and copies the source into it.
func copySlice(src []byte) []byte {
dest := make([]byte, len(src))
copy(dest, src)
return dest
}
// CopyInvoice makes a deep copy of the supplied invoice.
func CopyInvoice(src *Invoice) (*Invoice, error) {
dest := Invoice{
Memo: copySlice(src.Memo),
PaymentRequest: copySlice(src.PaymentRequest),
CreationDate: src.CreationDate,
SettleDate: src.SettleDate,
Terms: src.Terms,
AddIndex: src.AddIndex,
SettleIndex: src.SettleIndex,
State: src.State,
AmtPaid: src.AmtPaid,
Htlcs: make(
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
),
AMPState: make(map[SetID]InvoiceStateAMP),
HodlInvoice: src.HodlInvoice,
}
dest.Terms.Features = src.Terms.Features.Clone()
if src.Terms.PaymentPreimage != nil {
preimage := *src.Terms.PaymentPreimage
dest.Terms.PaymentPreimage = &preimage
}
for k, v := range src.Htlcs {
dest.Htlcs[k] = v.Copy()
}
// Lastly, copy the amp invoice state.
for k, v := range src.AMPState {
ampInvState, err := v.copy()
if err != nil {
return nil, err
}
dest.AMPState[k] = ampInvState
}
return &dest, nil
}
// InvoiceDeleteRef holds a reference to an invoice to be deleted.
type InvoiceDeleteRef struct {
// PayHash is the payment hash of the target invoice. All invoices are
// currently indexed by payment hash.
PayHash lntypes.Hash
// PayAddr is the payment addr of the target invoice. Newer invoices
// (0.11 and up) are indexed by payment address in addition to payment
// hash, but pre 0.8 invoices do not have one at all.
PayAddr *[32]byte
// AddIndex is the add index of the invoice.
AddIndex uint64
// SettleIndex is the settle index of the invoice.
SettleIndex uint64
}
package invoices
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("INVC", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package invoices
import (
"context"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/mock"
)
type MockInvoiceDB struct {
mock.Mock
}
func NewInvoicesDBMock() *MockInvoiceDB {
return &MockInvoiceDB{}
}
func (m *MockInvoiceDB) AddInvoice(invoice *Invoice,
paymentHash lntypes.Hash) (uint64, error) {
args := m.Called(invoice, paymentHash)
addIndex, _ := args.Get(0).(uint64)
// NOTE: this is a side effect of the AddInvoice method.
invoice.AddIndex = addIndex
return addIndex, args.Error(1)
}
func (m *MockInvoiceDB) InvoicesAddedSince(idx uint64) ([]Invoice, error) {
args := m.Called(idx)
invoices, _ := args.Get(0).([]Invoice)
return invoices, args.Error(1)
}
func (m *MockInvoiceDB) InvoicesSettledSince(idx uint64) ([]Invoice, error) {
args := m.Called(idx)
invoices, _ := args.Get(0).([]Invoice)
return invoices, args.Error(1)
}
func (m *MockInvoiceDB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
args := m.Called(ref)
invoice, _ := args.Get(0).(Invoice)
return invoice, args.Error(1)
}
func (m *MockInvoiceDB) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]Invoice, error) {
args := m.Called(ctx)
return args.Get(0).(map[lntypes.Hash]Invoice), args.Error(1)
}
func (m *MockInvoiceDB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
args := m.Called(q)
invoiceSlice, _ := args.Get(0).(InvoiceSlice)
return invoiceSlice, args.Error(1)
}
func (m *MockInvoiceDB) UpdateInvoice(ref InvoiceRef, setIDHint *SetID,
callback InvoiceUpdateCallback) (*Invoice, error) {
args := m.Called(ref, setIDHint, callback)
invoice, _ := args.Get(0).(*Invoice)
return invoice, args.Error(1)
}
func (m *MockInvoiceDB) DeleteInvoice(invoices []InvoiceDeleteRef) error {
args := m.Called(invoices)
return args.Error(0)
}
func (m *MockInvoiceDB) DeleteCanceledInvoices(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
// MockHtlcModifier is a mock implementation of the HtlcModifier interface.
type MockHtlcModifier struct {
mock.Mock
}
// Intercept generates a new intercept session for the given invoice.
// The call blocks until the client has responded to the request or an
// error occurs. The response callback is only called if a session was
// created in the first place, which is only the case if a client is
// registered.
func (m *MockHtlcModifier) Intercept(
req HtlcModifyRequest, callback func(HtlcModifyResponse)) error {
// If no expectations are set, return nil by default.
if len(m.ExpectedCalls) == 0 {
return nil
}
args := m.Called(req, callback)
// If a response was provided to the mock, execute the callback with it.
if response, ok := args.Get(1).(HtlcModifyResponse); ok &&
callback != nil {
callback(response)
}
return args.Error(0)
}
// RegisterInterceptor sets the client callback function that will be
// called when an invoice is intercepted. If a callback is already set,
// an error is returned. The returned function must be used to reset the
// callback to nil once the client is done or disconnects. The read-only channel
// closes when the server stops.
func (m *MockHtlcModifier) RegisterInterceptor(HtlcModifyCallback) (func(),
<-chan struct{}, error) {
return func() {}, make(chan struct{}), nil
}
// Ensure that MockHtlcModifier implements the HtlcInterceptor and HtlcModifier
// interfaces.
var _ HtlcInterceptor = (*MockHtlcModifier)(nil)
var _ HtlcModifier = (*MockHtlcModifier)(nil)
package invoices
import (
"errors"
"fmt"
"sync/atomic"
"github.com/lightningnetwork/lnd/fn/v2"
)
var (
// ErrInterceptorClientAlreadyConnected is an error that is returned
// when a client tries to connect to the interceptor service while
// another client is already connected.
ErrInterceptorClientAlreadyConnected = errors.New(
"interceptor client already connected",
)
// ErrInterceptorClientDisconnected is an error that is returned when
// the client disconnects during an interceptor session.
ErrInterceptorClientDisconnected = errors.New(
"interceptor client disconnected",
)
)
// safeCallback is a wrapper around a callback function that is safe for
// concurrent access.
type safeCallback struct {
// callback is the actual callback function that is called when an
// invoice is intercepted. This might be nil if no client is currently
// connected.
callback atomic.Pointer[HtlcModifyCallback]
}
// Set atomically sets the callback function. If a callback is already set, an
// error is returned. The returned function can be used to reset the callback to
// nil once the client is done.
func (s *safeCallback) Set(callback HtlcModifyCallback) (func(), error) {
if !s.callback.CompareAndSwap(nil, &callback) {
return nil, ErrInterceptorClientAlreadyConnected
}
return func() {
s.callback.Store(nil)
}, nil
}
// IsConnected returns true if a client is currently connected.
func (s *safeCallback) IsConnected() bool {
return s.callback.Load() != nil
}
// Exec executes the callback function if it is set. If the callback is not set,
// an error is returned.
func (s *safeCallback) Exec(req HtlcModifyRequest) (*HtlcModifyResponse,
error) {
callback := s.callback.Load()
if callback == nil {
return nil, ErrInterceptorClientDisconnected
}
return (*callback)(req)
}
// HtlcModificationInterceptor is a service that intercepts HTLCs that aim to
// settle an invoice, enabling a subscribed client to modify certain aspects of
// those HTLCs.
type HtlcModificationInterceptor struct {
started atomic.Bool
stopped atomic.Bool
// callback is the wrapped client callback function that is called when
// an invoice is intercepted. This function gives the client the ability
// to determine how the invoice should be settled.
callback *safeCallback
// quit is a channel that is closed when the interceptor is stopped.
quit chan struct{}
}
// NewHtlcModificationInterceptor creates a new HtlcModificationInterceptor.
func NewHtlcModificationInterceptor() *HtlcModificationInterceptor {
return &HtlcModificationInterceptor{
callback: &safeCallback{},
quit: make(chan struct{}),
}
}
// Intercept generates a new intercept session for the given invoice. The call
// blocks until the client has responded to the request or an error occurs. The
// response callback is only called if a session was created in the first place,
// which is only the case if a client is registered.
func (s *HtlcModificationInterceptor) Intercept(clientRequest HtlcModifyRequest,
responseCallback func(HtlcModifyResponse)) error {
// If there is no client callback set we will not handle the invoice
// further.
if !s.callback.IsConnected() {
log.Debugf("Not intercepting invoice with circuit key %v, no "+
"intercept client connected",
clientRequest.ExitHtlcCircuitKey)
return nil
}
// We'll block until the client has responded to the request or an error
// occurs.
var (
responseChan = make(chan *HtlcModifyResponse, 1)
errChan = make(chan error, 1)
)
// The callback function will block at the client's discretion. We will
// therefore execute it in a separate goroutine. We don't need a wait
// group because we wait for the response directly below. The caller
// needs to make sure they don't block indefinitely, by selecting on the
// quit channel they receive when registering the callback.
go func() {
log.Debugf("Waiting for client response from invoice HTLC "+
"interceptor session with circuit key %v",
clientRequest.ExitHtlcCircuitKey)
// By this point, we've already checked that the client callback
// is set. However, if the client disconnected since that check
// then Exec will return an error.
result, err := s.callback.Exec(clientRequest)
if err != nil {
_ = fn.SendOrQuit(errChan, err, s.quit)
return
}
_ = fn.SendOrQuit(responseChan, result, s.quit)
}()
// Wait for the client to respond or an error to occur.
select {
case response := <-responseChan:
responseCallback(*response)
return nil
case err := <-errChan:
log.Errorf("Error from invoice HTLC interceptor session: %v",
err)
return err
case <-s.quit:
return ErrInterceptorClientDisconnected
}
}
// RegisterInterceptor sets the client callback function that will be called
// when an invoice is intercepted. If a callback is already set, an error is
// returned. The returned function must be used to reset the callback to nil
// once the client is done or disconnects.
func (s *HtlcModificationInterceptor) RegisterInterceptor(
callback HtlcModifyCallback) (func(), <-chan struct{}, error) {
done, err := s.callback.Set(callback)
return done, s.quit, err
}
// Start starts the service.
func (s *HtlcModificationInterceptor) Start() error {
log.Info("HtlcModificationInterceptor starting...")
if !s.started.CompareAndSwap(false, true) {
return fmt.Errorf("HtlcModificationInterceptor started more" +
"than once")
}
log.Debugf("HtlcModificationInterceptor started")
return nil
}
// Stop stops the service.
func (s *HtlcModificationInterceptor) Stop() error {
log.Info("HtlcModificationInterceptor stopping...")
if !s.stopped.CompareAndSwap(false, true) {
return fmt.Errorf("HtlcModificationInterceptor stopped more" +
"than once")
}
close(s.quit)
log.Debug("HtlcModificationInterceptor stopped")
return nil
}
// Ensure that HtlcModificationInterceptor implements the HtlcInterceptor and
// HtlcModifier interfaces.
var _ HtlcInterceptor = (*HtlcModificationInterceptor)(nil)
var _ HtlcModifier = (*HtlcModificationInterceptor)(nil)
package invoices
import (
"time"
"github.com/lightningnetwork/lnd/lntypes"
)
// HtlcResolution describes how an htlc should be resolved.
type HtlcResolution interface {
// CircuitKey returns the circuit key for the htlc that we have a
// resolution for.
CircuitKey() CircuitKey
}
// HtlcFailResolution is an implementation of the HtlcResolution interface
// which is returned when a htlc is failed.
type HtlcFailResolution struct {
// circuitKey is the key of the htlc for which we have a resolution.
circuitKey CircuitKey
// AcceptHeight is the original height at which the htlc was accepted.
AcceptHeight int32
// Outcome indicates the outcome of the invoice registry update.
Outcome FailResolutionResult
}
// NewFailResolution returns a htlc failure resolution.
func NewFailResolution(key CircuitKey, acceptHeight int32,
outcome FailResolutionResult) *HtlcFailResolution {
return &HtlcFailResolution{
circuitKey: key,
AcceptHeight: acceptHeight,
Outcome: outcome,
}
}
// CircuitKey returns the circuit key for the htlc that we have a
// resolution for.
//
// Note: it is part of the HtlcResolution interface.
func (f *HtlcFailResolution) CircuitKey() CircuitKey {
return f.circuitKey
}
// HtlcSettleResolution is an implementation of the HtlcResolution interface
// which is returned when a htlc is settled.
type HtlcSettleResolution struct {
// Preimage is the htlc preimage. Its value is nil in case of a cancel.
Preimage lntypes.Preimage
// circuitKey is the key of the htlc for which we have a resolution.
circuitKey CircuitKey
// acceptHeight is the original height at which the htlc was accepted.
AcceptHeight int32
// Outcome indicates the outcome of the invoice registry update.
Outcome SettleResolutionResult
}
// NewSettleResolution returns a htlc resolution which is associated with a
// settle.
func NewSettleResolution(preimage lntypes.Preimage, key CircuitKey,
acceptHeight int32,
outcome SettleResolutionResult) *HtlcSettleResolution {
return &HtlcSettleResolution{
Preimage: preimage,
circuitKey: key,
AcceptHeight: acceptHeight,
Outcome: outcome,
}
}
// CircuitKey returns the circuit key for the htlc that we have a
// resolution for.
//
// Note: it is part of the HtlcResolution interface.
func (s *HtlcSettleResolution) CircuitKey() CircuitKey {
return s.circuitKey
}
// htlcAcceptResolution is an implementation of the HtlcResolution interface
// which is returned when a htlc is accepted. This struct is not exported
// because the codebase uses a nil resolution to indicate that a htlc was
// accepted. This struct is used internally in the invoice registry to
// surface accept resolution results. When an invoice update returns an
// acceptResolution, a nil resolution should be surfaced.
type htlcAcceptResolution struct {
// circuitKey is the key of the htlc for which we have a resolution.
circuitKey CircuitKey
// autoRelease signals that the htlc should be automatically released
// after a timeout.
autoRelease bool
// acceptTime is the time at which this htlc was accepted.
acceptTime time.Time
// outcome indicates the outcome of the invoice registry update.
outcome acceptResolutionResult
}
// newAcceptResolution returns a htlc resolution which is associated with a
// htlc accept.
func newAcceptResolution(key CircuitKey,
outcome acceptResolutionResult) *htlcAcceptResolution {
return &htlcAcceptResolution{
circuitKey: key,
outcome: outcome,
}
}
// CircuitKey returns the circuit key for the htlc that we have a
// resolution for.
//
// Note: it is part of the HtlcResolution interface.
func (a *htlcAcceptResolution) CircuitKey() CircuitKey {
return a.circuitKey
}
package invoices
// acceptResolutionResult provides metadata which about a htlc that was
// accepted by the registry.
type acceptResolutionResult uint8
const (
resultInvalidAccept acceptResolutionResult = iota
// resultReplayToAccepted is returned when we replay an accepted
// invoice.
resultReplayToAccepted
// resultDuplicateToAccepted is returned when we accept a duplicate
// htlc.
resultDuplicateToAccepted
// resultAccepted is returned when we accept a hodl invoice.
resultAccepted
// resultPartialAccepted is returned when we have partially received
// payment.
resultPartialAccepted
)
// String returns a string representation of the result.
func (a acceptResolutionResult) String() string {
switch a {
case resultInvalidAccept:
return "invalid accept result"
case resultReplayToAccepted:
return "replayed htlc to accepted invoice"
case resultDuplicateToAccepted:
return "accepting duplicate payment to accepted invoice"
case resultAccepted:
return "accepted"
case resultPartialAccepted:
return "partial payment accepted"
default:
return "unknown accept resolution result"
}
}
// FailResolutionResult provides metadata about a htlc that was failed by
// the registry. It can be used to take custom actions on resolution of the
// htlc.
type FailResolutionResult uint8
const (
resultInvalidFailure FailResolutionResult = iota
// ResultReplayToCanceled is returned when we replay a canceled invoice.
ResultReplayToCanceled
// ResultInvoiceAlreadyCanceled is returned when trying to pay an
// invoice that is already canceled.
ResultInvoiceAlreadyCanceled
// ResultInvoiceAlreadySettled is returned when trying to pay an invoice
// that is already settled.
ResultInvoiceAlreadySettled
// ResultAmountTooLow is returned when an invoice is underpaid.
ResultAmountTooLow
// ResultExpiryTooSoon is returned when we do not accept an invoice
// payment because it expires too soon.
ResultExpiryTooSoon
// ResultCanceled is returned when we cancel an invoice and its
// associated htlcs.
ResultCanceled
// ResultInvoiceNotOpen is returned when a mpp invoice is not open.
ResultInvoiceNotOpen
// ResultMppTimeout is returned when an invoice paid with multiple
// partial payments times out before it is fully paid.
ResultMppTimeout
// ResultAddressMismatch is returned when the payment address for a mpp
// invoice does not match.
ResultAddressMismatch
// ResultHtlcSetTotalMismatch is returned when the amount paid by a
// htlc does not match its set total.
ResultHtlcSetTotalMismatch
// ResultHtlcSetTotalTooLow is returned when a mpp set total is too low
// for an invoice.
ResultHtlcSetTotalTooLow
// ResultHtlcSetOverpayment is returned when a mpp set is overpaid.
ResultHtlcSetOverpayment
// ResultInvoiceNotFound is returned when an attempt is made to pay an
// invoice that is unknown to us.
ResultInvoiceNotFound
// ResultKeySendError is returned when we receive invalid keysend
// parameters.
ResultKeySendError
// ResultMppInProgress is returned when we are busy receiving a mpp
// payment.
ResultMppInProgress
// ResultHtlcInvoiceTypeMismatch is returned when an AMP HTLC targets a
// non-AMP invoice and vice versa.
ResultHtlcInvoiceTypeMismatch
// ResultAmpError is returned when we receive invalid AMP parameters.
ResultAmpError
// ResultAmpReconstruction is returned when the derived child
// hash/preimage pairs were invalid for at least one HTLC in the set.
ResultAmpReconstruction
// ExternalValidationFailed is returned when the external validation
// failed.
ExternalValidationFailed
)
// String returns a string representation of the result.
func (f FailResolutionResult) String() string {
return f.FailureString()
}
// FailureString returns a string representation of the result.
//
// Note: it is part of the FailureDetail interface.
func (f FailResolutionResult) FailureString() string {
switch f {
case resultInvalidFailure:
return "invalid failure result"
case ResultReplayToCanceled:
return "replayed htlc to canceled invoice"
case ResultInvoiceAlreadyCanceled:
return "invoice already canceled"
case ResultInvoiceAlreadySettled:
return "invoice already settled"
case ResultAmountTooLow:
return "amount too low"
case ResultExpiryTooSoon:
return "expiry too soon"
case ResultCanceled:
return "canceled"
case ResultInvoiceNotOpen:
return "invoice no longer open"
case ResultMppTimeout:
return "mpp timeout"
case ResultAddressMismatch:
return "payment address mismatch"
case ResultHtlcSetTotalMismatch:
return "htlc total amt doesn't match set total"
case ResultHtlcSetTotalTooLow:
return "set total too low for invoice"
case ResultHtlcSetOverpayment:
return "mpp is overpaying set total"
case ResultInvoiceNotFound:
return "invoice not found"
case ResultKeySendError:
return "invalid keysend parameters"
case ResultMppInProgress:
return "mpp reception in progress"
case ResultHtlcInvoiceTypeMismatch:
return "htlc invoice type mismatch"
case ResultAmpError:
return "invalid amp parameters"
case ResultAmpReconstruction:
return "amp reconstruction failed"
case ExternalValidationFailed:
return "external validation failed"
default:
return "unknown failure resolution result"
}
}
// IsSetFailure returns true if this failure should result in the entire HTLC
// set being failed with the same result.
func (f FailResolutionResult) IsSetFailure() bool {
switch f {
case
ResultAmpReconstruction,
ResultHtlcSetTotalTooLow,
ResultHtlcSetTotalMismatch,
ResultHtlcSetOverpayment,
ExternalValidationFailed:
return true
default:
return false
}
}
// SettleResolutionResult provides metadata which about a htlc that was failed
// by the registry. It can be used to take custom actions on resolution of the
// htlc.
type SettleResolutionResult uint8
const (
resultInvalidSettle SettleResolutionResult = iota
// ResultSettled is returned when we settle an invoice.
ResultSettled
// ResultReplayToSettled is returned when we replay a settled invoice.
ResultReplayToSettled
// ResultDuplicateToSettled is returned when we settle an invoice which
// has already been settled at least once.
ResultDuplicateToSettled
)
// String returns a string representation of the result.
func (s SettleResolutionResult) String() string {
switch s {
case resultInvalidSettle:
return "invalid settle result"
case ResultSettled:
return "settled"
case ResultReplayToSettled:
return "replayed htlc to settled invoice"
case ResultDuplicateToSettled:
return "accepting duplicate payment to settled invoice"
default:
return "unknown settle resolution result"
}
}
package invoices
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/pmezard/go-difflib/difflib"
"golang.org/x/time/rate"
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
// Within the invoice bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// invoiceIndexBucket is the name of the sub-bucket within the
// invoiceBucket which indexes all invoices by their payment hash. The
// payment hash is the sha256 of the invoice's payment preimage. This
// index is used to detect duplicates, and also to provide a fast path
// for looking up incoming HTLCs to determine if we're able to settle
// them fully.
//
// maps: payHash => invoiceKey
invoiceIndexBucket = []byte("paymenthashes")
// numInvoicesKey is the name of key which houses the auto-incrementing
// invoice ID which is essentially used as a primary key. With each
// invoice inserted, the primary key is incremented by one. This key is
// stored within the invoiceIndexBucket. Within the invoiceBucket
// invoices are uniquely identified by the invoice ID.
numInvoicesKey = []byte("nik")
// addIndexBucket is an index bucket that we'll use to create a
// monotonically increasing set of add indexes. Each time we add a new
// invoice, this sequence number will be incremented and then populated
// within the new invoice.
//
// In addition to this sequence number, we map:
//
// addIndexNo => invoiceKey
addIndexBucket = []byte("invoice-add-index")
// ErrMigrationMismatch is returned when the migrated invoice does not
// match the original invoice.
ErrMigrationMismatch = fmt.Errorf("migrated invoice does not match " +
"original invoice")
)
// createInvoiceHashIndex generates a hash index that contains payment hashes
// for each invoice in the database. Retrieving the payment hash for certain
// invoices, such as those created for spontaneous AMP payments, can be
// challenging because the hash is not directly derivable from the invoice's
// parameters and is stored separately in the `paymenthashes` bucket. This
// bucket maps payment hashes to invoice keys, but for migration purposes, we
// need the ability to query in the reverse direction. This function establishes
// a new index in the SQL database that maps each invoice key to its
// corresponding payment hash.
func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
tx *sqlc.Queries) error {
return db.View(func(kvTx kvdb.RTx) error {
invoices := kvTx.ReadBucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceIndex := invoices.NestedReadBucket(
invoiceIndexBucket,
)
if invoiceIndex == nil {
return ErrNoInvoicesCreated
}
addIndex := invoices.NestedReadBucket(addIndexBucket)
if addIndex == nil {
return ErrNoInvoicesCreated
}
// First, iterate over all elements in the add index bucket and
// insert the add index value for the corresponding invoice key
// in the payment_hashes table.
err := addIndex.ForEach(func(k, v []byte) error {
// The key is the add index, and the value is
// the invoice key.
addIndexNo := binary.BigEndian.Uint64(k)
invoiceKey := binary.BigEndian.Uint32(v)
return tx.InsertKVInvoiceKeyAndAddIndex(ctx,
sqlc.InsertKVInvoiceKeyAndAddIndexParams{
ID: int64(invoiceKey),
AddIndex: int64(addIndexNo),
},
)
})
if err != nil {
return err
}
// Next, iterate over all hashes in the invoice index bucket and
// set the hash to the corresponding the invoice key in the
// payment_hashes table.
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does
// not point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
// The key is the payment hash, and the value
// is the invoice key.
if len(k) != lntypes.HashSize {
return fmt.Errorf("invalid payment "+
"hash length: expected %v, "+
"got %v", lntypes.HashSize,
len(k))
}
invoiceKey := binary.BigEndian.Uint32(v)
return tx.SetKVInvoicePaymentHash(ctx,
sqlc.SetKVInvoicePaymentHashParams{
ID: int64(invoiceKey),
Hash: k,
},
)
})
}, func() {})
}
// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
// invoice into the SQL database. The parameters are derived from the original
// invoice insert parameters.
func toInsertMigratedInvoiceParams(
params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
return sqlc.InsertMigratedInvoiceParams{
Hash: params.Hash,
Preimage: params.Preimage,
Memo: params.Memo,
AmountMsat: params.AmountMsat,
CltvDelta: params.CltvDelta,
Expiry: params.Expiry,
PaymentAddr: params.PaymentAddr,
PaymentRequest: params.PaymentRequest,
PaymentRequestHash: params.PaymentRequestHash,
State: params.State,
AmountPaidMsat: params.AmountPaidMsat,
IsAmp: params.IsAmp,
IsHodl: params.IsHodl,
IsKeysend: params.IsKeysend,
CreatedAt: params.CreatedAt,
}
}
// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
// that perfect equality between the old and new schemas is not achievable, as
// the invoice's add index cannot be mapped directly to its ID due to SQL’s
// auto-incrementing primary key. The ID returned from the insert will instead
// serve as the add index in the new schema.
func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
invoice *Invoice, paymentHash lntypes.Hash) error {
insertInvoiceParams, err := makeInsertInvoiceParams(
invoice, paymentHash,
)
if err != nil {
return err
}
// Convert the insert invoice parameters to the migrated invoice insert
// parameters.
insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
insertInvoiceParams,
)
// If the invoice is settled, we'll also set the timestamp and the index
// at which it was settled.
if invoice.State == ContractSettled {
if invoice.SettleIndex == 0 {
return fmt.Errorf("settled invoice %s missing settle "+
"index", paymentHash)
}
if invoice.SettleDate.IsZero() {
return fmt.Errorf("settled invoice %s missing settle "+
"date", paymentHash)
}
insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
invoice.SettleIndex,
)
insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
invoice.SettleDate.UTC(),
)
}
// First we need to insert the invoice itself so we can use the "add
// index" which in this case is the auto incrementing primary key that
// is returned from the insert.
invoiceID, err := tx.InsertMigratedInvoice(
ctx, insertMigratedInvoiceParams,
)
if err != nil {
return fmt.Errorf("unable to insert invoice: %w", err)
}
// Insert the invoice's features.
for feature := range invoice.Terms.Features.Features() {
params := sqlc.InsertInvoiceFeatureParams{
InvoiceID: invoiceID,
Feature: int32(feature),
}
err := tx.InsertInvoiceFeature(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert invoice "+
"feature(%v): %w", feature, err)
}
}
sqlHtlcIDs := make(map[models.CircuitKey]int64)
// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
// ID of each HTLC so we can use it when inserting the AMP sub invoices.
for circuitKey, htlc := range invoice.Htlcs {
htlcParams := sqlc.InsertInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID),
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
AmountMsat: int64(htlc.Amt),
AcceptHeight: int32(htlc.AcceptHeight),
AcceptTime: htlc.AcceptTime.UTC(),
ExpiryHeight: int32(htlc.Expiry),
State: int16(htlc.State),
InvoiceID: invoiceID,
}
// Leave the MPP amount as NULL if the MPP total amount is zero.
if htlc.MppTotalAmt != 0 {
htlcParams.TotalMppMsat = sqldb.SQLInt64(
int64(htlc.MppTotalAmt),
)
}
// Leave the resolve time as NULL if the HTLC is not resolved.
if !htlc.ResolveTime.IsZero() {
htlcParams.ResolveTime = sqldb.SQLTime(
htlc.ResolveTime.UTC(),
)
}
sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
if err != nil {
return fmt.Errorf("unable to insert invoice htlc: %w",
err)
}
sqlHtlcIDs[circuitKey] = sqlID
// Store custom records.
for key, value := range htlc.CustomRecords {
err = tx.InsertInvoiceHTLCCustomRecord(
ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
Key: int64(key),
Value: value,
HtlcID: sqlID,
},
)
if err != nil {
return err
}
}
}
if !invoice.IsAMP() {
return nil
}
for setID, ampState := range invoice.AMPState {
// Find the earliest HTLC of the AMP invoice, which will
// be used as the creation date of this sub invoice.
var createdAt time.Time
for circuitKey := range ampState.InvoiceKeys {
htlc := invoice.Htlcs[circuitKey]
if createdAt.IsZero() {
createdAt = htlc.AcceptTime.UTC()
continue
}
if createdAt.After(htlc.AcceptTime) {
createdAt = htlc.AcceptTime.UTC()
}
}
params := sqlc.InsertAMPSubInvoiceParams{
SetID: setID[:],
State: int16(ampState.State),
CreatedAt: createdAt,
InvoiceID: invoiceID,
}
if ampState.SettleIndex != 0 {
if ampState.SettleDate.IsZero() {
return fmt.Errorf("settled AMP sub invoice %x "+
"missing settle date", setID)
}
params.SettledAt = sqldb.SQLTime(
ampState.SettleDate.UTC(),
)
params.SettleIndex = sqldb.SQLInt64(
ampState.SettleIndex,
)
}
err := tx.InsertAMPSubInvoice(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert AMP sub invoice: "+
"%w", err)
}
// Now we can add the AMP HTLCs to the database.
for circuitKey := range ampState.InvoiceKeys {
htlc := invoice.Htlcs[circuitKey]
rootShare := htlc.AMP.Record.RootShare()
sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
if !ok {
return fmt.Errorf("missing htlc for AMP htlc: "+
"%v", circuitKey)
}
params := sqlc.InsertAMPSubInvoiceHTLCParams{
InvoiceID: invoiceID,
SetID: setID[:],
HtlcID: sqlHtlcID,
RootShare: rootShare[:],
ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
Hash: htlc.AMP.Hash[:],
}
if htlc.AMP.Preimage != nil {
params.Preimage = htlc.AMP.Preimage[:]
}
err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert AMP sub "+
"invoice: %w", err)
}
}
}
return nil
}
// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
// time zone and chops off the nanosecond part for comparison. This is needed
// because KV database stores times as-is which as an unwanted side effect would
// fail migration due to time comparison expecting both the original and
// migrated invoices to be in the same local time zone and in microsecond
// precision. Note that PostgreSQL stores times in microsecond precision while
// SQLite can store times in nanosecond precision if using TEXT storage class.
func OverrideInvoiceTimeZone(invoice *Invoice) {
fixTime := func(t time.Time) time.Time {
return t.In(time.Local).Truncate(time.Microsecond)
}
invoice.CreationDate = fixTime(invoice.CreationDate)
if !invoice.SettleDate.IsZero() {
invoice.SettleDate = fixTime(invoice.SettleDate)
}
if invoice.IsAMP() {
for setID, ampState := range invoice.AMPState {
if ampState.SettleDate.IsZero() {
continue
}
ampState.SettleDate = fixTime(ampState.SettleDate)
invoice.AMPState[setID] = ampState
}
}
for _, htlc := range invoice.Htlcs {
if !htlc.AcceptTime.IsZero() {
htlc.AcceptTime = fixTime(htlc.AcceptTime)
}
if !htlc.ResolveTime.IsZero() {
htlc.ResolveTime = fixTime(htlc.ResolveTime)
}
}
}
// MigrateInvoicesToSQL runs the migration of all invoices from the KV database
// to the SQL database. The migration is done in a single transaction to ensure
// that all invoices are migrated or none at all. This function can be run
// multiple times without causing any issues as it will check if the migration
// has already been performed.
func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend,
kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error {
log.Infof("Starting migration of invoices from KV to SQL")
offset := uint64(0)
t0 := time.Now()
// Create the hash index which we will use to look up invoice
// payment hashes by their add index during migration.
err := createInvoiceHashIndex(ctx, db, tx)
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
log.Errorf("Unable to create invoice hash index: %v",
err)
return err
}
log.Debugf("Created SQL invoice hash index in %v", time.Since(t0))
s := rate.Sometimes{
Interval: 30 * time.Second,
}
t0 = time.Now()
chunk := 0
total := 0
// Now we can start migrating the invoices. We'll do this in
// batches to reduce memory usage.
for {
query := InvoiceQuery{
IndexOffset: offset,
NumMaxInvoices: uint64(batchSize),
}
queryResult, err := kvStore.QueryInvoices(ctx, query)
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
return fmt.Errorf("unable to query invoices: %w", err)
}
if len(queryResult.Invoices) == 0 {
log.Infof("All invoices migrated. Total: %d", total)
break
}
err = migrateInvoices(ctx, tx, queryResult.Invoices)
if err != nil {
return err
}
offset = queryResult.LastIndexOffset
resultCnt := len(queryResult.Invoices)
total += resultCnt
chunk += resultCnt
s.Do(func() {
elapsed := time.Since(t0).Seconds()
ratePerSec := float64(chunk) / elapsed
log.Debugf("Migrated %d invoices (%.2f invoices/sec)",
total, ratePerSec)
t0 = time.Now()
chunk = 0
})
}
// Clean up the hash index as it's no longer needed.
err = tx.ClearKVInvoiceHashIndex(ctx)
if err != nil {
return fmt.Errorf("unable to clear invoice hash "+
"index: %w", err)
}
log.Infof("Migration of %d invoices from KV to SQL completed", total)
return nil
}
func migrateInvoices(ctx context.Context, tx *sqlc.Queries,
invoices []Invoice) error {
for i, invoice := range invoices {
var paymentHash lntypes.Hash
if invoice.Terms.PaymentPreimage != nil {
paymentHash = invoice.Terms.PaymentPreimage.Hash()
} else {
paymentHashBytes, err :=
tx.GetKVInvoicePaymentHashByAddIndex(
ctx, int64(invoice.AddIndex),
)
if err != nil {
// This would be an unexpected inconsistency
// in the kv database. We can't do much here
// so we'll notify the user and continue.
log.Warnf("Cannot migrate invoice, unable to "+
"fetch payment hash (add_index=%v): %v",
invoice.AddIndex, err)
continue
}
copy(paymentHash[:], paymentHashBytes)
}
err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
if err != nil {
return fmt.Errorf("unable to migrate invoice(%v): %w",
paymentHash, err)
}
migratedInvoice, err := fetchInvoice(
ctx, tx, InvoiceRefByHash(paymentHash),
)
if err != nil {
return fmt.Errorf("unable to fetch migrated "+
"invoice(%v): %w", paymentHash, err)
}
// Override the time zone for comparison. Note that we need to
// override both invoices as the original invoice is coming from
// KV database, it was stored as a binary serialized Go
// time.Time value which has nanosecond precision but might have
// been created in a different time zone. The migrated invoice
// is stored in SQL in UTC and selected in the local time zone,
// however in PostgreSQL it has microsecond precision while in
// SQLite it has nanosecond precision if using TEXT storage
// class.
OverrideInvoiceTimeZone(&invoice)
OverrideInvoiceTimeZone(migratedInvoice)
// Override the add index before checking for equality.
migratedInvoice.AddIndex = invoice.AddIndex
if !reflect.DeepEqual(invoice, *migratedInvoice) {
diff := difflib.UnifiedDiff{
A: difflib.SplitLines(
spew.Sdump(invoice),
),
B: difflib.SplitLines(
spew.Sdump(migratedInvoice),
),
FromFile: "Expected",
FromDate: "",
ToFile: "Actual",
ToDate: "",
Context: 3,
}
diffText, _ := difflib.GetUnifiedDiffString(diff)
return fmt.Errorf("%w: %v.\n%v", ErrMigrationMismatch,
paymentHash, diffText)
}
}
return nil
}
package invoices
import (
"context"
"crypto/sha256"
"database/sql"
"errors"
"fmt"
"math"
"strconv"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
)
const (
// defaultQueryPaginationLimit is used in the LIMIT clause of the SQL
// queries to limit the number of rows returned.
defaultQueryPaginationLimit = 100
// invoiceProgressLogInterval is the interval we use limiting the
// logging output of invoice processing.
invoiceProgressLogInterval = 30 * time.Second
)
// SQLInvoiceQueries is an interface that defines the set of operations that can
// be executed against the invoice SQL database.
type SQLInvoiceQueries interface { //nolint:interfacebloat
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
error)
// TODO(bhandras): remove this once migrations have been separated out.
InsertMigratedInvoice(ctx context.Context,
arg sqlc.InsertMigratedInvoiceParams) (int64, error)
InsertInvoiceFeature(ctx context.Context,
arg sqlc.InsertInvoiceFeatureParams) error
InsertInvoiceHTLC(ctx context.Context,
arg sqlc.InsertInvoiceHTLCParams) (int64, error)
InsertInvoiceHTLCCustomRecord(ctx context.Context,
arg sqlc.InsertInvoiceHTLCCustomRecordParams) error
FilterInvoices(ctx context.Context,
arg sqlc.FilterInvoicesParams) ([]sqlc.Invoice, error)
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)
GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error)
GetInvoiceHTLCCustomRecords(ctx context.Context,
invoiceID int64) ([]sqlc.GetInvoiceHTLCCustomRecordsRow, error)
GetInvoiceHTLCs(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceHtlc, error)
UpdateInvoiceState(ctx context.Context,
arg sqlc.UpdateInvoiceStateParams) (sql.Result, error)
UpdateInvoiceAmountPaid(ctx context.Context,
arg sqlc.UpdateInvoiceAmountPaidParams) (sql.Result, error)
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
UpdateInvoiceHTLC(ctx context.Context,
arg sqlc.UpdateInvoiceHTLCParams) error
DeleteInvoice(ctx context.Context, arg sqlc.DeleteInvoiceParams) (
sql.Result, error)
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
// AMP sub invoice specific methods.
UpsertAMPSubInvoice(ctx context.Context,
arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
// TODO(bhandras): remove this once migrations have been separated out.
InsertAMPSubInvoice(ctx context.Context,
arg sqlc.InsertAMPSubInvoiceParams) error
UpdateAMPSubInvoiceState(ctx context.Context,
arg sqlc.UpdateAMPSubInvoiceStateParams) error
InsertAMPSubInvoiceHTLC(ctx context.Context,
arg sqlc.InsertAMPSubInvoiceHTLCParams) error
FetchAMPSubInvoices(ctx context.Context,
arg sqlc.FetchAMPSubInvoicesParams) ([]sqlc.AmpSubInvoice,
error)
FetchAMPSubInvoiceHTLCs(ctx context.Context,
arg sqlc.FetchAMPSubInvoiceHTLCsParams) (
[]sqlc.FetchAMPSubInvoiceHTLCsRow, error)
FetchSettledAMPSubInvoices(ctx context.Context,
arg sqlc.FetchSettledAMPSubInvoicesParams) (
[]sqlc.FetchSettledAMPSubInvoicesRow, error)
UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context,
arg sqlc.UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result,
error)
// Invoice events specific methods.
OnInvoiceCreated(ctx context.Context,
arg sqlc.OnInvoiceCreatedParams) error
OnInvoiceCanceled(ctx context.Context,
arg sqlc.OnInvoiceCanceledParams) error
OnInvoiceSettled(ctx context.Context,
arg sqlc.OnInvoiceSettledParams) error
OnAMPSubInvoiceCreated(ctx context.Context,
arg sqlc.OnAMPSubInvoiceCreatedParams) error
OnAMPSubInvoiceCanceled(ctx context.Context,
arg sqlc.OnAMPSubInvoiceCanceledParams) error
OnAMPSubInvoiceSettled(ctx context.Context,
arg sqlc.OnAMPSubInvoiceSettledParams) error
// Migration specific methods.
// TODO(bhandras): remove this once migrations have been separated out.
InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
SetKVInvoicePaymentHash(ctx context.Context,
arg sqlc.SetKVInvoicePaymentHashParams) error
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
[]byte, error)
ClearKVInvoiceHashIndex(ctx context.Context) error
}
var _ InvoiceDB = (*SQLStore)(nil)
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
// SQLInvoiceQueries understands.
type SQLInvoiceQueriesTxOptions struct {
// readOnly governs if a read only transaction is needed or not.
readOnly bool
}
// ReadOnly returns true if the transaction should be read only.
//
// NOTE: This implements the TxOptions.
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
return a.readOnly
}
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
return SQLInvoiceQueriesTxOptions{
readOnly: true,
}
}
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
// of batched database operations.
type BatchedSQLInvoiceQueries interface {
SQLInvoiceQueries
sqldb.BatchedTx[SQLInvoiceQueries]
}
// SQLStore represents a storage backend.
type SQLStore struct {
db BatchedSQLInvoiceQueries
clock clock.Clock
opts SQLStoreOptions
}
// SQLStoreOptions holds the options for the SQL store.
type SQLStoreOptions struct {
paginationLimit int
}
// defaultSQLStoreOptions returns the default options for the SQL store.
func defaultSQLStoreOptions() SQLStoreOptions {
return SQLStoreOptions{
paginationLimit: defaultQueryPaginationLimit,
}
}
// SQLStoreOption is a functional option that can be used to optionally modify
// the behavior of the SQL store.
type SQLStoreOption func(*SQLStoreOptions)
// WithPaginationLimit sets the pagination limit for the SQL store queries that
// paginate results.
func WithPaginationLimit(limit int) SQLStoreOption {
return func(o *SQLStoreOptions) {
o.paginationLimit = limit
}
}
// NewSQLStore creates a new SQLStore instance given a open
// BatchedSQLInvoiceQueries storage backend.
func NewSQLStore(db BatchedSQLInvoiceQueries,
clock clock.Clock, options ...SQLStoreOption) *SQLStore {
opts := defaultSQLStoreOptions()
for _, applyOption := range options {
applyOption(&opts)
}
return &SQLStore{
db: db,
clock: clock,
opts: opts,
}
}
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
sqlc.InsertInvoiceParams, error) {
// Precompute the payment request hash so we can use it in the query.
var paymentRequestHash []byte
if len(invoice.PaymentRequest) > 0 {
h := sha256.New()
h.Write(invoice.PaymentRequest)
paymentRequestHash = h.Sum(nil)
}
params := sqlc.InsertInvoiceParams{
Hash: paymentHash[:],
AmountMsat: int64(invoice.Terms.Value),
CltvDelta: sqldb.SQLInt32(
invoice.Terms.FinalCltvDelta,
),
Expiry: int32(invoice.Terms.Expiry.Seconds()),
// Note: keysend invoices don't have a payment request.
PaymentRequest: sqldb.SQLStr(string(
invoice.PaymentRequest),
),
PaymentRequestHash: paymentRequestHash,
State: int16(invoice.State),
AmountPaidMsat: int64(invoice.AmtPaid),
IsAmp: invoice.IsAMP(),
IsHodl: invoice.HodlInvoice,
IsKeysend: invoice.IsKeysend(),
CreatedAt: invoice.CreationDate.UTC(),
}
if invoice.Memo != nil {
// Store the memo as a nullable string in the database. Note
// that for compatibility reasons, we store the value as a valid
// string even if it's empty.
params.Memo = sql.NullString{
String: string(invoice.Memo),
Valid: true,
}
}
// Some invoices may not have a preimage, like in the case of HODL
// invoices.
if invoice.Terms.PaymentPreimage != nil {
preimage := *invoice.Terms.PaymentPreimage
if preimage == UnknownPreimage {
return sqlc.InsertInvoiceParams{},
errors.New("cannot use all-zeroes preimage")
}
params.Preimage = preimage[:]
}
// Some non MPP payments may have the default (invalid) value.
if invoice.Terms.PaymentAddr != BlankPayAddr {
params.PaymentAddr = invoice.Terms.PaymentAddr[:]
}
return params, nil
}
// AddInvoice inserts the targeted invoice into the database. If the invoice has
// *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes.
//
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
func (i *SQLStore) AddInvoice(ctx context.Context,
newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
// Make sure this is a valid invoice before trying to store it in our
// DB.
if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
return 0, err
}
var (
writeTxOpts SQLInvoiceQueriesTxOptions
invoiceID int64
)
insertInvoiceParams, err := makeInsertInvoiceParams(
newInvoice, paymentHash,
)
if err != nil {
return 0, err
}
err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
var err error
invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
if err != nil {
return fmt.Errorf("unable to insert invoice: %w", err)
}
// TODO(positiveblue): if invocies do not have custom features
// maybe just store the "invoice type" and populate the features
// based on that.
for feature := range newInvoice.Terms.Features.Features() {
params := sqlc.InsertInvoiceFeatureParams{
InvoiceID: invoiceID,
Feature: int32(feature),
}
err := db.InsertInvoiceFeature(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert invoice "+
"feature(%v): %w", feature, err)
}
}
// Finally add a new event for this invoice.
return db.OnInvoiceCreated(ctx, sqlc.OnInvoiceCreatedParams{
AddedAt: newInvoice.CreationDate.UTC(),
InvoiceID: invoiceID,
})
}, func() {})
if err != nil {
mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
// Add context to unique constraint errors.
return 0, ErrDuplicateInvoice
}
return 0, fmt.Errorf("unable to add invoice(%v): %w",
paymentHash, err)
}
newInvoice.AddIndex = uint64(invoiceID)
return newInvoice.AddIndex, nil
}
// getInvoiceByRef fetches the invoice with the given reference. The reference
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
func getInvoiceByRef(ctx context.Context,
db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
// If the reference is empty, we can't look up the invoice.
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
return sqlc.Invoice{}, ErrInvoiceNotFound
}
// If the reference is a hash only, we can look up the invoice directly
// by the payment hash which is faster.
if ref.IsHashOnly() {
invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
if errors.Is(err, sql.ErrNoRows) {
return sqlc.Invoice{}, ErrInvoiceNotFound
}
return invoice, err
}
// Otherwise the reference may include more fields, so we'll need to
// assemble the query parameters based on the fields that are set.
var params sqlc.GetInvoiceParams
if ref.PayHash() != nil {
params.Hash = ref.PayHash()[:]
}
// Newer invoices (0.11 and up) are indexed by payment address in
// addition to payment hash, but pre 0.8 invoices do not have one at
// all. Only allow lookups for payment address if it is not a blank
// payment address, which is a special-cased value for legacy keysend
// invoices.
if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
params.PaymentAddr = ref.PayAddr()[:]
}
// If the reference has a set ID we'll fetch the invoice which has the
// corresponding AMP sub invoice.
if ref.SetID() != nil {
params.SetID = ref.SetID()[:]
}
var (
rows []sqlc.Invoice
err error
)
// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {
rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return sqlc.Invoice{}, ErrInvoiceNotFound
case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
"%s: %s", ref.String(), spew.Sdump(rows))
case err != nil:
return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
err)
}
return rows[0], nil
}
// fetchInvoice fetches the common invoice data and the AMP state for the
// invoice with the given reference.
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
*Invoice, error) {
// Fetch the invoice from the database.
sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
if err != nil {
return nil, err
}
var (
setID *[32]byte
fetchAmpHtlcs bool
)
// Now that we got the invoice itself, fetch the HTLCs as requested by
// the modifier.
switch ref.Modifier() {
case DefaultModifier:
// By default we'll fetch all AMP HTLCs.
setID = nil
fetchAmpHtlcs = true
case HtlcSetOnlyModifier:
// In this case we'll fetch all AMP HTLCs for the specified set
// id.
if ref.SetID() == nil {
return nil, fmt.Errorf("set ID is required to use " +
"the HTLC set only modifier")
}
setID = ref.SetID()
fetchAmpHtlcs = true
case HtlcSetBlankModifier:
// No need to fetch any HTLCs.
setID = nil
fetchAmpHtlcs = false
default:
return nil, fmt.Errorf("unknown invoice ref modifier: %v",
ref.Modifier())
}
// Fetch the rest of the invoice data and fill the invoice struct.
_, invoice, err := fetchInvoiceData(
ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
)
if err != nil {
return nil, err
}
return invoice, nil
}
// fetchAmpState fetches the AMP state for the invoice with the given ID.
// Optional setID can be provided to fetch the state for a specific AMP HTLC
// set. If setID is nil then we'll fetch the state for all AMP sub invoices. If
// fetchHtlcs is set to true, the HTLCs for the given set will be fetched as
// well.
//
//nolint:funlen
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
HTLCSet, error) {
var paramSetID []byte
if setID != nil {
paramSetID = setID[:]
}
// First fetch all the AMP sub invoices for this invoice or the one
// matching the provided set ID.
ampInvoiceRows, err := db.FetchAMPSubInvoices(
ctx, sqlc.FetchAMPSubInvoicesParams{
InvoiceID: invoiceID,
SetID: paramSetID,
},
)
if err != nil {
return nil, nil, err
}
ampState := make(map[SetID]InvoiceStateAMP)
for _, row := range ampInvoiceRows {
var rowSetID [32]byte
if len(row.SetID) != 32 {
return nil, nil, fmt.Errorf("invalid set id length: %d",
len(row.SetID))
}
var settleDate time.Time
if row.SettledAt.Valid {
settleDate = row.SettledAt.Time.Local()
}
copy(rowSetID[:], row.SetID)
ampState[rowSetID] = InvoiceStateAMP{
State: HtlcState(row.State),
SettleIndex: uint64(row.SettleIndex.Int64),
SettleDate: settleDate,
InvoiceKeys: make(map[models.CircuitKey]struct{}),
}
}
if !fetchHtlcs {
return ampState, nil, nil
}
customRecordRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
if err != nil {
return nil, nil, fmt.Errorf("unable to get custom records for "+
"invoice HTLCs: %w", err)
}
customRecords := make(map[int64]record.CustomSet, len(customRecordRows))
for _, row := range customRecordRows {
if _, ok := customRecords[row.HtlcID]; !ok {
customRecords[row.HtlcID] = make(record.CustomSet)
}
value := row.Value
if value == nil {
value = []byte{}
}
customRecords[row.HtlcID][uint64(row.Key)] = value
}
// Now fetch all the AMP HTLCs for this invoice or the one matching the
// provided set ID.
ampHtlcRows, err := db.FetchAMPSubInvoiceHTLCs(
ctx, sqlc.FetchAMPSubInvoiceHTLCsParams{
InvoiceID: invoiceID,
SetID: paramSetID,
},
)
if err != nil {
return nil, nil, err
}
ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
for _, row := range ampHtlcRows {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil {
return nil, nil, err
}
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
if row.HtlcID < 0 {
return nil, nil, fmt.Errorf("invalid HTLC ID "+
"value: %v", row.HtlcID)
}
htlcID := uint64(row.HtlcID)
circuitKey := CircuitKey{
ChanID: chanID,
HtlcID: htlcID,
}
htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight),
State: HtlcState(row.State),
}
if row.TotalMppMsat.Valid {
htlc.MppTotalAmt = lnwire.MilliSatoshi(
row.TotalMppMsat.Int64,
)
}
if row.ResolveTime.Valid {
htlc.ResolveTime = row.ResolveTime.Time.Local()
}
var (
rootShare [32]byte
setID [32]byte
)
if len(row.RootShare) != 32 {
return nil, nil, fmt.Errorf("invalid root share "+
"length: %d", len(row.RootShare))
}
copy(rootShare[:], row.RootShare)
if len(row.SetID) != 32 {
return nil, nil, fmt.Errorf("invalid set ID length: %d",
len(row.SetID))
}
copy(setID[:], row.SetID)
if row.ChildIndex < 0 || row.ChildIndex > math.MaxUint32 {
return nil, nil, fmt.Errorf("invalid child index "+
"value: %v", row.ChildIndex)
}
ampRecord := record.NewAMP(
rootShare, setID, uint32(row.ChildIndex),
)
htlc.AMP = &InvoiceHtlcAMPData{
Record: *ampRecord,
}
if len(row.Hash) != 32 {
return nil, nil, fmt.Errorf("invalid hash length: %d",
len(row.Hash))
}
copy(htlc.AMP.Hash[:], row.Hash)
if row.Preimage != nil {
preimage, err := lntypes.MakePreimage(row.Preimage)
if err != nil {
return nil, nil, err
}
htlc.AMP.Preimage = &preimage
}
if _, ok := customRecords[row.ID]; ok {
htlc.CustomRecords = customRecords[row.ID]
} else {
htlc.CustomRecords = make(record.CustomSet)
}
ampHtlcs[circuitKey] = htlc
}
if len(ampHtlcs) > 0 {
for setID := range ampState {
var amtPaid lnwire.MilliSatoshi
invoiceKeys := make(
map[models.CircuitKey]struct{},
)
for key, htlc := range ampHtlcs {
if htlc.AMP.Record.SetID() != setID {
continue
}
invoiceKeys[key] = struct{}{}
if htlc.State != HtlcStateCanceled {
amtPaid += htlc.Amt
}
}
setState := ampState[setID]
setState.InvoiceKeys = invoiceKeys
setState.AmtPaid = amtPaid
ampState[setID] = setState
}
}
return ampState, ampHtlcs, nil
}
// LookupInvoice attempts to look up an invoice corresponding the passed in
// reference. The reference may be a payment hash, a payment address, or a set
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
// error.
func (i *SQLStore) LookupInvoice(ctx context.Context,
ref InvoiceRef) (Invoice, error) {
var (
invoice *Invoice
err error
)
readTxOpt := NewSQLInvoiceQueryReadTx()
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoice, err = fetchInvoice(ctx, db, ref)
return err
}, func() {})
if txErr != nil {
return Invoice{}, txErr
}
return *invoice, nil
}
// FetchPendingInvoices returns all the invoices that are currently in a
// "pending" state. An invoice is pending if it has been created but not yet
// settled or canceled.
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]Invoice, error) {
var invoices map[lntypes.Hash]Invoice
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
return queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
PendingOnly: true,
NumOffset: int32(offset),
NumLimit: int32(i.opts.paginationLimit),
Reverse: false,
}
rows, err := db.FilterInvoices(ctx, params)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("unable to get invoices "+
"from db: %w", err)
}
// Load all the information for the invoices.
for _, row := range rows {
hash, invoice, err := fetchInvoiceData(
ctx, db, row, nil, true,
)
if err != nil {
return 0, err
}
invoices[*hash] = *invoice
}
return len(rows), nil
}, i.opts.paginationLimit)
}, func() {
invoices = make(map[lntypes.Hash]Invoice)
})
if err != nil {
return nil, fmt.Errorf("unable to fetch pending invoices: %w",
err)
}
return invoices, nil
}
// InvoicesSettledSince can be used by callers to catch up any settled invoices
// they missed within the settled invoice time series. We'll return all known
// settled invoice that have a settle index higher than the passed idx.
//
// NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop.
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
[]Invoice, error) {
var (
invoices []Invoice
start = time.Now()
lastLogTime = time.Now()
processedCount int
)
if idx == 0 {
return invoices, nil
}
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
err := queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
SettleIndexGet: sqldb.SQLInt64(idx + 1),
NumOffset: int32(offset),
NumLimit: int32(i.opts.paginationLimit),
Reverse: false,
}
rows, err := db.FilterInvoices(ctx, params)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("unable to get invoices "+
"from db: %w", err)
}
// Load all the information for the invoices.
for _, row := range rows {
_, invoice, err := fetchInvoiceData(
ctx, db, row, nil, true,
)
if err != nil {
return 0, fmt.Errorf("unable to fetch "+
"invoice(id=%d) from db: %w",
row.ID, err)
}
invoices = append(invoices, *invoice)
processedCount++
if time.Since(lastLogTime) >=
invoiceProgressLogInterval {
log.Debugf("Processed %d settled "+
"invoices which have a settle "+
"index greater than %v",
processedCount, idx)
lastLogTime = time.Now()
}
}
return len(rows), nil
}, i.opts.paginationLimit)
if err != nil {
return err
}
// Now fetch all the AMP sub invoices that were settled since
// the provided index.
ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
ctx, sqlc.FetchSettledAMPSubInvoicesParams{
SettleIndexGet: sqldb.SQLInt64(idx + 1),
},
)
if err != nil {
return err
}
for _, ampInvoice := range ampInvoices {
// Convert the row to a sqlc.Invoice so we can use the
// existing fetchInvoiceData function.
sqlInvoice := sqlc.Invoice{
ID: ampInvoice.ID,
Hash: ampInvoice.Hash,
Preimage: ampInvoice.Preimage,
SettleIndex: ampInvoice.AmpSettleIndex,
SettledAt: ampInvoice.AmpSettledAt,
Memo: ampInvoice.Memo,
AmountMsat: ampInvoice.AmountMsat,
CltvDelta: ampInvoice.CltvDelta,
Expiry: ampInvoice.Expiry,
PaymentAddr: ampInvoice.PaymentAddr,
PaymentRequest: ampInvoice.PaymentRequest,
State: ampInvoice.State,
AmountPaidMsat: ampInvoice.AmountPaidMsat,
IsAmp: ampInvoice.IsAmp,
IsHodl: ampInvoice.IsHodl,
IsKeysend: ampInvoice.IsKeysend,
CreatedAt: ampInvoice.CreatedAt.UTC(),
}
// Fetch the state and HTLCs for this AMP sub invoice.
_, invoice, err := fetchInvoiceData(
ctx, db, sqlInvoice,
(*[32]byte)(ampInvoice.SetID), true,
)
if err != nil {
return fmt.Errorf("unable to fetch "+
"AMP invoice(id=%d) from db: %w",
ampInvoice.ID, err)
}
invoices = append(invoices, *invoice)
processedCount++
if time.Since(lastLogTime) >=
invoiceProgressLogInterval {
log.Debugf("Processed %d settled invoices "+
"including AMP sub invoices which "+
"have a settle index greater than %v",
processedCount, idx)
lastLogTime = time.Now()
}
}
return nil
}, func() {
invoices = nil
})
if err != nil {
return nil, fmt.Errorf("unable to get invoices settled since "+
"index (excluding) %d: %w", idx, err)
}
elapsed := time.Since(start)
log.Debugf("Completed scanning for settled invoices starting at "+
"index %v: total_processed=%d, found_invoices=%d, elapsed=%v",
idx, processedCount, len(invoices),
elapsed.Round(time.Millisecond))
return invoices, nil
}
// InvoicesAddedSince can be used by callers to seek into the event time series
// of all the invoices added in the database. This method will return all
// invoices with an add index greater than the specified idx.
//
// NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop.
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
[]Invoice, error) {
var (
result []Invoice
start = time.Now()
lastLogTime = time.Now()
processedCount int
)
if idx == 0 {
return result, nil
}
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
return queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
AddIndexGet: sqldb.SQLInt64(idx + 1),
NumOffset: int32(offset),
NumLimit: int32(i.opts.paginationLimit),
Reverse: false,
}
rows, err := db.FilterInvoices(ctx, params)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("unable to get invoices "+
"from db: %w", err)
}
// Load all the information for the invoices.
for _, row := range rows {
_, invoice, err := fetchInvoiceData(
ctx, db, row, nil, true,
)
if err != nil {
return 0, err
}
result = append(result, *invoice)
processedCount++
if time.Since(lastLogTime) >=
invoiceProgressLogInterval {
log.Debugf("Processed %d invoices "+
"which were added since add "+
"index %v", processedCount, idx)
lastLogTime = time.Now()
}
}
return len(rows), nil
}, i.opts.paginationLimit)
}, func() {
result = nil
})
if err != nil {
return nil, fmt.Errorf("unable to get invoices added since "+
"index %d: %w", idx, err)
}
elapsed := time.Since(start)
log.Debugf("Completed scanning for invoices added since index %v: "+
"total_processed=%d, found_invoices=%d, elapsed=%v",
idx, processedCount, len(result),
elapsed.Round(time.Millisecond))
return result, nil
}
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (i *SQLStore) QueryInvoices(ctx context.Context,
q InvoiceQuery) (InvoiceSlice, error) {
var invoices []Invoice
if q.NumMaxInvoices == 0 {
return InvoiceSlice{}, fmt.Errorf("max invoices must " +
"be non-zero")
}
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
return queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
NumOffset: int32(offset),
NumLimit: int32(i.opts.paginationLimit),
PendingOnly: q.PendingOnly,
Reverse: q.Reversed,
}
if q.Reversed {
// If the index offset was not set, we want to
// fetch from the lastest invoice.
if q.IndexOffset == 0 {
params.AddIndexLet = sqldb.SQLInt64(
int64(math.MaxInt64),
)
} else {
// The invoice with index offset id must
// not be included in the results.
params.AddIndexLet = sqldb.SQLInt64(
q.IndexOffset - 1,
)
}
} else {
// The invoice with index offset id must not be
// included in the results.
params.AddIndexGet = sqldb.SQLInt64(
q.IndexOffset + 1,
)
}
if q.CreationDateStart != 0 {
params.CreatedAfter = sqldb.SQLTime(
time.Unix(q.CreationDateStart, 0).UTC(),
)
}
if q.CreationDateEnd != 0 {
// We need to add 1 to the end date as we're
// checking less than the end date in SQL.
params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd+1, 0).UTC(),
)
}
rows, err := db.FilterInvoices(ctx, params)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("unable to get invoices "+
"from db: %w", err)
}
// Load all the information for the invoices.
for _, row := range rows {
_, invoice, err := fetchInvoiceData(
ctx, db, row, nil, true,
)
if err != nil {
return 0, err
}
invoices = append(invoices, *invoice)
if len(invoices) == int(q.NumMaxInvoices) {
return 0, nil
}
}
return len(rows), nil
}, i.opts.paginationLimit)
}, func() {
invoices = nil
})
if err != nil {
return InvoiceSlice{}, fmt.Errorf("unable to query "+
"invoices: %w", err)
}
if len(invoices) == 0 {
return InvoiceSlice{
InvoiceQuery: q,
}, nil
}
// If we iterated through the add index in reverse order, then
// we'll need to reverse the slice of invoices to return them in
// forward order.
if q.Reversed {
numInvoices := len(invoices)
for i := 0; i < numInvoices/2; i++ {
reverse := numInvoices - i - 1
invoices[i], invoices[reverse] =
invoices[reverse], invoices[i]
}
}
res := InvoiceSlice{
InvoiceQuery: q,
Invoices: invoices,
FirstIndexOffset: invoices[0].AddIndex,
LastIndexOffset: invoices[len(invoices)-1].AddIndex,
}
return res, nil
}
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
// a SQL database as the backend.
type sqlInvoiceUpdater struct {
db SQLInvoiceQueries
ctx context.Context //nolint:containedctx
invoice *Invoice
updateTime time.Time
}
// AddHtlc adds a new htlc to the invoice.
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
newHtlc *InvoiceHTLC) error {
htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
s.ctx, sqlc.InsertInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID),
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
AmountMsat: int64(newHtlc.Amt),
TotalMppMsat: sql.NullInt64{
Int64: int64(newHtlc.MppTotalAmt),
Valid: newHtlc.MppTotalAmt != 0,
},
AcceptHeight: int32(newHtlc.AcceptHeight),
AcceptTime: newHtlc.AcceptTime.UTC(),
ExpiryHeight: int32(newHtlc.Expiry),
State: int16(newHtlc.State),
InvoiceID: int64(s.invoice.AddIndex),
},
)
if err != nil {
return err
}
for key, value := range newHtlc.CustomRecords {
err = s.db.InsertInvoiceHTLCCustomRecord(
s.ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
// TODO(bhandras): schema might be wrong here
// as the custom record key is an uint64.
Key: int64(key),
Value: value,
HtlcID: htlcPrimaryKeyID,
},
)
if err != nil {
return err
}
}
if newHtlc.AMP != nil {
setID := newHtlc.AMP.Record.SetID()
upsertResult, err := s.db.UpsertAMPSubInvoice(
s.ctx, sqlc.UpsertAMPSubInvoiceParams{
SetID: setID[:],
CreatedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
},
)
if err != nil {
mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:ll
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
return ErrDuplicateSetID{
SetID: setID,
}
}
return err
}
// If we're just inserting the AMP invoice, we'll get a non
// zero rows affected count.
rowsAffected, err := upsertResult.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 0 {
// If we're inserting a new AMP invoice, we'll also
// insert a new invoice event.
err = s.db.OnAMPSubInvoiceCreated(
s.ctx, sqlc.OnAMPSubInvoiceCreatedParams{
AddedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
SetID: setID[:],
},
)
if err != nil {
return err
}
}
rootShare := newHtlc.AMP.Record.RootShare()
ampHtlcParams := sqlc.InsertAMPSubInvoiceHTLCParams{
InvoiceID: int64(s.invoice.AddIndex),
SetID: setID[:],
HtlcID: htlcPrimaryKeyID,
RootShare: rootShare[:],
ChildIndex: int64(
newHtlc.AMP.Record.ChildIndex(),
),
Hash: newHtlc.AMP.Hash[:],
}
if newHtlc.AMP.Preimage != nil {
ampHtlcParams.Preimage = newHtlc.AMP.Preimage[:]
}
err = s.db.InsertAMPSubInvoiceHTLC(s.ctx, ampHtlcParams)
if err != nil {
return err
}
}
return nil
}
// ResolveHtlc marks an htlc as resolved with the given state.
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
state HtlcState, resolveTime time.Time) error {
return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID),
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
InvoiceID: int64(s.invoice.AddIndex),
State: int16(state),
ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
})
}
// AddAmpHtlcPreimage adds a preimage of an AMP htlc to the AMP sub invoice
// identified by the setID.
func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
circuitKey models.CircuitKey, preimage lntypes.Preimage) error {
result, err := s.db.UpdateAMPSubInvoiceHTLCPreimage(
s.ctx, sqlc.UpdateAMPSubInvoiceHTLCPreimageParams{
InvoiceID: int64(s.invoice.AddIndex),
SetID: setID[:],
HtlcID: int64(circuitKey.HtlcID),
Preimage: preimage[:],
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
},
)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrInvoiceNotFound
}
return nil
}
// UpdateInvoiceState updates the invoice state to the new state.
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
newState ContractState, preimage *lntypes.Preimage) error {
var (
settleIndex sql.NullInt64
settledAt sql.NullTime
)
switch newState {
case ContractSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil {
return err
}
settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time.
settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnInvoiceSettled(
s.ctx, sqlc.OnInvoiceSettledParams{
AddedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
},
)
if err != nil {
return err
}
case ContractCanceled:
err := s.db.OnInvoiceCanceled(
s.ctx, sqlc.OnInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
},
)
if err != nil {
return err
}
}
params := sqlc.UpdateInvoiceStateParams{
ID: int64(s.invoice.AddIndex),
State: int16(newState),
SettleIndex: settleIndex,
SettledAt: settledAt,
}
if preimage != nil {
params.Preimage = preimage[:]
}
result, err := s.db.UpdateInvoiceState(s.ctx, params)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrInvoiceNotFound
}
if settleIndex.Valid {
s.invoice.SettleIndex = uint64(settleIndex.Int64)
s.invoice.SettleDate = s.updateTime
}
return nil
}
// UpdateInvoiceAmtPaid updates the invoice amount paid to the new amount.
func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
amtPaid lnwire.MilliSatoshi) error {
_, err := s.db.UpdateInvoiceAmountPaid(
s.ctx, sqlc.UpdateInvoiceAmountPaidParams{
ID: int64(s.invoice.AddIndex),
AmountPaidMsat: int64(amtPaid),
},
)
return err
}
// UpdateAmpState updates the state of the AMP sub invoice identified by the
// setID.
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
newState InvoiceStateAMP, _ models.CircuitKey) error {
var (
settleIndex sql.NullInt64
settledAt sql.NullTime
)
switch newState.State {
case HtlcStateSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil {
return err
}
settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time.
settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnAMPSubInvoiceSettled(
s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
AddedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
SetID: setID[:],
},
)
if err != nil {
return err
}
case HtlcStateCanceled:
err := s.db.OnAMPSubInvoiceCanceled(
s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(),
InvoiceID: int64(s.invoice.AddIndex),
SetID: setID[:],
},
)
if err != nil {
return err
}
}
err := s.db.UpdateAMPSubInvoiceState(
s.ctx, sqlc.UpdateAMPSubInvoiceStateParams{
SetID: setID[:],
State: int16(newState.State),
SettleIndex: settleIndex,
SettledAt: settledAt,
},
)
if err != nil {
return err
}
if settleIndex.Valid {
updatedState := s.invoice.AMPState[setID]
updatedState.SettleIndex = uint64(settleIndex.Int64)
updatedState.SettleDate = s.updateTime.UTC()
s.invoice.AMPState[setID] = updatedState
}
return nil
}
// Finalize finalizes the update before it is written to the database. Note that
// we don't use this directly in the SQL implementation, so the function is just
// a stub.
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
return nil
}
// UpdateInvoice attempts to update an invoice corresponding to the passed
// reference. If an invoice matching the passed reference doesn't exist within
// the database, then the action will fail with ErrInvoiceNotFound error.
//
// The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {
var updatedInvoice *Invoice
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
switch {
// For the default case we fetch all HTLCs.
case setID == nil:
ref.refModifier = DefaultModifier
// If the setID is the blank but NOT nil, we set the
// refModifier to HtlcSetBlankModifier to fetch no HTLC for the
// AMP invoice.
case *setID == BlankPayAddr:
ref.refModifier = HtlcSetBlankModifier
// A setID is provided, we use the refModifier to fetch only
// the HTLCs for the given setID and also make sure we add the
// setID to the ref.
default:
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes
// We only fetch the HTLCs for the given setID.
ref.refModifier = HtlcSetOnlyModifier
}
invoice, err := fetchInvoice(ctx, db, ref)
if err != nil {
return err
}
updateTime := i.clock.Now()
updater := &sqlInvoiceUpdater{
db: db,
ctx: ctx,
invoice: invoice,
updateTime: updateTime,
}
payHash := ref.PayHash()
updatedInvoice, err = UpdateInvoice(
payHash, invoice, updateTime, callback, updater,
)
return err
}, func() {})
if txErr != nil {
// If the invoice is already settled, we'll return the
// (unchanged) invoice and the ErrInvoiceAlreadySettled error.
if errors.Is(txErr, ErrInvoiceAlreadySettled) {
return updatedInvoice, txErr
}
return nil, txErr
}
return updatedInvoice, nil
}
// DeleteInvoice attempts to delete the passed invoices and all their related
// data from the database in one transaction.
func (i *SQLStore) DeleteInvoice(ctx context.Context,
invoicesToDelete []InvoiceDeleteRef) error {
// All the InvoiceDeleteRef instances include the add index of the
// invoice. The rest was added to ensure that the invoices were deleted
// properly in the kv database. When we have fully migrated we can
// remove the rest of the fields.
for _, ref := range invoicesToDelete {
if ref.AddIndex == 0 {
return fmt.Errorf("unable to delete invoice using a "+
"ref without AddIndex set: %v", ref)
}
}
var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
for _, ref := range invoicesToDelete {
params := sqlc.DeleteInvoiceParams{
AddIndex: sqldb.SQLInt64(ref.AddIndex),
}
if ref.SettleIndex != 0 {
params.SettleIndex = sqldb.SQLInt64(
ref.SettleIndex,
)
}
if ref.PayHash != lntypes.ZeroHash {
params.Hash = ref.PayHash[:]
}
result, err := db.DeleteInvoice(ctx, params)
if err != nil {
return fmt.Errorf("unable to delete "+
"invoice(%v): %w", ref.AddIndex, err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("unable to get rows "+
"affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("%w: %v",
ErrInvoiceNotFound, ref.AddIndex)
}
}
return nil
}, func() {})
if err != nil {
return fmt.Errorf("unable to delete invoices: %w", err)
}
return nil
}
// DeleteCanceledInvoices removes all canceled invoices from the database.
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
_, err := db.DeleteCanceledInvoices(ctx)
if err != nil {
return fmt.Errorf("unable to delete canceled "+
"invoices: %w", err)
}
return nil
}, func() {})
if err != nil {
return fmt.Errorf("unable to delete invoices: %w", err)
}
return nil
}
// fetchInvoiceData fetches additional data for the given invoice. If the
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
*Invoice, error) {
// Unmarshal the common data.
hash, invoice, err := unmarshalInvoice(row)
if err != nil {
return nil, nil, fmt.Errorf("unable to unmarshal "+
"invoice(id=%d) from db: %w", row.ID, err)
}
// Fetch the invoice features.
features, err := getInvoiceFeatures(ctx, db, row.ID)
if err != nil {
return nil, nil, err
}
invoice.Terms.Features = features
// If this is an AMP invoice, we'll need fetch the AMP state along
// with the HTLCs (if requested).
if invoice.IsAMP() {
invoiceID := int64(invoice.AddIndex)
ampState, ampHtlcs, err := fetchAmpState(
ctx, db, invoiceID, setID, fetchAmpHtlcs,
)
if err != nil {
return nil, nil, err
}
invoice.AMPState = ampState
invoice.Htlcs = ampHtlcs
return hash, invoice, nil
}
// Otherwise simply fetch the invoice HTLCs.
htlcs, err := getInvoiceHtlcs(ctx, db, row.ID)
if err != nil {
return nil, nil, err
}
if len(htlcs) > 0 {
invoice.Htlcs = htlcs
}
return hash, invoice, nil
}
// getInvoiceFeatures fetches the invoice features for the given invoice id.
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (*lnwire.FeatureVector, error) {
rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
if err != nil {
return nil, fmt.Errorf("unable to get invoice features: %w",
err)
}
features := lnwire.EmptyFeatureVector()
for _, feature := range rows {
features.Set(lnwire.FeatureBit(feature.Feature))
}
return features, nil
}
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
if err != nil {
return nil, fmt.Errorf("unable to get invoice htlcs: %w", err)
}
// We have no htlcs to unmarshal.
if len(htlcRows) == 0 {
return nil, nil
}
crRows, err := db.GetInvoiceHTLCCustomRecords(ctx, invoiceID)
if err != nil {
return nil, fmt.Errorf("unable to get custom records for "+
"invoice htlcs: %w", err)
}
cr := make(map[int64]record.CustomSet, len(crRows))
for _, row := range crRows {
if _, ok := cr[row.HtlcID]; !ok {
cr[row.HtlcID] = make(record.CustomSet)
}
value := row.Value
if value == nil {
value = []byte{}
}
cr[row.HtlcID][uint64(row.Key)] = value
}
htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
for _, row := range htlcRows {
circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal "+
"htlc(%d): %w", row.ID, err)
}
if customRecords, ok := cr[row.ID]; ok {
htlc.CustomRecords = customRecords
} else {
htlc.CustomRecords = make(record.CustomSet)
}
htlcs[circuiteKey] = htlc
}
return htlcs, nil
}
// unmarshalInvoice converts an InvoiceRow to an Invoice.
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
error) {
var (
settleIndex int64
settledAt time.Time
memo []byte
paymentRequest []byte
preimage *lntypes.Preimage
paymentAddr [32]byte
)
hash, err := lntypes.MakeHash(row.Hash)
if err != nil {
return nil, nil, err
}
if row.SettleIndex.Valid {
settleIndex = row.SettleIndex.Int64
}
if row.SettledAt.Valid {
settledAt = row.SettledAt.Time.Local()
}
if row.Memo.Valid {
memo = []byte(row.Memo.String)
}
// Keysend payments will have this field empty.
if row.PaymentRequest.Valid {
paymentRequest = []byte(row.PaymentRequest.String)
} else {
paymentRequest = []byte{}
}
// We may not have the preimage if this a hodl invoice.
if row.Preimage != nil {
preimage = &lntypes.Preimage{}
copy(preimage[:], row.Preimage)
}
copy(paymentAddr[:], row.PaymentAddr)
var cltvDelta int32
if row.CltvDelta.Valid {
cltvDelta = row.CltvDelta.Int32
}
expiry := time.Duration(row.Expiry) * time.Second
invoice := &Invoice{
SettleIndex: uint64(settleIndex),
SettleDate: settledAt,
Memo: memo,
PaymentRequest: paymentRequest,
CreationDate: row.CreatedAt.Local(),
Terms: ContractTerm{
FinalCltvDelta: cltvDelta,
Expiry: expiry,
PaymentPreimage: preimage,
Value: lnwire.MilliSatoshi(row.AmountMsat),
PaymentAddr: paymentAddr,
},
AddIndex: uint64(row.ID),
State: ContractState(row.State),
AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat),
Htlcs: make(map[models.CircuitKey]*InvoiceHTLC),
AMPState: AMPInvoiceState{},
HodlInvoice: row.IsHodl,
}
return &hash, invoice, nil
}
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
*InvoiceHTLC, error) {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil {
return CircuitKey{}, nil, err
}
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
if row.HtlcID < 0 {
return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
"value: %v", row.HtlcID)
}
htlcID := uint64(row.HtlcID)
circuitKey := CircuitKey{
ChanID: chanID,
HtlcID: htlcID,
}
htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight),
State: HtlcState(row.State),
}
if row.TotalMppMsat.Valid {
htlc.MppTotalAmt = lnwire.MilliSatoshi(row.TotalMppMsat.Int64)
}
if row.ResolveTime.Valid {
htlc.ResolveTime = row.ResolveTime.Time.Local()
}
return circuitKey, htlc, nil
}
// queryWithLimit is a helper method that can be used to query the database
// using a limit and offset. The passed query function should return the number
// of rows returned and an error if any.
func queryWithLimit(query func(int) (int, error), limit int) error {
offset := 0
for {
rows, err := query(offset)
if err != nil {
return err
}
if rows < limit {
return nil
}
offset += limit
}
}
package invoices
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
type mockChainNotifier struct {
chainntnfs.ChainNotifier
blockChan chan *chainntnfs.BlockEpoch
}
func newMockNotifier() *mockChainNotifier {
return &mockChainNotifier{
blockChan: make(chan *chainntnfs.BlockEpoch),
}
}
// RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's
// block channel to deliver blocks to the client.
func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: m.blockChan,
Cancel: func() {},
}, nil
}
const (
testCurrentHeight = int32(1)
)
var (
testTimeout = 5 * time.Second
testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
testPrivKeyBytes, _ = hex.DecodeString(
"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2d" +
"b734",
)
testPrivKey, _ = btcec.PrivKeyFromBytes(testPrivKeyBytes)
testInvoiceDescription = "coffee"
testInvoiceAmount = lnwire.MilliSatoshi(100000)
testNetParams = &chaincfg.MainNetParams
testMessageSigner = zpay32.MessageSigner{
SignCompact: func(msg []byte) ([]byte, error) {
hash := chainhash.HashB(msg)
sig := ecdsa.SignCompact(testPrivKey, hash, true)
return sig, nil
},
}
testFeatures = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
)
func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
timestamp time.Time, expiry time.Duration) *Invoice {
t.Helper()
if expiry == 0 {
expiry = time.Hour
}
var payAddr [32]byte
if _, err := rand.Read(payAddr[:]); err != nil {
t.Fatalf("unable to generate payment addr: %v", err)
}
rawInvoice, err := zpay32.NewInvoice(
testNetParams,
preimage.Hash(),
timestamp,
zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription),
zpay32.Expiry(expiry),
zpay32.PaymentAddr(payAddr),
)
require.NoError(t, err, "Error while creating new invoice")
paymentRequest, err := rawInvoice.Encode(testMessageSigner)
require.NoError(t, err, "Error while encoding payment request")
return &Invoice{
Terms: ContractTerm{
PaymentPreimage: &preimage,
PaymentAddr: payAddr,
Value: testInvoiceAmount,
Expiry: expiry,
Features: testFeatures,
},
PaymentRequest: []byte(paymentRequest),
CreationDate: timestamp,
}
}
// invoiceExpiryTestData simply holds generated expired and pending invoices.
type invoiceExpiryTestData struct {
expiredInvoices map[lntypes.Hash]*Invoice
pendingInvoices map[lntypes.Hash]*Invoice
}
// generateInvoiceExpiryTestData generates the specified number of fake expired
// and pending invoices anchored to the passed now timestamp.
func generateInvoiceExpiryTestData(
t *testing.T, now time.Time,
offset, numExpired, numPending int) invoiceExpiryTestData {
t.Helper()
var testData invoiceExpiryTestData
testData.expiredInvoices = make(map[lntypes.Hash]*Invoice)
testData.pendingInvoices = make(map[lntypes.Hash]*Invoice)
expiredCreationDate := now.Add(-24 * time.Hour)
for i := 1; i <= numExpired; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
duration := (i + offset) % 24
expiry := time.Duration(duration) * time.Hour
invoice := newTestInvoice(
t, preimage, expiredCreationDate, expiry,
)
testData.expiredInvoices[preimage.Hash()] = invoice
}
for i := 1; i <= numPending; i++ {
var preimage lntypes.Preimage
binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
duration := (i + offset) % 24
expiry := time.Duration(duration) * time.Hour
invoice := newTestInvoice(t, preimage, now, expiry)
testData.pendingInvoices[preimage.Hash()] = invoice
}
return testData
}
type hodlExpiryTest struct {
hash lntypes.Hash
state ContractState
stateLock sync.Mutex
mockNotifier *mockChainNotifier
mockClock *clock.TestClock
cancelChan chan lntypes.Hash
watcher *InvoiceExpiryWatcher
}
func (h *hodlExpiryTest) setState(state ContractState) {
h.stateLock.Lock()
defer h.stateLock.Unlock()
h.state = state
}
func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) {
t.Helper()
select {
case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{
Height: int32(height),
}:
case <-time.After(testTimeout):
t.Fatalf("block %v not consumed", height)
}
}
func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) {
t.Helper()
select {
case actual := <-h.cancelChan:
require.Equal(t, expected, actual)
case <-time.After(testTimeout):
t.Fatalf("invoice: %v not canceled", h.hash)
}
}
// setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an
// arbitrary update function which advances the invoices's state.
func setupHodlExpiry(t *testing.T, creationDate time.Time,
expiry time.Duration, heightDelta uint32,
startState ContractState,
startHtlcs []*InvoiceHTLC) *hodlExpiryTest {
t.Helper()
mockNotifier := newMockNotifier()
mockClock := clock.NewTestClock(testTime)
test := &hodlExpiryTest{
state: startState,
watcher: NewInvoiceExpiryWatcher(
mockClock, heightDelta, uint32(testCurrentHeight), nil,
mockNotifier,
),
cancelChan: make(chan lntypes.Hash),
mockNotifier: mockNotifier,
mockClock: mockClock,
}
// Use an unbuffered channel to block on cancel calls so that the test
// does not exit before we've processed all the invoices we expect.
cancelImpl := func(paymentHash lntypes.Hash, force bool) error {
test.stateLock.Lock()
currentState := test.state
test.stateLock.Unlock()
if currentState != ContractOpen && !force {
return nil
}
select {
case test.cancelChan <- paymentHash:
case <-time.After(testTimeout):
}
return nil
}
require.NoError(t, test.watcher.Start(cancelImpl))
// We set preimage and hash so that we can use our existing test
// helpers. In practice we would only have the hash, but this does not
// affect what we're testing at all.
preimage := lntypes.Preimage{1}
test.hash = preimage.Hash()
invoice := newTestInvoice(t, preimage, creationDate, expiry)
invoice.State = startState
invoice.HodlInvoice = true
invoice.Htlcs = make(map[CircuitKey]*InvoiceHTLC)
// If we have any htlcs, add them with unique circult keys.
for i, htlc := range startHtlcs {
key := CircuitKey{
HtlcID: uint64(i),
}
invoice.Htlcs[key] = htlc
}
// Create an expiry entry for our invoice in its starting state. This
// mimics adding invoices to the watcher on start.
entry := makeInvoiceExpiry(test.hash, invoice)
test.watcher.AddInvoices(entry)
return test
}
package invoices
import (
"bytes"
"encoding/hex"
"errors"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/amp"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
)
// invoiceUpdateCtx is an object that describes the context for the invoice
// update to be carried out.
type invoiceUpdateCtx struct {
hash lntypes.Hash
circuitKey CircuitKey
amtPaid lnwire.MilliSatoshi
expiry uint32
currentHeight int32
finalCltvRejectDelta int32
// wireCustomRecords are the custom records that were included with the
// HTLC wire message.
wireCustomRecords lnwire.CustomRecords
// customRecords is a map of custom records that were included with the
// HTLC onion payload.
customRecords record.CustomSet
mpp *record.MPP
amp *record.AMP
metadata []byte
pathID *chainhash.Hash
totalAmtMsat lnwire.MilliSatoshi
}
// invoiceRef returns an identifier that can be used to lookup or update the
// invoice this HTLC is targeting.
func (i *invoiceUpdateCtx) invoiceRef() InvoiceRef {
switch {
case i.pathID != nil:
return InvoiceRefByHashAndAddr(i.hash, *i.pathID)
case i.amp != nil && i.mpp != nil:
payAddr := i.mpp.PaymentAddr()
return InvoiceRefByAddr(payAddr)
case i.mpp != nil:
payAddr := i.mpp.PaymentAddr()
return InvoiceRefByHashAndAddr(i.hash, payAddr)
default:
return InvoiceRefByHash(i.hash)
}
}
// setID returns an identifier that identifies other possible HTLCs that this
// particular one is related to. If nil is returned this means the HTLC is an
// MPP or legacy payment, otherwise the HTLC belongs AMP payment.
func (i invoiceUpdateCtx) setID() *[32]byte {
if i.amp != nil {
setID := i.amp.SetID()
return &setID
}
return nil
}
// log logs a message specific to this update context.
func (i *invoiceUpdateCtx) log(s string) {
// Don't use %x in the log statement below, because it doesn't
// distinguish between nil and empty metadata.
metadata := "<nil>"
if i.metadata != nil {
metadata = hex.EncodeToString(i.metadata)
}
log.Debugf("Invoice%v: %v, amt=%v, expiry=%v, circuit=%v, mpp=%v, "+
"amp=%v, metadata=%v", i.invoiceRef(), s, i.amtPaid, i.expiry,
i.circuitKey, i.mpp, i.amp, metadata)
}
// failRes is a helper function which creates a failure resolution with
// the information contained in the invoiceUpdateCtx and the fail resolution
// result provided.
func (i invoiceUpdateCtx) failRes(outcome FailResolutionResult) *HtlcFailResolution {
return NewFailResolution(i.circuitKey, i.currentHeight, outcome)
}
// settleRes is a helper function which creates a settle resolution with
// the information contained in the invoiceUpdateCtx and the preimage and
// the settle resolution result provided.
func (i invoiceUpdateCtx) settleRes(preimage lntypes.Preimage,
outcome SettleResolutionResult) *HtlcSettleResolution {
return NewSettleResolution(
preimage, i.circuitKey, i.currentHeight, outcome,
)
}
// acceptRes is a helper function which creates an accept resolution with
// the information contained in the invoiceUpdateCtx and the accept resolution
// result provided.
func (i invoiceUpdateCtx) acceptRes(
outcome acceptResolutionResult) *htlcAcceptResolution {
return newAcceptResolution(i.circuitKey, outcome)
}
// resolveReplayedHtlc returns the HTLC resolution for a replayed HTLC. The
// returned boolean indicates whether the HTLC was replayed or not.
func resolveReplayedHtlc(ctx *invoiceUpdateCtx, inv *Invoice) (bool,
HtlcResolution, error) {
// Don't update the invoice when this is a replayed htlc.
htlc, replayedHTLC := inv.Htlcs[ctx.circuitKey]
if !replayedHTLC {
return false, nil, nil
}
switch htlc.State {
case HtlcStateCanceled:
return true, ctx.failRes(ResultReplayToCanceled), nil
case HtlcStateAccepted:
return true, ctx.acceptRes(resultReplayToAccepted), nil
case HtlcStateSettled:
pre := inv.Terms.PaymentPreimage
// Terms.PaymentPreimage will be nil for AMP invoices.
// Set it to the HTLCs AMP Preimage instead.
if pre == nil {
pre = htlc.AMP.Preimage
}
return true, ctx.settleRes(
*pre,
ResultReplayToSettled,
), nil
default:
return true, nil, errors.New("unknown htlc state")
}
}
// updateInvoice is a callback for DB.UpdateInvoice that contains the invoice
// settlement logic. It returns a HTLC resolution that indicates what the
// outcome of the update was.
//
// NOTE: Make sure replayed HTLCs are always considered before calling this
// function.
func updateInvoice(ctx *invoiceUpdateCtx, inv *Invoice) (
*InvoiceUpdateDesc, HtlcResolution, error) {
// If no MPP payload was provided, then we expect this to be a keysend,
// or a payment to an invoice created before we started to require the
// MPP payload.
if ctx.mpp == nil && ctx.pathID == nil {
return updateLegacy(ctx, inv)
}
return updateMpp(ctx, inv)
}
// updateMpp is a callback for DB.UpdateInvoice that contains the invoice
// settlement logic for mpp payments.
func updateMpp(ctx *invoiceUpdateCtx, inv *Invoice) (*InvoiceUpdateDesc,
HtlcResolution, error) {
// Reject HTLCs to AMP invoices if they are missing an AMP payload, and
// HTLCs to MPP invoices if they have an AMP payload.
switch {
case inv.Terms.Features.RequiresFeature(lnwire.AMPRequired) &&
ctx.amp == nil:
return nil, ctx.failRes(ResultHtlcInvoiceTypeMismatch), nil
case !inv.Terms.Features.RequiresFeature(lnwire.AMPRequired) &&
ctx.amp != nil:
return nil, ctx.failRes(ResultHtlcInvoiceTypeMismatch), nil
}
setID := ctx.setID()
var (
totalAmt = ctx.totalAmtMsat
paymentAddr []byte
)
// If an MPP record is present, then the payment address and total
// payment amount is extracted from it. Otherwise, the pathID is used
// to extract the payment address.
if ctx.mpp != nil {
totalAmt = ctx.mpp.TotalMsat()
payAddr := ctx.mpp.PaymentAddr()
paymentAddr = payAddr[:]
} else {
paymentAddr = ctx.pathID[:]
}
// For storage, we don't really care where the custom records came from.
// So we merge them together and store them in the same field.
customRecords := lnwire.CustomRecords(
ctx.customRecords,
).MergedCopy(ctx.wireCustomRecords)
// Start building the accept descriptor.
acceptDesc := &HtlcAcceptDesc{
Amt: ctx.amtPaid,
Expiry: ctx.expiry,
AcceptHeight: ctx.currentHeight,
MppTotalAmt: totalAmt,
CustomRecords: record.CustomSet(customRecords),
}
if ctx.amp != nil {
acceptDesc.AMP = &InvoiceHtlcAMPData{
Record: *ctx.amp,
Hash: ctx.hash,
Preimage: nil,
}
}
// Only accept payments to open invoices. This behaviour differs from
// non-mpp payments that are accepted even after the invoice is settled.
// Because non-mpp payments don't have a payment address, this is needed
// to thwart probing.
if inv.State != ContractOpen {
return nil, ctx.failRes(ResultInvoiceNotOpen), nil
}
// Check the payment address that authorizes the payment.
if !bytes.Equal(paymentAddr, inv.Terms.PaymentAddr[:]) {
return nil, ctx.failRes(ResultAddressMismatch), nil
}
// Don't accept zero-valued sets.
if totalAmt == 0 {
return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil
}
// Check that the total amt of the htlc set is high enough. In case this
// is a zero-valued invoice, it will always be enough.
if totalAmt < inv.Terms.Value {
return nil, ctx.failRes(ResultHtlcSetTotalTooLow), nil
}
htlcSet := inv.HTLCSet(setID, HtlcStateAccepted)
// Check whether total amt matches other HTLCs in the set.
var newSetTotal lnwire.MilliSatoshi
for _, htlc := range htlcSet {
if totalAmt != htlc.MppTotalAmt {
return nil, ctx.failRes(ResultHtlcSetTotalMismatch), nil
}
newSetTotal += htlc.Amt
}
// Add amount of new htlc.
newSetTotal += ctx.amtPaid
// The invoice is still open. Check the expiry.
if ctx.expiry < uint32(ctx.currentHeight+ctx.finalCltvRejectDelta) {
return nil, ctx.failRes(ResultExpiryTooSoon), nil
}
if ctx.expiry < uint32(ctx.currentHeight+inv.Terms.FinalCltvDelta) {
return nil, ctx.failRes(ResultExpiryTooSoon), nil
}
if setID != nil && *setID == BlankPayAddr {
return nil, ctx.failRes(ResultAmpError), nil
}
// Record HTLC in the invoice database.
newHtlcs := map[CircuitKey]*HtlcAcceptDesc{
ctx.circuitKey: acceptDesc,
}
update := InvoiceUpdateDesc{
UpdateType: AddHTLCsUpdate,
AddHtlcs: newHtlcs,
}
// If the invoice cannot be settled yet, only record the htlc.
setComplete := newSetTotal >= totalAmt
if !setComplete {
return &update, ctx.acceptRes(resultPartialAccepted), nil
}
// Check to see if we can settle or this is a hold invoice, and
// we need to wait for the preimage.
if inv.HodlInvoice {
update.State = &InvoiceStateUpdateDesc{
NewState: ContractAccepted,
}
return &update, ctx.acceptRes(resultAccepted), nil
}
var (
htlcPreimages map[CircuitKey]lntypes.Preimage
htlcPreimage lntypes.Preimage
)
if ctx.amp != nil {
var failRes *HtlcFailResolution
htlcPreimages, failRes = reconstructAMPPreimages(ctx, htlcSet)
if failRes != nil {
update.UpdateType = CancelInvoiceUpdate
update.State = &InvoiceStateUpdateDesc{
NewState: ContractCanceled,
SetID: setID,
}
return &update, failRes, nil
}
// The preimage for _this_ HTLC will be the one with context's
// circuit key.
htlcPreimage = htlcPreimages[ctx.circuitKey]
} else {
htlcPreimage = *inv.Terms.PaymentPreimage
}
update.State = &InvoiceStateUpdateDesc{
NewState: ContractSettled,
Preimage: inv.Terms.PaymentPreimage,
HTLCPreimages: htlcPreimages,
SetID: setID,
}
return &update, ctx.settleRes(htlcPreimage, ResultSettled), nil
}
// HTLCSet is a map of CircuitKey to InvoiceHTLC.
type HTLCSet = map[CircuitKey]*InvoiceHTLC
// HTLCPreimages is a map of CircuitKey to preimage.
type HTLCPreimages = map[CircuitKey]lntypes.Preimage
// reconstructAMPPreimages reconstructs the root seed for an AMP HTLC set and
// verifies that all derived child hashes match the payment hashes of the HTLCs
// in the set. This method is meant to be called after receiving the full amount
// committed to via mpp_total_msat. This method will return a fail resolution if
// any of the child hashes fail to match their corresponding HTLCs.
func reconstructAMPPreimages(ctx *invoiceUpdateCtx,
htlcSet HTLCSet) (HTLCPreimages, *HtlcFailResolution) {
// Create a slice containing all the child descriptors to be used for
// reconstruction. This should include all HTLCs currently in the HTLC
// set, plus the incoming HTLC.
childDescs := make([]amp.ChildDesc, 0, 1+len(htlcSet))
// Add the new HTLC's child descriptor at index 0.
childDescs = append(childDescs, amp.ChildDesc{
Share: ctx.amp.RootShare(),
Index: ctx.amp.ChildIndex(),
})
// Next, construct an index mapping the position in childDescs to a
// circuit key for all preexisting HTLCs.
indexToCircuitKey := make(map[int]CircuitKey)
// Add the child descriptor for each HTLC in the HTLC set, recording
// it's position within the slice.
var htlcSetIndex int
for circuitKey, htlc := range htlcSet {
childDescs = append(childDescs, amp.ChildDesc{
Share: htlc.AMP.Record.RootShare(),
Index: htlc.AMP.Record.ChildIndex(),
})
indexToCircuitKey[htlcSetIndex] = circuitKey
htlcSetIndex++
}
// Using the child descriptors, reconstruct the root seed and derive the
// child hash/preimage pairs for each of the HTLCs.
children := amp.ReconstructChildren(childDescs...)
// Validate that the derived child preimages match the hash of each
// HTLC's respective hash.
if ctx.hash != children[0].Hash {
return nil, ctx.failRes(ResultAmpReconstruction)
}
for idx, child := range children[1:] {
circuitKey := indexToCircuitKey[idx]
htlc := htlcSet[circuitKey]
if htlc.AMP.Hash != child.Hash {
return nil, ctx.failRes(ResultAmpReconstruction)
}
}
// Finally, construct the map of learned preimages indexed by circuit
// key, so that they can be persisted along with each HTLC when updating
// the invoice.
htlcPreimages := make(map[CircuitKey]lntypes.Preimage)
htlcPreimages[ctx.circuitKey] = children[0].Preimage
for idx, child := range children[1:] {
circuitKey := indexToCircuitKey[idx]
htlcPreimages[circuitKey] = child.Preimage
}
return htlcPreimages, nil
}
// updateLegacy is a callback for DB.UpdateInvoice that contains the invoice
// settlement logic for legacy payments.
//
// NOTE: This function is only kept in place in order to be able to handle key
// send payments and any invoices we created in the past that are valid and
// still had the optional mpp bit set.
func updateLegacy(ctx *invoiceUpdateCtx,
inv *Invoice) (*InvoiceUpdateDesc, HtlcResolution, error) {
// If the invoice is already canceled, there is no further
// checking to do.
if inv.State == ContractCanceled {
return nil, ctx.failRes(ResultInvoiceAlreadyCanceled), nil
}
// If an invoice amount is specified, check that enough is paid. Also
// check this for duplicate payments if the invoice is already settled
// or accepted. In case this is a zero-valued invoice, it will always be
// enough.
if ctx.amtPaid < inv.Terms.Value {
return nil, ctx.failRes(ResultAmountTooLow), nil
}
// If the invoice had the required feature bit set at this point, then
// if we're in this method it means that the remote party didn't supply
// the expected payload. However if this is a keysend payment, then
// we'll permit it to pass.
_, isKeySend := ctx.customRecords[record.KeySendType]
invoiceFeatures := inv.Terms.Features
paymentAddrRequired := invoiceFeatures.RequiresFeature(
lnwire.PaymentAddrRequired,
)
if !isKeySend && paymentAddrRequired {
log.Warnf("Payment to pay_hash=%v doesn't include MPP "+
"payload, rejecting", ctx.hash)
return nil, ctx.failRes(ResultAddressMismatch), nil
}
// Don't allow settling the invoice with an old style
// htlc if we are already in the process of gathering an
// mpp set.
for _, htlc := range inv.HTLCSet(nil, HtlcStateAccepted) {
if htlc.MppTotalAmt > 0 {
return nil, ctx.failRes(ResultMppInProgress), nil
}
}
// The invoice is still open. Check the expiry.
if ctx.expiry < uint32(ctx.currentHeight+ctx.finalCltvRejectDelta) {
return nil, ctx.failRes(ResultExpiryTooSoon), nil
}
if ctx.expiry < uint32(ctx.currentHeight+inv.Terms.FinalCltvDelta) {
return nil, ctx.failRes(ResultExpiryTooSoon), nil
}
// For storage, we don't really care where the custom records came from.
// So we merge them together and store them in the same field.
customRecords := lnwire.CustomRecords(
ctx.customRecords,
).MergedCopy(ctx.wireCustomRecords)
// Record HTLC in the invoice database.
newHtlcs := map[CircuitKey]*HtlcAcceptDesc{
ctx.circuitKey: {
Amt: ctx.amtPaid,
Expiry: ctx.expiry,
AcceptHeight: ctx.currentHeight,
CustomRecords: record.CustomSet(customRecords),
},
}
update := InvoiceUpdateDesc{
AddHtlcs: newHtlcs,
UpdateType: AddHTLCsUpdate,
}
// Don't update invoice state if we are accepting a duplicate payment.
// We do accept or settle the HTLC.
switch inv.State {
case ContractAccepted:
return &update, ctx.acceptRes(resultDuplicateToAccepted), nil
case ContractSettled:
return &update, ctx.settleRes(
*inv.Terms.PaymentPreimage, ResultDuplicateToSettled,
), nil
}
// Check to see if we can settle or this is an hold invoice and we need
// to wait for the preimage.
if inv.HodlInvoice {
update.State = &InvoiceStateUpdateDesc{
NewState: ContractAccepted,
}
return &update, ctx.acceptRes(resultAccepted), nil
}
update.State = &InvoiceStateUpdateDesc{
NewState: ContractSettled,
Preimage: inv.Terms.PaymentPreimage,
}
return &update, ctx.settleRes(
*inv.Terms.PaymentPreimage, ResultSettled,
), nil
}
package invoices
import (
"errors"
"fmt"
"time"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
// updateHtlcsAmp takes an invoice, and a new HTLC to be added (along with its
// set ID), and updates the internal AMP state of an invoice, and also tallies
// the set of HTLCs to be updated on disk.
func acceptHtlcsAmp(invoice *Invoice, setID SetID,
circuitKey models.CircuitKey, htlc *InvoiceHTLC,
updater InvoiceUpdater) error {
newAmpState, err := getUpdatedInvoiceAmpState(
invoice, setID, circuitKey, HtlcStateAccepted, htlc.Amt,
)
if err != nil {
return err
}
invoice.AMPState[setID] = newAmpState
// Mark the updates as needing to be written to disk.
return updater.UpdateAmpState(setID, newAmpState, circuitKey)
}
// cancelHtlcsAmp processes a cancellation of an HTLC that belongs to an AMP
// HTLC set. We'll need to update the meta data in the main invoice, and also
// apply the new update to the update MAP, since all the HTLCs for a given HTLC
// set need to be written in-line with each other.
func cancelHtlcsAmp(invoice *Invoice, circuitKey models.CircuitKey,
htlc *InvoiceHTLC, updater InvoiceUpdater) error {
setID := htlc.AMP.Record.SetID()
// First, we'll update the state of the entire HTLC set
// to cancelled.
newAmpState, err := getUpdatedInvoiceAmpState(
invoice, setID, circuitKey, HtlcStateCanceled,
htlc.Amt,
)
if err != nil {
return err
}
invoice.AMPState[setID] = newAmpState
// Mark the updates as needing to be written to disk.
err = updater.UpdateAmpState(setID, newAmpState, circuitKey)
if err != nil {
return err
}
// We'll only decrement the total amount paid if the invoice was
// already in the accepted state.
if invoice.AmtPaid != 0 {
return updateInvoiceAmtPaid(
invoice, invoice.AmtPaid-htlc.Amt, updater,
)
}
return nil
}
// settleHtlcsAmp processes a new settle operation on an HTLC set for an AMP
// invoice. We'll update some meta data in the main invoice, and also signal
// that this HTLC set needs to be re-written back to disk.
func settleHtlcsAmp(invoice *Invoice, circuitKey models.CircuitKey,
htlc *InvoiceHTLC, updater InvoiceUpdater) error {
setID := htlc.AMP.Record.SetID()
// Next update the main AMP meta-data to indicate that this HTLC set
// has been fully settled.
newAmpState, err := getUpdatedInvoiceAmpState(
invoice, setID, circuitKey, HtlcStateSettled, 0,
)
if err != nil {
return err
}
invoice.AMPState[setID] = newAmpState
// Mark the updates as needing to be written to disk.
return updater.UpdateAmpState(setID, newAmpState, circuitKey)
}
// UpdateInvoice fetches the invoice, obtains the update descriptor from the
// callback and applies the updates in a single db transaction.
func UpdateInvoice(hash *lntypes.Hash, invoice *Invoice,
updateTime time.Time, callback InvoiceUpdateCallback,
updater InvoiceUpdater) (*Invoice, error) {
// Create deep copy to prevent any accidental modification in the
// callback.
invoiceCopy, err := CopyInvoice(invoice)
if err != nil {
return nil, err
}
// Call the callback and obtain the update descriptor.
update, err := callback(invoiceCopy)
if err != nil {
return invoice, err
}
// If there is nothing to update, return early.
if update == nil {
return invoice, nil
}
switch update.UpdateType {
case CancelHTLCsUpdate:
err := cancelHTLCs(invoice, updateTime, update, updater)
if err != nil {
return nil, err
}
case AddHTLCsUpdate:
err := addHTLCs(invoice, hash, updateTime, update, updater)
if err != nil {
return nil, err
}
case SettleHodlInvoiceUpdate:
err := settleHodlInvoice(
invoice, hash, updateTime, update.State, updater,
)
if err != nil {
return nil, err
}
case CancelInvoiceUpdate:
err := cancelInvoice(
invoice, hash, updateTime, update.State, updater,
)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown update type: %s",
update.UpdateType)
}
if err := updater.Finalize(update.UpdateType); err != nil {
return nil, err
}
return invoice, nil
}
// cancelHTLCs tries to cancel the htlcs in the given InvoiceUpdateDesc.
//
// NOTE: cancelHTLCs updates will only use the `CancelHtlcs` field in the
// InvoiceUpdateDesc.
func cancelHTLCs(invoice *Invoice, updateTime time.Time,
update *InvoiceUpdateDesc, updater InvoiceUpdater) error {
for key := range update.CancelHtlcs {
htlc, exists := invoice.Htlcs[key]
// Verify that we don't get an action for htlcs that are not
// present on the invoice.
if !exists {
return fmt.Errorf("cancel of non-existent htlc")
}
err := canCancelSingleHtlc(htlc, invoice.State)
if err != nil {
return err
}
err = resolveHtlc(
key, htlc, HtlcStateCanceled, updateTime,
updater,
)
if err != nil {
return err
}
// Tally this into the set of HTLCs that need to be updated on
// disk, but once again, only if this is an AMP invoice.
if invoice.IsAMP() {
err := cancelHtlcsAmp(invoice, key, htlc, updater)
if err != nil {
return err
}
}
}
return nil
}
// addHTLCs tries to add the htlcs in the given InvoiceUpdateDesc.
//
//nolint:funlen
func addHTLCs(invoice *Invoice, hash *lntypes.Hash, updateTime time.Time,
update *InvoiceUpdateDesc, updater InvoiceUpdater) error {
var setID *[32]byte
invoiceIsAMP := invoice.IsAMP()
if invoiceIsAMP && update.State != nil {
setID = update.State.SetID
}
for key, htlcUpdate := range update.AddHtlcs {
if _, exists := invoice.Htlcs[key]; exists {
return fmt.Errorf("duplicate add of htlc %v", key)
}
// Force caller to supply htlc without custom records in a
// consistent way.
if htlcUpdate.CustomRecords == nil {
return errors.New("nil custom records map")
}
htlc := &InvoiceHTLC{
Amt: htlcUpdate.Amt,
MppTotalAmt: htlcUpdate.MppTotalAmt,
Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
AcceptTime: updateTime,
State: HtlcStateAccepted,
CustomRecords: htlcUpdate.CustomRecords,
}
if invoiceIsAMP {
if htlcUpdate.AMP == nil {
return fmt.Errorf("unable to add htlc "+
"without AMP data to AMP invoice(%v)",
invoice.AddIndex)
}
htlc.AMP = htlcUpdate.AMP.Copy()
}
if err := updater.AddHtlc(key, htlc); err != nil {
return err
}
invoice.Htlcs[key] = htlc
// Collect the set of new HTLCs so we can write them properly
// below, but only if this is an AMP invoice.
if invoiceIsAMP {
err := acceptHtlcsAmp(
invoice, htlcUpdate.AMP.Record.SetID(), key,
htlc, updater,
)
if err != nil {
return err
}
}
}
// At this point, the set of accepted HTLCs should be fully
// populated with added HTLCs or removed of canceled ones. Update
// invoice state if the update descriptor indicates an invoice state
// change, which depends on having an accurate view of the accepted
// HTLCs.
if update.State != nil {
newState, err := getUpdatedInvoiceState(
invoice, hash, *update.State,
)
if err != nil {
return err
}
// If this isn't an AMP invoice, then we'll go ahead and update
// the invoice state directly here. For AMP invoices, we instead
// will keep the top-level invoice open, and update the state of
// each _htlc set_ instead. However, we'll allow the invoice to
// transition to the cancelled state regardless.
if !invoiceIsAMP || *newState == ContractCanceled {
err := updater.UpdateInvoiceState(*newState, nil)
if err != nil {
return err
}
invoice.State = *newState
}
}
// The set of HTLC pre-images will only be set if we were actually able
// to reconstruct all the AMP pre-images.
var settleEligibleAMP bool
if update.State != nil {
settleEligibleAMP = len(update.State.HTLCPreimages) != 0
}
// With any invoice level state transitions recorded, we'll now
// finalize the process by updating the state transitions for
// individual HTLCs
var amtPaid lnwire.MilliSatoshi
for key, htlc := range invoice.Htlcs {
// Set the HTLC preimage for any AMP HTLCs.
if setID != nil && update.State != nil {
preimage, ok := update.State.HTLCPreimages[key]
switch {
// If we don't already have a preimage for this HTLC, we
// can set it now.
case ok && htlc.AMP.Preimage == nil:
err := updater.AddAmpHtlcPreimage(
htlc.AMP.Record.SetID(), key, preimage,
)
if err != nil {
return err
}
htlc.AMP.Preimage = &preimage
// Otherwise, prevent over-writing an existing
// preimage. Ignore the case where the preimage is
// identical.
case ok && *htlc.AMP.Preimage != preimage:
return ErrHTLCPreimageAlreadyExists
}
}
// The invoice state may have changed and this could have
// implications for the states of the individual htlcs. Align
// the htlc state with the current invoice state.
//
// If we have all the pre-images for an AMP invoice, then we'll
// act as if we're able to settle the entire invoice. We need
// to do this since it's possible for us to settle AMP invoices
// while the contract state (on disk) is still in the accept
// state.
htlcContextState := invoice.State
if settleEligibleAMP {
htlcContextState = ContractSettled
}
htlcStateChanged, htlcState, err := getUpdatedHtlcState(
htlc, htlcContextState, setID,
)
if err != nil {
return err
}
if htlcStateChanged {
err = resolveHtlc(
key, htlc, htlcState, updateTime, updater,
)
if err != nil {
return err
}
}
htlcSettled := htlcStateChanged &&
htlcState == HtlcStateSettled
// If the HTLC has being settled for the first time, and this
// is an AMP invoice, then we'll need to update some additional
// meta data state.
if htlcSettled && invoiceIsAMP {
err = settleHtlcsAmp(invoice, key, htlc, updater)
if err != nil {
return err
}
}
accepted := htlc.State == HtlcStateAccepted
settled := htlc.State == HtlcStateSettled
invoiceStateReady := accepted || settled
if !invoiceIsAMP {
// Update the running amount paid to this invoice. We
// don't include accepted htlcs when the invoice is
// still open.
if invoice.State != ContractOpen &&
invoiceStateReady {
amtPaid += htlc.Amt
}
} else {
// For AMP invoices, since we won't always be reading
// out the total invoice set each time, we'll instead
// accumulate newly added invoices to the total amount
// paid.
if _, ok := update.AddHtlcs[key]; !ok {
continue
}
// Update the running amount paid to this invoice. AMP
// invoices never go to the settled state, so if it's
// open, then we tally the HTLC.
if invoice.State == ContractOpen &&
invoiceStateReady {
amtPaid += htlc.Amt
}
}
}
// For non-AMP invoices we recalculate the amount paid from scratch
// each time, while for AMP invoices, we'll accumulate only based on
// newly added HTLCs.
if invoiceIsAMP {
amtPaid += invoice.AmtPaid
}
return updateInvoiceAmtPaid(invoice, amtPaid, updater)
}
func resolveHtlc(circuitKey models.CircuitKey, htlc *InvoiceHTLC,
state HtlcState, resolveTime time.Time,
updater InvoiceUpdater) error {
err := updater.ResolveHtlc(circuitKey, state, resolveTime)
if err != nil {
return err
}
htlc.State = state
htlc.ResolveTime = resolveTime
return nil
}
func updateInvoiceAmtPaid(invoice *Invoice, amt lnwire.MilliSatoshi,
updater InvoiceUpdater) error {
err := updater.UpdateInvoiceAmtPaid(amt)
if err != nil {
return err
}
invoice.AmtPaid = amt
return nil
}
// settleHodlInvoice marks a hodl invoice as settled.
//
// NOTE: Currently it is not possible to have HODL AMP invoices.
func settleHodlInvoice(invoice *Invoice, hash *lntypes.Hash,
updateTime time.Time, update *InvoiceStateUpdateDesc,
updater InvoiceUpdater) error {
if !invoice.HodlInvoice {
return fmt.Errorf("unable to settle hodl invoice: %v is "+
"not a hodl invoice", invoice.AddIndex)
}
// TODO(positiveblue): because NewState can only be ContractSettled we
// can remove it from the API and set it here directly.
switch {
case update == nil:
fallthrough
case update.NewState != ContractSettled:
return fmt.Errorf("unable to settle hodl invoice: "+
"not valid InvoiceUpdateDesc.State: %v", update)
case update.Preimage == nil:
return fmt.Errorf("unable to settle hodl invoice: " +
"preimage is nil")
}
newState, err := getUpdatedInvoiceState(
invoice, hash, *update,
)
if err != nil {
return err
}
if newState == nil || *newState != ContractSettled {
return fmt.Errorf("unable to settle hodl invoice: "+
"new computed state is not settled: %s", newState)
}
err = updater.UpdateInvoiceState(
ContractSettled, update.Preimage,
)
if err != nil {
return err
}
invoice.State = ContractSettled
invoice.Terms.PaymentPreimage = update.Preimage
// TODO(positiveblue): this logic can be further simplified.
var amtPaid lnwire.MilliSatoshi
for key, htlc := range invoice.Htlcs {
settled, _, err := getUpdatedHtlcState(
htlc, ContractSettled, nil,
)
if err != nil {
return err
}
if settled {
err = resolveHtlc(
key, htlc, HtlcStateSettled, updateTime,
updater,
)
if err != nil {
return err
}
amtPaid += htlc.Amt
}
}
return updateInvoiceAmtPaid(invoice, amtPaid, updater)
}
// cancelInvoice attempts to cancel the given invoice. That includes changing
// the invoice state and the state of any relevant HTLC.
func cancelInvoice(invoice *Invoice, hash *lntypes.Hash,
updateTime time.Time, update *InvoiceStateUpdateDesc,
updater InvoiceUpdater) error {
switch {
case update == nil:
fallthrough
case update.NewState != ContractCanceled:
return fmt.Errorf("unable to cancel invoice: "+
"InvoiceUpdateDesc.State not valid: %v", update)
}
var (
setID *[32]byte
invoiceIsAMP bool
)
invoiceIsAMP = invoice.IsAMP()
if invoiceIsAMP {
setID = update.SetID
}
newState, err := getUpdatedInvoiceState(invoice, hash, *update)
if err != nil {
return err
}
if newState == nil || *newState != ContractCanceled {
return fmt.Errorf("unable to cancel invoice(%v): new "+
"computed state is not canceled: %s", invoice.AddIndex,
newState)
}
err = updater.UpdateInvoiceState(ContractCanceled, nil)
if err != nil {
return err
}
invoice.State = ContractCanceled
for key, htlc := range invoice.Htlcs {
// We might not have a setID here in case we are cancelling
// an AMP invoice however the setID is only important when
// settling an AMP HTLC.
canceled, _, err := getUpdatedHtlcState(
htlc, ContractCanceled, setID,
)
if err != nil {
return err
}
if canceled {
err = resolveHtlc(
key, htlc, HtlcStateCanceled, updateTime,
updater,
)
if err != nil {
return err
}
// If its an AMP HTLC we need to make sure we persist
// this new state otherwise AMP HTLCs are not updated
// on disk because HTLCs for AMP invoices are stored
// separately.
if htlc.AMP != nil {
err := cancelHtlcsAmp(
invoice, key, htlc, updater,
)
if err != nil {
return err
}
}
}
}
return nil
}
// getUpdatedInvoiceState validates and processes an invoice state update. The
// new state to transition to is returned, so the caller is able to select
// exactly how the invoice state is updated. Note that for AMP invoices this
// function is only used to validate the state transition if we're cancelling
// the invoice.
func getUpdatedInvoiceState(invoice *Invoice, hash *lntypes.Hash,
update InvoiceStateUpdateDesc) (*ContractState, error) {
// Returning to open is never allowed from any state.
if update.NewState == ContractOpen {
return nil, ErrInvoiceCannotOpen
}
switch invoice.State {
// Once a contract is accepted, we can only transition to settled or
// canceled. Forbid transitioning back into this state. Otherwise this
// state is identical to ContractOpen, so we fallthrough to apply the
// same checks that we apply to open invoices.
case ContractAccepted:
if update.NewState == ContractAccepted {
return nil, ErrInvoiceCannotAccept
}
fallthrough
// If a contract is open, permit a state transition to accepted, settled
// or canceled. The only restriction is on transitioning to settled
// where we ensure the preimage is valid.
case ContractOpen:
if update.NewState == ContractCanceled {
return &update.NewState, nil
}
// Sanity check that the user isn't trying to settle or accept a
// non-existent HTLC set.
set := invoice.HTLCSet(update.SetID, HtlcStateAccepted)
if len(set) == 0 {
return nil, ErrEmptyHTLCSet
}
// For AMP invoices, there are no invoice-level preimage checks.
// However, we still sanity check that we aren't trying to
// settle an AMP invoice with a preimage.
if update.SetID != nil {
if update.Preimage != nil {
return nil, errors.New("AMP set cannot have " +
"preimage")
}
return &update.NewState, nil
}
switch {
// If an invoice-level preimage was supplied, but the InvoiceRef
// doesn't specify a hash (e.g. AMP invoices) we fail.
case update.Preimage != nil && hash == nil:
return nil, ErrUnexpectedInvoicePreimage
// Validate the supplied preimage for non-AMP invoices.
case update.Preimage != nil:
if update.Preimage.Hash() != *hash {
return nil, ErrInvoicePreimageMismatch
}
// Permit non-AMP invoices to be accepted without knowing the
// preimage. When trying to settle we'll have to pass through
// the above check in order to not hit the one below.
case update.NewState == ContractAccepted:
// Fail if we still don't have a preimage when transitioning to
// settle the non-AMP invoice.
case update.NewState == ContractSettled &&
invoice.Terms.PaymentPreimage == nil:
return nil, errors.New("unknown preimage")
}
return &update.NewState, nil
// Once settled, we are in a terminal state.
case ContractSettled:
return nil, ErrInvoiceAlreadySettled
// Once canceled, we are in a terminal state.
case ContractCanceled:
return nil, ErrInvoiceAlreadyCanceled
default:
return nil, errors.New("unknown state transition")
}
}
// getUpdatedInvoiceAmpState returns the AMP state of an invoice (without
// applying it), given the new state, and the amount of the HTLC that is
// being updated.
func getUpdatedInvoiceAmpState(invoice *Invoice, setID SetID,
circuitKey models.CircuitKey, state HtlcState,
amt lnwire.MilliSatoshi) (InvoiceStateAMP, error) {
// Retrieve the AMP state for this set ID.
ampState, ok := invoice.AMPState[setID]
// If the state is accepted then we may need to create a new entry for
// this set ID, otherwise we expect that the entry already exists and
// we can update it.
if !ok && state != HtlcStateAccepted {
return InvoiceStateAMP{},
fmt.Errorf("unable to update AMP state for setID=%x ",
setID)
}
switch state {
case HtlcStateAccepted:
if !ok {
// If an entry for this set ID doesn't already exist,
// then we'll need to create it.
ampState = InvoiceStateAMP{
State: HtlcStateAccepted,
InvoiceKeys: make(
map[models.CircuitKey]struct{},
),
}
}
ampState.AmtPaid += amt
case HtlcStateCanceled:
ampState.State = HtlcStateCanceled
ampState.AmtPaid -= amt
case HtlcStateSettled:
ampState.State = HtlcStateSettled
}
ampState.InvoiceKeys[circuitKey] = struct{}{}
return ampState, nil
}
// canCancelSingleHtlc validates cancellation of a single HTLC. If nil is
// returned, then the HTLC can be cancelled.
func canCancelSingleHtlc(htlc *InvoiceHTLC,
invoiceState ContractState) error {
// It is only possible to cancel individual htlcs on an open invoice.
if invoiceState != ContractOpen {
return fmt.Errorf("htlc canceled on invoice in state %v",
invoiceState)
}
// It is only possible if the htlc is still pending.
if htlc.State != HtlcStateAccepted {
return fmt.Errorf("htlc canceled in state %v", htlc.State)
}
return nil
}
// getUpdatedHtlcState aligns the state of an htlc with the given invoice state.
// A boolean indicating whether the HTLCs state need to be updated, along with
// the new state (or old state if no change is needed) is returned.
func getUpdatedHtlcState(htlc *InvoiceHTLC,
invoiceState ContractState, setID *[32]byte) (
bool, HtlcState, error) {
trySettle := func(persist bool) (bool, HtlcState, error) {
if htlc.State != HtlcStateAccepted {
return false, htlc.State, nil
}
// Settle the HTLC if it matches the settled set id. If
// there're other HTLCs with distinct setIDs, then we'll leave
// them, as they may eventually be settled as we permit
// multiple settles to a single pay_addr for AMP.
settled := false
if htlc.IsInHTLCSet(setID) {
// Non-AMP HTLCs can be settled immediately since we
// already know the preimage is valid due to checks at
// the invoice level. For AMP HTLCs, verify that the
// per-HTLC preimage-hash pair is valid.
switch {
// Non-AMP HTLCs can be settle immediately since we
// already know the preimage is valid due to checks at
// the invoice level.
case setID == nil:
// At this point, the setID is non-nil, meaning this is
// an AMP HTLC. We know that htlc.AMP cannot be nil,
// otherwise IsInHTLCSet would have returned false.
//
// Fail if an accepted AMP HTLC has no preimage.
case htlc.AMP.Preimage == nil:
return false, htlc.State,
ErrHTLCPreimageMissing
// Fail if the accepted AMP HTLC has an invalid
// preimage.
case !htlc.AMP.Preimage.Matches(htlc.AMP.Hash):
return false, htlc.State,
ErrHTLCPreimageMismatch
}
settled = true
}
// Only persist the changes if the invoice is moving to the
// settled state, and we're actually updating the state to
// settled.
newState := htlc.State
if settled {
newState = HtlcStateSettled
}
return persist && settled, newState, nil
}
if invoiceState == ContractSettled {
// Check that we can settle the HTLCs. For legacy and MPP HTLCs
// this will be a NOP, but for AMP HTLCs this asserts that we
// have a valid hash/preimage pair. Passing true permits the
// method to update the HTLC to HtlcStateSettled.
return trySettle(true)
}
// We should never find a settled HTLC on an invoice that isn't in
// ContractSettled.
if htlc.State == HtlcStateSettled {
return false, htlc.State, ErrHTLCAlreadySettled
}
switch invoiceState {
case ContractCanceled:
htlcAlreadyCanceled := htlc.State == HtlcStateCanceled
return !htlcAlreadyCanceled, HtlcStateCanceled, nil
// TODO(roasbeef): never fully passed thru now?
case ContractAccepted:
// Check that we can settle the HTLCs. For legacy and MPP HTLCs
// this will be a NOP, but for AMP HTLCs this asserts that we
// have a valid hash/preimage pair. Passing false prevents the
// method from putting the HTLC in HtlcStateSettled, leaving it
// in HtlcStateAccepted.
return trySettle(false)
case ContractOpen:
return false, htlc.State, nil
default:
return false, htlc.State, errors.New("unknown state transition")
}
}
package keychain
import (
"crypto/sha256"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb"
)
const (
// CoinTypeBitcoin specifies the BIP44 coin type for Bitcoin key
// derivation.
CoinTypeBitcoin uint32 = 0
// CoinTypeTestnet specifies the BIP44 coin type for all testnet key
// derivation.
CoinTypeTestnet = 1
)
var (
// lightningAddrSchema is the scope addr schema for all keys that we
// derive. We'll treat them all as p2wkh addresses, as atm we must
// specify a particular type.
lightningAddrSchema = waddrmgr.ScopeAddrSchema{
ExternalAddrType: waddrmgr.WitnessPubKey,
InternalAddrType: waddrmgr.WitnessPubKey,
}
// waddrmgrNamespaceKey is the namespace key that the waddrmgr state is
// stored within the top-level waleltdb buckets of btcwallet.
waddrmgrNamespaceKey = []byte("waddrmgr")
)
// BtcWalletKeyRing is an implementation of both the KeyRing and SecretKeyRing
// interfaces backed by btcwallet's internal root waddrmgr. Internally, we'll
// be using a ScopedKeyManager to do all of our derivations, using the key
// scope and scope addr scehma defined above. Re-using the existing key scope
// construction means that all key derivation will be protected under the root
// seed of the wallet, making each derived key fully deterministic.
type BtcWalletKeyRing struct {
// wallet is a pointer to the active instance of the btcwallet core.
// This is required as we'll need to manually open database
// transactions in order to derive addresses and lookup relevant keys
wallet *wallet.Wallet
// chainKeyScope defines the purpose and coin type to be used when generating
// keys for this keyring.
chainKeyScope waddrmgr.KeyScope
// lightningScope is a pointer to the scope that we'll be using as a
// sub key manager to derive all the keys that we require.
lightningScope *waddrmgr.ScopedKeyManager
}
// NewBtcWalletKeyRing creates a new implementation of the
// keychain.SecretKeyRing interface backed by btcwallet.
//
// NOTE: The passed waddrmgr.Manager MUST be unlocked in order for the keychain
// to function.
func NewBtcWalletKeyRing(w *wallet.Wallet, coinType uint32) SecretKeyRing {
// Construct the key scope that will be used within the waddrmgr to
// create an HD chain for deriving all of our required keys. A different
// scope is used for each specific coin type.
chainKeyScope := waddrmgr.KeyScope{
Purpose: BIP0043Purpose,
Coin: coinType,
}
return &BtcWalletKeyRing{
wallet: w,
chainKeyScope: chainKeyScope,
}
}
// keyScope attempts to return the key scope that we'll use to derive all of
// our keys. If the scope has already been fetched from the database, then a
// cached version will be returned. Otherwise, we'll fetch it from the database
// and cache it for subsequent accesses.
func (b *BtcWalletKeyRing) keyScope() (*waddrmgr.ScopedKeyManager, error) {
// If the scope has already been populated, then we'll return it
// directly.
if b.lightningScope != nil {
return b.lightningScope, nil
}
// Otherwise, we'll first do a check to ensure that the root manager
// isn't locked, as otherwise we won't be able to *use* the scope.
if !b.wallet.Manager.WatchOnly() && b.wallet.Manager.IsLocked() {
return nil, fmt.Errorf("cannot create BtcWalletKeyRing with " +
"locked waddrmgr.Manager")
}
// If the manager is indeed unlocked, then we'll fetch the scope, cache
// it, and return to the caller.
lnScope, err := b.wallet.Manager.FetchScopedKeyManager(b.chainKeyScope)
if err != nil {
return nil, err
}
b.lightningScope = lnScope
return lnScope, nil
}
// createAccountIfNotExists will create the corresponding account for a key
// family if it doesn't already exist in the database.
func (b *BtcWalletKeyRing) createAccountIfNotExists(
addrmgrNs walletdb.ReadWriteBucket, keyFam KeyFamily,
scope *waddrmgr.ScopedKeyManager) error {
// If this is the multi-sig key family, then we can return early as
// this is the default account that's created.
if keyFam == KeyFamilyMultiSig {
return nil
}
// Otherwise, we'll check if the account already exists, if so, we can
// once again bail early.
_, err := scope.AccountName(addrmgrNs, uint32(keyFam))
if err == nil {
return nil
}
// If we reach this point, then the account hasn't yet been created, so
// we'll need to create it before we can proceed.
return scope.NewRawAccount(addrmgrNs, uint32(keyFam))
}
// DeriveNextKey attempts to derive the *next* key within the key family
// (account in BIP43) specified. This method should return the next external
// child within this branch.
//
// NOTE: This is part of the keychain.KeyRing interface.
func (b *BtcWalletKeyRing) DeriveNextKey(keyFam KeyFamily) (KeyDescriptor, error) {
var (
pubKey *btcec.PublicKey
keyLoc KeyLocator
)
db := b.wallet.Database()
err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
scope, err := b.keyScope()
if err != nil {
return err
}
// If the account doesn't exist, then we may need to create it
// for the first time in order to derive the keys that we
// require.
err = b.createAccountIfNotExists(addrmgrNs, keyFam, scope)
if err != nil {
return err
}
addrs, err := scope.NextExternalAddresses(
addrmgrNs, uint32(keyFam), 1,
)
if err != nil {
return err
}
// Extract the first address, ensuring that it is of the proper
// interface type, otherwise we can't manipulate it below.
addr, ok := addrs[0].(waddrmgr.ManagedPubKeyAddress)
if !ok {
return fmt.Errorf("address is not a managed pubkey " +
"addr")
}
pubKey = addr.PubKey()
_, pathInfo, _ := addr.DerivationInfo()
keyLoc = KeyLocator{
Family: keyFam,
Index: pathInfo.Index,
}
return nil
})
if err != nil {
return KeyDescriptor{}, err
}
return KeyDescriptor{
PubKey: pubKey,
KeyLocator: keyLoc,
}, nil
}
// DeriveKey attempts to derive an arbitrary key specified by the passed
// KeyLocator. This may be used in several recovery scenarios, or when manually
// rotating something like our current default node key.
//
// NOTE: This is part of the keychain.KeyRing interface.
func (b *BtcWalletKeyRing) DeriveKey(keyLoc KeyLocator) (KeyDescriptor, error) {
var keyDesc KeyDescriptor
db := b.wallet.Database()
err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
scope, err := b.keyScope()
if err != nil {
return err
}
// If the account doesn't exist, then we may need to create it
// for the first time in order to derive the keys that we
// require. We skip this if we're using a remote signer in which
// case we _need_ to create all accounts when creating the
// wallet, so it must exist now.
if !b.wallet.Manager.WatchOnly() {
err = b.createAccountIfNotExists(
addrmgrNs, keyLoc.Family, scope,
)
if err != nil {
return err
}
}
path := waddrmgr.DerivationPath{
InternalAccount: uint32(keyLoc.Family),
Branch: 0,
Index: keyLoc.Index,
}
addr, err := scope.DeriveFromKeyPath(addrmgrNs, path)
if err != nil {
return err
}
keyDesc.KeyLocator = keyLoc
keyDesc.PubKey = addr.(waddrmgr.ManagedPubKeyAddress).PubKey()
return nil
})
if err != nil {
return keyDesc, err
}
return keyDesc, nil
}
// DerivePrivKey attempts to derive the private key that corresponds to the
// passed key descriptor.
//
// NOTE: This is part of the keychain.SecretKeyRing interface.
func (b *BtcWalletKeyRing) DerivePrivKey(keyDesc KeyDescriptor) (
*btcec.PrivateKey, error) {
var key *btcec.PrivateKey
scope, err := b.keyScope()
if err != nil {
return nil, err
}
// First, attempt to see if we can read the key directly from
// btcwallet's internal cache, if we can then we can skip all the
// operations below (fast path).
if keyDesc.PubKey == nil {
keyPath := waddrmgr.DerivationPath{
InternalAccount: uint32(keyDesc.Family),
Account: uint32(keyDesc.Family),
Branch: 0,
Index: keyDesc.Index,
}
privKey, err := scope.DeriveFromKeyPathCache(keyPath)
if err == nil {
return privKey, nil
}
}
db := b.wallet.Database()
err = walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
// If the account doesn't exist, then we may need to create it
// for the first time in order to derive the keys that we
// require. We skip this if we're using a remote signer in which
// case we _need_ to create all accounts when creating the
// wallet, so it must exist now.
if !b.wallet.Manager.WatchOnly() {
err = b.createAccountIfNotExists(
addrmgrNs, keyDesc.Family, scope,
)
if err != nil {
return err
}
}
// If the public key isn't set or they have a non-zero index,
// then we know that the caller instead knows the derivation
// path for a key.
if keyDesc.PubKey == nil || keyDesc.Index > 0 {
// Now that we know the account exists, we can safely
// derive the full private key from the given path.
path := waddrmgr.DerivationPath{
InternalAccount: uint32(keyDesc.Family),
Branch: 0,
Index: keyDesc.Index,
}
addr, err := scope.DeriveFromKeyPath(addrmgrNs, path)
if err != nil {
return err
}
key, err = addr.(waddrmgr.ManagedPubKeyAddress).PrivKey()
if err != nil {
return err
}
return nil
}
// If the public key isn't nil, then this indicates that we
// need to scan for the private key, assuming that we know the
// valid key family.
nextPath := waddrmgr.DerivationPath{
InternalAccount: uint32(keyDesc.Family),
Branch: 0,
Index: 0,
}
// We'll now iterate through our key range in an attempt to
// find the target public key.
//
// TODO(roasbeef): possibly move scanning into wallet to allow
// to be parallelized
for i := 0; i < MaxKeyRangeScan; i++ {
// Derive the next key in the range and fetch its
// managed address.
addr, err := scope.DeriveFromKeyPath(
addrmgrNs, nextPath,
)
if err != nil {
return err
}
managedAddr := addr.(waddrmgr.ManagedPubKeyAddress)
// If this is the target public key, then we'll return
// it directly back to the caller.
if managedAddr.PubKey().IsEqual(keyDesc.PubKey) {
key, err = managedAddr.PrivKey()
if err != nil {
return err
}
return nil
}
// This wasn't the target key, so roll forward and try
// the next one.
nextPath.Index++
}
// If we reach this point, then we we're unable to derive the
// private key, so return an error back to the user.
return ErrCannotDerivePrivKey
})
if err != nil {
return nil, err
}
return key, nil
}
// ECDH performs a scalar multiplication (ECDH-like operation) between the
// target key descriptor and remote public key. The output returned will be
// the sha256 of the resulting shared point serialized in compressed format. If
// k is our private key, and P is the public key, we perform the following
// operation:
//
// sx := k*P s := sha256(sx.SerializeCompressed())
//
// NOTE: This is part of the keychain.ECDHRing interface.
func (b *BtcWalletKeyRing) ECDH(keyDesc KeyDescriptor,
pub *btcec.PublicKey) ([32]byte, error) {
privKey, err := b.DerivePrivKey(keyDesc)
if err != nil {
return [32]byte{}, err
}
var (
pubJacobian btcec.JacobianPoint
s btcec.JacobianPoint
)
pub.AsJacobian(&pubJacobian)
btcec.ScalarMultNonConst(&privKey.Key, &pubJacobian, &s)
s.ToAffine()
sPubKey := btcec.NewPublicKey(&s.X, &s.Y)
h := sha256.Sum256(sPubKey.SerializeCompressed())
return h, nil
}
// SignMessage signs the given message, single or double SHA256 hashing it
// first, with the private key described in the key locator.
//
// NOTE: This is part of the keychain.MessageSignerRing interface.
func (b *BtcWalletKeyRing) SignMessage(keyLoc KeyLocator,
msg []byte, doubleHash bool) (*ecdsa.Signature, error) {
privKey, err := b.DerivePrivKey(KeyDescriptor{
KeyLocator: keyLoc,
})
if err != nil {
return nil, err
}
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.Sign(privKey, digest), nil
}
// SignMessageCompact signs the given message, single or double SHA256 hashing
// it first, with the private key described in the key locator and returns
// the signature in the compact, public key recoverable format.
//
// NOTE: This is part of the keychain.MessageSignerRing interface.
func (b *BtcWalletKeyRing) SignMessageCompact(keyLoc KeyLocator,
msg []byte, doubleHash bool) ([]byte, error) {
privKey, err := b.DerivePrivKey(KeyDescriptor{
KeyLocator: keyLoc,
})
if err != nil {
return nil, err
}
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.SignCompact(privKey, digest, true), nil
}
// SignMessageSchnorr uses the Schnorr signature algorithm to sign the given
// message, single or double SHA256 hashing it first, with the private key
// described in the key locator and the optional tweak applied to the private
// key.
//
// NOTE: This is part of the keychain.MessageSignerRing interface.
func (b *BtcWalletKeyRing) SignMessageSchnorr(keyLoc KeyLocator,
msg []byte, doubleHash bool, taprootTweak []byte,
tag []byte) (*schnorr.Signature, error) {
privKey, err := b.DerivePrivKey(KeyDescriptor{
KeyLocator: keyLoc,
})
if err != nil {
return nil, err
}
if len(taprootTweak) > 0 {
privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak)
}
// If a tag was provided, we need to take the tagged hash of the input.
var digest []byte
switch {
case len(tag) > 0:
taggedHash := chainhash.TaggedHash(tag, msg)
digest = taggedHash[:]
case doubleHash:
digest = chainhash.DoubleHashB(msg)
default:
digest = chainhash.HashB(msg)
}
return schnorr.Sign(privKey, digest)
}
package keychain
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
)
const (
// KeyDerivationVersionLegacy is the previous version of the key
// derivation schema defined below. We use a version as this means that
// we'll be able to accept new seed in the future and be able to discern
// if the software is compatible with the version of the seed.
KeyDerivationVersionLegacy = 0
// KeyDerivationVersionTaproot is the most recent version of the key
// derivation scheme that marks the introduction of the Taproot
// derivation with BIP0086 support.
KeyDerivationVersionTaproot = 1
// CurrentKeyDerivationVersion is the current default key derivation
// version that is used for new seeds.
CurrentKeyDerivationVersion = KeyDerivationVersionTaproot
// BIP0043Purpose is the "purpose" value that we'll use for the first
// version or our key derivation scheme. All keys are expected to be
// derived from this purpose, then the particular coin type of the
// chain where the keys are to be used. Slightly adhering to BIP0043
// allows us to not deviate too far from a widely used standard, and
// also fits into existing implementations of the BIP's template.
//
// NOTE: BRICK SQUUUUUAD.
BIP0043Purpose = 1017
)
// IsKnownVersion returns true if the given version is one of the known
// derivation scheme versions as defined by this package.
func IsKnownVersion(internalVersion uint8) bool {
return internalVersion == KeyDerivationVersionLegacy ||
internalVersion == KeyDerivationVersionTaproot
}
var (
// MaxKeyRangeScan is the maximum number of keys that we'll attempt to
// scan with if a caller knows the public key, but not the KeyLocator
// and wishes to derive a private key.
MaxKeyRangeScan = 100000
// ErrCannotDerivePrivKey is returned when DerivePrivKey is unable to
// derive a private key given only the public key and target key
// family.
ErrCannotDerivePrivKey = fmt.Errorf("unable to derive private key")
)
// KeyFamily represents a "family" of keys that will be used within various
// contracts created by lnd. These families are meant to be distinct branches
// within the HD key chain of the backing wallet. Usage of key families within
// the interface below are strict in order to promote integrability and the
// ability to restore all keys given a user master seed backup.
//
// The key derivation in this file follows the following hierarchy based on
// BIP43:
//
// - m/1017'/coinType'/keyFamily'/0/index
type KeyFamily uint32
const (
// KeyFamilyMultiSig are keys to be used within multi-sig scripts.
KeyFamilyMultiSig KeyFamily = 0
// KeyFamilyRevocationBase are keys that are used within channels to
// create revocation basepoints that the remote party will use to
// create revocation keys for us.
KeyFamilyRevocationBase KeyFamily = 1
// KeyFamilyHtlcBase are keys used within channels that will be
// combined with per-state randomness to produce public keys that will
// be used in HTLC scripts.
KeyFamilyHtlcBase KeyFamily = 2
// KeyFamilyPaymentBase are keys used within channels that will be
// combined with per-state randomness to produce public keys that will
// be used in scripts that pay directly to us without any delay.
KeyFamilyPaymentBase KeyFamily = 3
// KeyFamilyDelayBase are keys used within channels that will be
// combined with per-state randomness to produce public keys that will
// be used in scripts that pay to us, but require a CSV delay before we
// can sweep the funds.
KeyFamilyDelayBase KeyFamily = 4
// KeyFamilyRevocationRoot is a family of keys which will be used to
// derive the root of a revocation tree for a particular channel.
KeyFamilyRevocationRoot KeyFamily = 5
// KeyFamilyNodeKey is a family of keys that will be used to derive
// keys that will be advertised on the network to represent our current
// "identity" within the network. Peers will need our latest node key
// in order to establish a transport session with us on the Lightning
// p2p level (BOLT-0008).
KeyFamilyNodeKey KeyFamily = 6
// KeyFamilyBaseEncryption is the family of keys that will be used to
// derive keys that we use to encrypt and decrypt any general blob data
// like static channel backups and the TLS private key. Often used when
// encrypting files on disk.
KeyFamilyBaseEncryption KeyFamily = 7
// KeyFamilyTowerSession is the family of keys that will be used to
// derive session keys when negotiating sessions with watchtowers. The
// session keys are limited to the lifetime of the session and are used
// to increase privacy in the watchtower protocol.
KeyFamilyTowerSession KeyFamily = 8
// KeyFamilyTowerID is the family of keys used to derive the public key
// of a watchtower. This made distinct from the node key to offer a form
// of rudimentary whitelisting, i.e. via knowledge of the pubkey,
// preventing others from having full access to the tower just as a
// result of knowing the node key.
KeyFamilyTowerID KeyFamily = 9
)
// VersionZeroKeyFamilies is a slice of all the known key families for first
// version of the key derivation schema defined in this package.
var VersionZeroKeyFamilies = []KeyFamily{
KeyFamilyMultiSig,
KeyFamilyRevocationBase,
KeyFamilyHtlcBase,
KeyFamilyPaymentBase,
KeyFamilyDelayBase,
KeyFamilyRevocationRoot,
KeyFamilyNodeKey,
KeyFamilyBaseEncryption,
KeyFamilyTowerSession,
KeyFamilyTowerID,
}
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
// been used under the key derivation mechanisms described in this file.
// Version 0 of our key derivation schema uses the following BIP43-like
// derivation:
//
// - m/1017'/coinType'/keyFamily'/0/index
//
// Our purpose is 1017 (chosen arbitrary for now), and the coin type will vary
// based on which coin/chain the channels are being created on. The key family
// are actually just individual "accounts" in the nomenclature of BIP43. By
// default we assume a branch of 0 (external). Finally, the key index (which
// will vary per channel and use case) is the final element which allows us to
// deterministically derive keys.
type KeyLocator struct {
// TODO(roasbeef): add the key scope as well??
// Family is the family of key being identified.
Family KeyFamily
// Index is the precise index of the key being identified.
Index uint32
}
// IsEmpty returns true if a KeyLocator is "empty". This may be the case where
// we learn of a key from a remote party for a contract, but don't know the
// precise details of its derivation (as we don't know the private key!).
func (k KeyLocator) IsEmpty() bool {
return k.Family == 0 && k.Index == 0
}
// KeyDescriptor wraps a KeyLocator and also optionally includes a public key.
// Either the KeyLocator must be non-empty, or the public key pointer be
// non-nil. This will be used by the KeyRing interface to lookup arbitrary
// private keys, and also within the SignDescriptor struct to locate precisely
// which keys should be used for signing.
type KeyDescriptor struct {
// KeyLocator is the internal KeyLocator of the descriptor.
KeyLocator
// PubKey is an optional public key that fully describes a target key.
// If this is nil, the KeyLocator MUST NOT be empty.
PubKey *btcec.PublicKey
}
// KeyRing is the primary interface that will be used to perform public
// derivation of various keys used within the peer-to-peer network, and also
// within any created contracts. All derivation required by the KeyRing is
// based off of public derivation, so a system with only an extended public key
// (for the particular purpose+family) can derive this set of keys.
type KeyRing interface {
// DeriveNextKey attempts to derive the *next* key within the key
// family (account in BIP43) specified. This method should return the
// next external child within this branch.
DeriveNextKey(keyFam KeyFamily) (KeyDescriptor, error)
// DeriveKey attempts to derive an arbitrary key specified by the
// passed KeyLocator. This may be used in several recovery scenarios,
// or when manually rotating something like our current default node
// key.
DeriveKey(keyLoc KeyLocator) (KeyDescriptor, error)
}
// SecretKeyRing is a ring similar to the regular KeyRing interface, but it is
// also able to derive *private keys*. As this is a super-set of the regular
// KeyRing, we also expect the SecretKeyRing to implement the fully KeyRing
// interface. The methods in this struct may be used to extract the node key in
// order to accept inbound network connections, or to do manual signing for
// recovery purposes.
type SecretKeyRing interface {
KeyRing
ECDHRing
MessageSignerRing
// DerivePrivKey attempts to derive the private key that corresponds to
// the passed key descriptor. If the public key is set, then this
// method will perform an in-order scan over the key set, with a max of
// MaxKeyRangeScan keys. In order for this to work, the caller MUST set
// the KeyFamily within the partially populated KeyLocator.
DerivePrivKey(keyDesc KeyDescriptor) (*btcec.PrivateKey, error)
}
// MessageSignerRing is an interface that abstracts away basic low-level ECDSA
// signing on keys within a key ring.
type MessageSignerRing interface {
// SignMessage signs the given message, single or double SHA256 hashing
// it first, with the private key described in the key locator.
SignMessage(keyLoc KeyLocator, msg []byte,
doubleHash bool) (*ecdsa.Signature, error)
// SignMessageCompact signs the given message, single or double SHA256
// hashing it first, with the private key described in the key locator
// and returns the signature in the compact, public key recoverable
// format.
SignMessageCompact(keyLoc KeyLocator, msg []byte,
doubleHash bool) ([]byte, error)
// SignMessageSchnorr signs the given message, single or double SHA256
// hashing it first, with the private key described in the key locator
// and the optional Taproot tweak applied to the private key.
SignMessageSchnorr(keyLoc KeyLocator, msg []byte,
doubleHash bool, taprootTweak []byte,
tag []byte) (*schnorr.Signature, error)
}
// SingleKeyMessageSigner is an abstraction interface that hides the
// implementation of the low-level ECDSA signing operations by wrapping a
// single, specific private key.
type SingleKeyMessageSigner interface {
// PubKey returns the public key of the wrapped private key.
PubKey() *btcec.PublicKey
// KeyLocator returns the locator that describes the wrapped private
// key.
KeyLocator() KeyLocator
// SignMessage signs the given message, single or double SHA256 hashing
// it first, with the wrapped private key.
SignMessage(message []byte, doubleHash bool) (*ecdsa.Signature, error)
// SignMessageCompact signs the given message, single or double SHA256
// hashing it first, with the wrapped private key and returns the
// signature in the compact, public key recoverable format.
SignMessageCompact(message []byte, doubleHash bool) ([]byte, error)
}
// ECDHRing is an interface that abstracts away basic low-level ECDH shared key
// generation on keys within a key ring.
type ECDHRing interface {
// ECDH performs a scalar multiplication (ECDH-like operation) between
// the target key descriptor and remote public key. The output
// returned will be the sha256 of the resulting shared point serialized
// in compressed format. If k is our private key, and P is the public
// key, we perform the following operation:
//
// sx := k*P
// s := sha256(sx.SerializeCompressed())
ECDH(keyDesc KeyDescriptor, pubKey *btcec.PublicKey) ([32]byte, error)
}
// SingleKeyECDH is an abstraction interface that hides the implementation of an
// ECDH operation by wrapping a single, specific private key.
type SingleKeyECDH interface {
// PubKey returns the public key of the wrapped private key.
PubKey() *btcec.PublicKey
// ECDH performs a scalar multiplication (ECDH-like operation) between
// the wrapped private key and remote public key. The output returned
// will be the sha256 of the resulting shared point serialized in
// compressed format.
ECDH(pubKey *btcec.PublicKey) ([32]byte, error)
}
// TODO(roasbeef): extend to actually support scalar mult of key?
// * would allow to push in initial handshake auth into interface as well
package keychain
import (
"crypto/sha256"
"github.com/btcsuite/btcd/btcec/v2"
)
// NewPubKeyECDH wraps the given key of the key ring so it adheres to the
// SingleKeyECDH interface.
func NewPubKeyECDH(keyDesc KeyDescriptor, ecdh ECDHRing) *PubKeyECDH {
return &PubKeyECDH{
keyDesc: keyDesc,
ecdh: ecdh,
}
}
// PubKeyECDH is an implementation of the SingleKeyECDH interface. It wraps an
// ECDH key ring so it can perform ECDH shared key generation against a single
// abstracted away private key.
type PubKeyECDH struct {
keyDesc KeyDescriptor
ecdh ECDHRing
}
// PubKey returns the public key of the private key that is abstracted away by
// the interface.
//
// NOTE: This is part of the SingleKeyECDH interface.
func (p *PubKeyECDH) PubKey() *btcec.PublicKey {
return p.keyDesc.PubKey
}
// ECDH performs a scalar multiplication (ECDH-like operation) between the
// abstracted private key and a remote public key. The output returned will be
// the sha256 of the resulting shared point serialized in compressed format. If
// k is our private key, and P is the public key, we perform the following
// operation:
//
// sx := k*P
// s := sha256(sx.SerializeCompressed())
//
// NOTE: This is part of the SingleKeyECDH interface.
func (p *PubKeyECDH) ECDH(pubKey *btcec.PublicKey) ([32]byte, error) {
return p.ecdh.ECDH(p.keyDesc, pubKey)
}
// PrivKeyECDH is an implementation of the SingleKeyECDH in which we do have the
// full private key. This can be used to wrap a temporary key to conform to the
// SingleKeyECDH interface.
type PrivKeyECDH struct {
// PrivKey is the private key that is used for the ECDH operation.
PrivKey *btcec.PrivateKey
}
// PubKey returns the public key of the private key that is abstracted away by
// the interface.
//
// NOTE: This is part of the SingleKeyECDH interface.
func (p *PrivKeyECDH) PubKey() *btcec.PublicKey {
return p.PrivKey.PubKey()
}
// ECDH performs a scalar multiplication (ECDH-like operation) between the
// abstracted private key and a remote public key. The output returned will be
// the sha256 of the resulting shared point serialized in compressed format. If
// k is our private key, and P is the public key, we perform the following
// operation:
//
// sx := k*P
// s := sha256(sx.SerializeCompressed())
//
// NOTE: This is part of the SingleKeyECDH interface.
func (p *PrivKeyECDH) ECDH(pub *btcec.PublicKey) ([32]byte, error) {
var (
pubJacobian btcec.JacobianPoint
s btcec.JacobianPoint
)
pub.AsJacobian(&pubJacobian)
btcec.ScalarMultNonConst(&p.PrivKey.Key, &pubJacobian, &s)
s.ToAffine()
sPubKey := btcec.NewPublicKey(&s.X, &s.Y)
return sha256.Sum256(sPubKey.SerializeCompressed()), nil
}
var _ SingleKeyECDH = (*PubKeyECDH)(nil)
var _ SingleKeyECDH = (*PrivKeyECDH)(nil)
package keychain
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
func NewPubKeyMessageSigner(pubKey *btcec.PublicKey, keyLoc KeyLocator,
signer MessageSignerRing) *PubKeyMessageSigner {
return &PubKeyMessageSigner{
pubKey: pubKey,
keyLoc: keyLoc,
digestSigner: signer,
}
}
type PubKeyMessageSigner struct {
pubKey *btcec.PublicKey
keyLoc KeyLocator
digestSigner MessageSignerRing
}
func (p *PubKeyMessageSigner) PubKey() *btcec.PublicKey {
return p.pubKey
}
func (p *PubKeyMessageSigner) KeyLocator() KeyLocator {
return p.keyLoc
}
func (p *PubKeyMessageSigner) SignMessage(message []byte,
doubleHash bool) (*ecdsa.Signature, error) {
return p.digestSigner.SignMessage(p.keyLoc, message, doubleHash)
}
func (p *PubKeyMessageSigner) SignMessageCompact(msg []byte,
doubleHash bool) ([]byte, error) {
return p.digestSigner.SignMessageCompact(p.keyLoc, msg, doubleHash)
}
func NewPrivKeyMessageSigner(privKey *btcec.PrivateKey,
keyLoc KeyLocator) *PrivKeyMessageSigner {
return &PrivKeyMessageSigner{
privKey: privKey,
keyLoc: keyLoc,
}
}
type PrivKeyMessageSigner struct {
keyLoc KeyLocator
privKey *btcec.PrivateKey
}
func (p *PrivKeyMessageSigner) PubKey() *btcec.PublicKey {
return p.privKey.PubKey()
}
func (p *PrivKeyMessageSigner) KeyLocator() KeyLocator {
return p.keyLoc
}
func (p *PrivKeyMessageSigner) SignMessage(msg []byte,
doubleHash bool) (*ecdsa.Signature, error) {
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.Sign(p.privKey, digest), nil
}
func (p *PrivKeyMessageSigner) SignMessageCompact(msg []byte,
doubleHash bool) ([]byte, error) {
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.SignCompact(p.privKey, digest, true), nil
}
var _ SingleKeyMessageSigner = (*PubKeyMessageSigner)(nil)
var _ SingleKeyMessageSigner = (*PrivKeyMessageSigner)(nil)
//go:build !js
// +build !js
package kvdb
import (
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"time"
_ "github.com/btcsuite/btcwallet/walletdb/bdb" // Import to register backend.
)
const (
// DefaultTempDBFileName is the default name of the temporary bolt DB
// file that we'll use to atomically compact the primary DB file on
// startup.
DefaultTempDBFileName = "temp-dont-use.db"
// LastCompactionFileNameSuffix is the suffix we append to the file name
// of a database file to record the timestamp when the last compaction
// occurred.
LastCompactionFileNameSuffix = ".last-compacted"
)
var (
byteOrder = binary.BigEndian
)
// fileExists returns true if the file exists, and false otherwise.
func fileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// BoltBackendConfig is a struct that holds settings specific to the bolt
// database backend.
type BoltBackendConfig struct {
// DBPath is the directory path in which the database file should be
// stored.
DBPath string
// DBFileName is the name of the database file.
DBFileName string
// NoFreelistSync, if true, prevents the database from syncing its
// freelist to disk, resulting in improved performance at the expense of
// increased startup time.
NoFreelistSync bool
// AutoCompact specifies if a Bolt based database backend should be
// automatically compacted on startup (if the minimum age of the
// database file is reached). This will require additional disk space
// for the compacted copy of the database but will result in an overall
// lower database size after the compaction.
AutoCompact bool
// AutoCompactMinAge specifies the minimum time that must have passed
// since a bolt database file was last compacted for the compaction to
// be considered again.
AutoCompactMinAge time.Duration
// DBTimeout specifies the timeout value to use when opening the wallet
// database.
DBTimeout time.Duration
// ReadOnly specifies if the database should be opened in read-only
// mode.
ReadOnly bool
}
// GetBoltBackend opens (or creates if doesn't exits) a bbolt backed database
// and returns a kvdb.Backend wrapping it.
func GetBoltBackend(cfg *BoltBackendConfig) (Backend, error) {
dbFilePath := filepath.Join(cfg.DBPath, cfg.DBFileName)
// Is this a new database?
if !fileExists(dbFilePath) {
if !fileExists(cfg.DBPath) {
if err := os.MkdirAll(cfg.DBPath, 0700); err != nil {
return nil, err
}
}
return Create(
BoltBackendName, dbFilePath,
cfg.NoFreelistSync, cfg.DBTimeout, cfg.ReadOnly,
)
}
// This is an existing database. We might want to compact it on startup
// to free up some space.
if cfg.AutoCompact {
if err := compactAndSwap(cfg); err != nil {
return nil, err
}
}
return Open(
BoltBackendName, dbFilePath,
cfg.NoFreelistSync, cfg.DBTimeout, cfg.ReadOnly,
)
}
// compactAndSwap will attempt to write a new temporary DB file to disk with
// the compacted database content, then atomically swap (via rename) the old
// file for the new file by updating the name of the new file to the old.
func compactAndSwap(cfg *BoltBackendConfig) error {
sourceName := cfg.DBFileName
// If the main DB file isn't set, then we can't proceed.
if sourceName == "" {
return fmt.Errorf("cannot compact DB with empty name")
}
sourceFilePath := filepath.Join(cfg.DBPath, sourceName)
tempDestFilePath := filepath.Join(cfg.DBPath, DefaultTempDBFileName)
// Let's find out how long ago the last compaction of the source file
// occurred and possibly skip compacting it again now.
lastCompactionDate, err := lastCompactionDate(sourceFilePath)
if err != nil {
return fmt.Errorf("cannot determine last compaction date of "+
"source DB file: %v", err)
}
compactAge := time.Since(lastCompactionDate)
if cfg.AutoCompactMinAge != 0 && compactAge <= cfg.AutoCompactMinAge {
log.Infof("Not compacting database file at %v, it was last "+
"compacted at %v (%v ago), min age is set to %v",
sourceFilePath, lastCompactionDate,
compactAge.Truncate(time.Second), cfg.AutoCompactMinAge)
return nil
}
log.Infof("Compacting database file at %v", sourceFilePath)
// If the old temporary DB file still exists, then we'll delete it
// before proceeding.
if _, err := os.Stat(tempDestFilePath); err == nil {
log.Infof("Found old temp DB @ %v, removing before swap",
tempDestFilePath)
err = os.Remove(tempDestFilePath)
if err != nil {
return fmt.Errorf("unable to remove old temp DB file: "+
"%v", err)
}
}
// Now that we know the staging area is clear, we'll create the new
// temporary DB file and close it before we write the new DB to it.
tempFile, err := os.Create(tempDestFilePath)
if err != nil {
return fmt.Errorf("unable to create temp DB file: %w", err)
}
if err := tempFile.Close(); err != nil {
return fmt.Errorf("unable to close file: %w", err)
}
// With the file created, we'll start the compaction and remove the
// temporary file all together once this method exits.
defer func() {
// This will only succeed if the rename below fails. If the
// compaction is successful, the file won't exist on exit
// anymore so no need to log an error here.
_ = os.Remove(tempDestFilePath)
}()
c := &compacter{
srcPath: sourceFilePath,
dstPath: tempDestFilePath,
dbTimeout: cfg.DBTimeout,
}
initialSize, newSize, err := c.execute()
if err != nil {
return fmt.Errorf("error during compact: %w", err)
}
log.Infof("DB compaction of %v successful, %d -> %d bytes (gain=%.2fx)",
sourceFilePath, initialSize, newSize,
float64(initialSize)/float64(newSize))
// We try to store the current timestamp in a file with the suffix
// .last-compacted so we can figure out how long ago the last compaction
// was. But since this shouldn't fail the compaction process itself, we
// only log the error. Worst case if this file cannot be written is that
// we compact on every startup.
err = updateLastCompactionDate(sourceFilePath)
if err != nil {
log.Warnf("Could not update last compaction timestamp in "+
"%s%s: %v", sourceFilePath,
LastCompactionFileNameSuffix, err)
}
log.Infof("Swapping old DB file from %v to %v", tempDestFilePath,
sourceFilePath)
// Finally, we'll attempt to atomically rename the temporary file to
// the main back up file. If this succeeds, then we'll only have a
// single file on disk once this method exits.
return os.Rename(tempDestFilePath, sourceFilePath)
}
// lastCompactionDate returns the date the given database file was last
// compacted or a zero time.Time if no compaction was recorded before. The
// compaction date is read from a file in the same directory and with the same
// name as the DB file, but with the suffix ".last-compacted".
func lastCompactionDate(dbFile string) (time.Time, error) {
zeroTime := time.Unix(0, 0)
tsFile := fmt.Sprintf("%s%s", dbFile, LastCompactionFileNameSuffix)
if !fileExists(tsFile) {
return zeroTime, nil
}
tsBytes, err := os.ReadFile(tsFile)
if err != nil {
return zeroTime, err
}
tsNano := byteOrder.Uint64(tsBytes)
return time.Unix(0, int64(tsNano)), nil
}
// updateLastCompactionDate stores the current time as a timestamp in a file
// in the same directory and with the same name as the DB file, but with the
// suffix ".last-compacted".
func updateLastCompactionDate(dbFile string) error {
var tsBytes [8]byte
byteOrder.PutUint64(tsBytes[:], uint64(time.Now().UnixNano()))
tsFile := fmt.Sprintf("%s%s", dbFile, LastCompactionFileNameSuffix)
return os.WriteFile(tsFile, tsBytes[:], 0600)
}
// GetTestBackend opens (or creates if doesn't exist) a bbolt or etcd
// backed database (for testing), and returns a kvdb.Backend and a cleanup
// func. Whether to create/open bbolt or embedded etcd database is based
// on the TestBackend constant which is conditionally compiled with build tag.
// The passed path is used to hold all db files, while the name is only used
// for bbolt.
func GetTestBackend(path, name string) (Backend, func(), error) {
empty := func() {}
// Note that for tests, we expect only one db backend build flag
// (or none) to be set at a time and thus one of the following switch
// cases should ever be true
switch {
case PostgresBackend:
key := filepath.Join(path, name)
keyHash := sha256.Sum256([]byte(key))
f, err := NewPostgresFixture("test_" + hex.EncodeToString(
keyHash[:]),
)
if err != nil {
return nil, func() {}, err
}
return f.DB(), func() {
_ = f.DB().Close()
}, nil
case EtcdBackend:
etcdConfig, cancel, err := StartEtcdTestBackend(path, 0, 0, "")
if err != nil {
return nil, empty, err
}
backend, err := Open(
EtcdBackendName, context.TODO(), etcdConfig,
)
return backend, cancel, err
case SqliteBackend:
dbPath := filepath.Join(path, name)
keyHash := sha256.Sum256([]byte(dbPath))
sqliteDb, err := StartSqliteTestBackend(
path, name, "test_"+hex.EncodeToString(keyHash[:]),
)
if err != nil {
return nil, empty, err
}
return sqliteDb, func() {
_ = sqliteDb.Close()
}, nil
default:
db, err := GetBoltBackend(&BoltBackendConfig{
DBPath: path,
DBFileName: name,
NoFreelistSync: true,
DBTimeout: DefaultDBTimeout,
ReadOnly: false,
})
if err != nil {
return nil, nil, err
}
return db, empty, nil
}
}
// The code in this file is an adapted version of the bbolt compact command
// implemented in this file:
// https://github.com/etcd-io/bbolt/blob/master/cmd/bbolt/main.go
//go:build !js
// +build !js
package kvdb
import (
"encoding/hex"
"fmt"
"os"
"path"
"time"
"github.com/lightningnetwork/lnd/healthcheck"
"go.etcd.io/bbolt"
)
const (
// defaultResultFileSizeMultiplier is the default multiplier we apply to
// the current database size to calculate how big it could possibly get
// after compacting, in case the database is already at its optimal size
// and compaction causes it to grow. This should normally not be the
// case but we really want to avoid not having enough disk space for the
// compaction, so we apply a safety margin of 10%.
defaultResultFileSizeMultiplier = float64(1.1)
// defaultTxMaxSize is the default maximum number of operations that
// are allowed to be executed in a single transaction.
defaultTxMaxSize = 65536
// bucketFillSize is the fill size setting that is used for each new
// bucket that is created in the compacted database. This setting is not
// persisted and is therefore only effective for the compaction itself.
// Because during the compaction we only append data a fill percent of
// 100% is optimal for performance.
bucketFillSize = 1.0
)
type compacter struct {
srcPath string
dstPath string
txMaxSize int64
// dbTimeout specifies the timeout value used when opening the db.
dbTimeout time.Duration
}
// execute opens the source and destination databases and then compacts the
// source into destination and returns the size of both files as a result.
func (cmd *compacter) execute() (int64, int64, error) {
if cmd.txMaxSize == 0 {
cmd.txMaxSize = defaultTxMaxSize
}
// Ensure source file exists.
fi, err := os.Stat(cmd.srcPath)
if err != nil {
return 0, 0, fmt.Errorf("error determining source database "+
"size: %v", err)
}
initialSize := fi.Size()
marginSize := float64(initialSize) * defaultResultFileSizeMultiplier
// Before opening any of the databases, let's first make sure we have
// enough free space on the destination file system to create a full
// copy of the source DB (worst-case scenario if the compaction doesn't
// actually shrink the file size).
destFolder := path.Dir(cmd.dstPath)
freeSpace, err := healthcheck.AvailableDiskSpace(destFolder)
if err != nil {
return 0, 0, fmt.Errorf("error determining free disk space on "+
"%s: %v", destFolder, err)
}
log.Debugf("Free disk space on compaction destination file system: "+
"%d bytes", freeSpace)
if freeSpace < uint64(marginSize) {
return 0, 0, fmt.Errorf("could not start compaction, "+
"destination folder %s only has %d bytes of free disk "+
"space available while we need at least %d for worst-"+
"case compaction", destFolder, freeSpace, uint64(marginSize))
}
// Open source database. We open it in read only mode to avoid (and fix)
// possible freelist sync problems.
src, err := bbolt.Open(cmd.srcPath, 0444, &bbolt.Options{
ReadOnly: true,
Timeout: cmd.dbTimeout,
})
if err != nil {
return 0, 0, fmt.Errorf("error opening source database: %w",
err)
}
defer func() {
if err := src.Close(); err != nil {
log.Errorf("Compact error: closing source DB: %v", err)
}
}()
// Open destination database.
dst, err := bbolt.Open(cmd.dstPath, fi.Mode(), &bbolt.Options{
Timeout: cmd.dbTimeout,
})
if err != nil {
return 0, 0, fmt.Errorf("error opening destination database: "+
"%w", err)
}
defer func() {
if err := dst.Close(); err != nil {
log.Errorf("Compact error: closing dest DB: %v", err)
}
}()
// Run compaction.
if err := cmd.compact(dst, src); err != nil {
return 0, 0, fmt.Errorf("error running compaction: %w", err)
}
// Report stats on new size.
fi, err = os.Stat(cmd.dstPath)
if err != nil {
return 0, 0, fmt.Errorf("error determining destination "+
"database size: %w", err)
} else if fi.Size() == 0 {
return 0, 0, fmt.Errorf("zero db size")
}
return initialSize, fi.Size(), nil
}
// compact tries to create a compacted copy of the source database in a new
// destination database.
func (cmd *compacter) compact(dst, src *bbolt.DB) error {
// Commit regularly, or we'll run out of memory for large datasets if
// using one transaction.
var size int64
tx, err := dst.Begin(true)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()
if err := cmd.walk(src, func(keys [][]byte, k, v []byte, seq uint64) error {
// On each key/value, check if we have exceeded tx size.
sz := int64(len(k) + len(v))
if size+sz > cmd.txMaxSize && cmd.txMaxSize != 0 {
// Commit previous transaction.
if err := tx.Commit(); err != nil {
return err
}
// Start new transaction.
tx, err = dst.Begin(true)
if err != nil {
return err
}
size = 0
}
size += sz
// Create bucket on the root transaction if this is the first
// level.
nk := len(keys)
if nk == 0 {
bkt, err := tx.CreateBucket(k)
if err != nil {
return err
}
if err := bkt.SetSequence(seq); err != nil {
return err
}
return nil
}
// Create buckets on subsequent levels, if necessary.
b := tx.Bucket(keys[0])
if nk > 1 {
for _, k := range keys[1:] {
b = b.Bucket(k)
}
}
// Fill the entire page for best compaction.
b.FillPercent = bucketFillSize
// If there is no value then this is a bucket call.
if v == nil {
bkt, err := b.CreateBucket(k)
if err != nil {
return err
}
if err := bkt.SetSequence(seq); err != nil {
return err
}
return nil
}
// Otherwise treat it as a key/value pair.
return b.Put(k, v)
}); err != nil {
return err
}
return tx.Commit()
}
// walkFunc is the type of the function called for keys (buckets and "normal"
// values) discovered by Walk. keys is the list of keys to descend to the bucket
// owning the discovered key/value pair k/v.
type walkFunc func(keys [][]byte, k, v []byte, seq uint64) error
// walk walks recursively the bolt database db, calling walkFn for each key it
// finds.
func (cmd *compacter) walk(db *bbolt.DB, walkFn walkFunc) error {
return db.View(func(tx *bbolt.Tx) error {
return tx.ForEach(func(name []byte, b *bbolt.Bucket) error {
// This will log the top level buckets only to give the
// user some sense of progress.
log.Debugf("Compacting top level bucket '%s'",
LoggableKeyName(name))
return cmd.walkBucket(
b, nil, name, nil, b.Sequence(), walkFn,
)
})
})
}
// LoggableKeyName returns a printable name of the given key.
func LoggableKeyName(key []byte) string {
strKey := string(key)
if hasSpecialChars(strKey) {
return hex.EncodeToString(key)
}
return strKey
}
// hasSpecialChars returns true if any of the characters in the given string
// cannot be printed.
func hasSpecialChars(s string) bool {
for _, b := range s {
if !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z') &&
!(b >= '0' && b <= '9') && b != '-' && b != '_' {
return true
}
}
return false
}
// walkBucket recursively walks through a bucket.
func (cmd *compacter) walkBucket(b *bbolt.Bucket, keyPath [][]byte, k, v []byte,
seq uint64, fn walkFunc) error {
// Execute callback.
if err := fn(keyPath, k, v, seq); err != nil {
return err
}
// If this is not a bucket then stop.
if v != nil {
return nil
}
// Iterate over each child key/value.
keyPath = append(keyPath, k)
return b.ForEach(func(k, v []byte) error {
if v == nil {
bkt := b.Bucket(k)
return cmd.walkBucket(
bkt, keyPath, k, nil, bkt.Sequence(), fn,
)
}
return cmd.walkBucket(b, keyPath, k, v, b.Sequence(), fn)
})
}
package kvdb
import (
"path/filepath"
"testing"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/stretchr/testify/require"
)
type boltFixture struct {
t *testing.T
tempDir string
}
func NewBoltFixture(t *testing.T) *boltFixture {
return &boltFixture{
t: t,
tempDir: t.TempDir(),
}
}
func (b *boltFixture) NewBackend() walletdb.DB {
dbPath := filepath.Join(b.tempDir)
db, err := GetBoltBackend(&BoltBackendConfig{
DBPath: dbPath,
DBFileName: "test.db",
NoFreelistSync: true,
DBTimeout: DefaultDBTimeout,
ReadOnly: false,
})
require.NoError(b.t, err)
return db
}
package etcd
import "fmt"
// Config holds etcd configuration alongside with configuration related to our higher level interface.
//
//nolint:ll
type Config struct {
Embedded bool `long:"embedded" description:"Use embedded etcd instance instead of the external one. Note: use for testing only."`
EmbeddedClientPort uint16 `long:"embedded_client_port" description:"Client port to use for the embedded instance. Note: use for testing only."`
EmbeddedPeerPort uint16 `long:"embedded_peer_port" description:"Peer port to use for the embedded instance. Note: use for testing only."`
EmbeddedLogFile string `long:"embedded_log_file" description:"Optional log file to use for embedded instance logs. note: use for testing only."`
Host string `long:"host" description:"Etcd database host. Supports multiple hosts separated by a comma."`
User string `long:"user" description:"Etcd database user."`
Pass string `long:"pass" description:"Password for the database user."`
Namespace string `long:"namespace" description:"The etcd namespace to use."`
DisableTLS bool `long:"disabletls" description:"Disable TLS for etcd connection. Caution: use for development only."`
CertFile string `long:"cert_file" description:"Path to the TLS certificate for etcd RPC."`
KeyFile string `long:"key_file" description:"Path to the TLS private key for etcd RPC."`
InsecureSkipVerify bool `long:"insecure_skip_verify" description:"Whether we intend to skip TLS verification"`
CollectStats bool `long:"collect_stats" description:"Whether to collect etcd commit stats."`
MaxMsgSize int `long:"max_msg_size" description:"The maximum message size in bytes that we may send to etcd."`
// SingleWriter should be set to true if we intend to only allow a
// single writer to the database at a time.
SingleWriter bool
}
// CloneWithSubNamespace clones the current configuration and returns a new
// instance with the given sub namespace applied by appending it to the main
// namespace.
func (c *Config) CloneWithSubNamespace(subNamespace string) *Config {
ns := c.Namespace
if len(ns) == 0 {
ns = subNamespace
} else {
ns = fmt.Sprintf("%s/%s", ns, subNamespace)
}
return &Config{
Embedded: c.Embedded,
EmbeddedClientPort: c.EmbeddedClientPort,
EmbeddedPeerPort: c.EmbeddedPeerPort,
Host: c.Host,
User: c.User,
Pass: c.Pass,
Namespace: ns,
DisableTLS: c.DisableTLS,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
InsecureSkipVerify: c.InsecureSkipVerify,
CollectStats: c.CollectStats,
MaxMsgSize: c.MaxMsgSize,
SingleWriter: c.SingleWriter,
}
}
// CloneWithSingleWriter clones the current configuration and returns a new
// instance with the single writer property set to true.
func (c *Config) CloneWithSingleWriter() *Config {
return &Config{
Embedded: c.Embedded,
EmbeddedClientPort: c.EmbeddedClientPort,
EmbeddedPeerPort: c.EmbeddedPeerPort,
Host: c.Host,
User: c.User,
Pass: c.Pass,
Namespace: c.Namespace,
DisableTLS: c.DisableTLS,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
InsecureSkipVerify: c.InsecureSkipVerify,
CollectStats: c.CollectStats,
MaxMsgSize: c.MaxMsgSize,
SingleWriter: true,
}
}
package kvdb
import (
"github.com/btcsuite/btcwallet/walletdb"
)
// Update opens a database read/write transaction and executes the function f
// with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
// transaction is rolled back. If the rollback fails, the original error
// returned by f is still returned. If the commit fails, the commit error is
// returned. As callers may expect retries of the f closure (depending on the
// database backend used), the reset function will be called before each retry
// respectively.
func Update(db Backend, f func(tx RwTx) error, reset func()) error {
return db.Update(f, reset)
}
// View opens a database read transaction and executes the function f with the
// transaction passed as a parameter. After f exits, the transaction is rolled
// back. If f errors, its error is returned, not a rollback error (if any
// occur). The passed reset function is called before the start of the
// transaction and can be used to reset intermediate state. As callers may
// expect retries of the f closure (depending on the database backend used), the
// reset function will be called before each retry respectively.
func View(db Backend, f func(tx RTx) error, reset func()) error {
return db.View(f, reset)
}
// Batch is identical to the Update call, but it attempts to combine several
// individual Update transactions into a single write database transaction on
// an optimistic basis. This only has benefits if multiple goroutines call
// Batch. For etcd Batch simply does an Update since combination is more complex
// in that case due to STM retries.
func Batch(db Backend, f func(tx RwTx) error) error {
// Fall back to the normal Update method if the backend doesn't support
// batching.
if _, ok := db.(walletdb.BatchDB); !ok {
// Since Batch calls handle external state reset, we can safely
// pass in an empty reset closure.
return db.Update(f, func() {})
}
return walletdb.Batch(db, f)
}
// Create initializes and opens a database for the specified type. The
// arguments are specific to the database type driver. See the documentation
// for the database driver for further details.
//
// ErrDbUnknownType will be returned if the database type is not registered.
var Create = walletdb.Create
// Backend represents an ACID database. All database access is performed
// through read or read+write transactions.
type Backend = walletdb.DB
// Open opens an existing database for the specified type. The arguments are
// specific to the database type driver. See the documentation for the database
// driver for further details.
//
// ErrDbUnknownType will be returned if the database type is not registered.
var Open = walletdb.Open
// Driver defines a structure for backend drivers to use when they registered
// themselves as a backend which implements the Backend interface.
type Driver = walletdb.Driver
// RBucket represents a bucket (a hierarchical structure within the
// database) that is only allowed to perform read operations.
type RBucket = walletdb.ReadBucket
// RCursor represents a bucket cursor that can be positioned at the start or
// end of the bucket's key/value pairs and iterate over pairs in the bucket.
// This type is only allowed to perform database read operations.
type RCursor = walletdb.ReadCursor
// RTx represents a database transaction that can only be used for reads. If
// a database update must occur, use a RwTx.
type RTx = walletdb.ReadTx
// RwBucket represents a bucket (a hierarchical structure within the database)
// that is allowed to perform both read and write operations.
type RwBucket = walletdb.ReadWriteBucket
// RwCursor represents a bucket cursor that can be positioned at the start or
// end of the bucket's key/value pairs and iterate over pairs in the bucket.
// This abstraction is allowed to perform both database read and write
// operations.
type RwCursor = walletdb.ReadWriteCursor
// RwTx represents a database transaction that can be used for both reads and
// writes. When only reads are necessary, consider using a RTx instead.
type RwTx = walletdb.ReadWriteTx
// ExtendedRTx is an extension to walletdb.ReadTx to allow prefetching of keys.
type ExtendedRTx interface {
RTx
// RootBucket returns the "root bucket" which is pseudo bucket used
// when prefetching (keys from) top level buckets.
RootBucket() RBucket
}
// ExtendedRBucket is an extension to walletdb.ReadBucket to allow prefetching
// of all values inside buckets.
type ExtendedRBucket interface {
RBucket
// Prefetch will attempt to prefetch all values under a path.
Prefetch(paths ...[]string)
// ForAll is an optimized version of ForEach.
//
// NOTE: ForAll differs from ForEach in that no additional queries can
// be executed within the callback.
ForAll(func(k, v []byte) error) error
}
// Prefetch will attempt to prefetch all values under a path from the passed
// bucket.
func Prefetch(b RBucket, paths ...[]string) {
if bucket, ok := b.(ExtendedRBucket); ok {
bucket.Prefetch(paths...)
}
}
// ForAll is an optimized version of ForEach with the limitation that no
// additional queries can be executed within the callback.
func ForAll(b RBucket, cb func(k, v []byte) error) error {
if bucket, ok := b.(ExtendedRBucket); ok {
return bucket.ForAll(cb)
}
return b.ForEach(cb)
}
// RootBucket is a wrapper to ExtendedRTx.RootBucket which does nothing if
// the implementation doesn't have ExtendedRTx.
func RootBucket(t RTx) RBucket {
if tx, ok := t.(ExtendedRTx); ok {
return tx.RootBucket()
}
return nil
}
var (
// ErrBucketNotFound is returned when trying to access a bucket that
// has not been created yet.
ErrBucketNotFound = walletdb.ErrBucketNotFound
// ErrBucketExists is returned when creating a bucket that already
// exists.
ErrBucketExists = walletdb.ErrBucketExists
// ErrDatabaseNotOpen is returned when a database instance is accessed
// before it is opened or after it is closed.
ErrDatabaseNotOpen = walletdb.ErrDbNotOpen
// ErrDbDoesNotExist is returned when a database instance is opened
// but it does not exist.
ErrDbDoesNotExist = walletdb.ErrDbDoesNotExist
)
//go:build !kvdb_etcd
// +build !kvdb_etcd
package kvdb
import (
"fmt"
"github.com/lightningnetwork/lnd/kvdb/etcd"
)
// EtcdBackend is conditionally set to false when the kvdb_etcd build tag is not
// defined. This will allow testing of other database backends.
const EtcdBackend = false
var errEtcdNotAvailable = fmt.Errorf("etcd backend not available")
// StartEtcdTestBackend is a stub returning nil, and errEtcdNotAvailable error.
func StartEtcdTestBackend(path string, clientPort, peerPort uint16,
logFile string) (*etcd.Config, func(), error) {
return nil, func() {}, errEtcdNotAvailable
}
//go:build !kvdb_postgres
// +build !kvdb_postgres
package kvdb
import (
"errors"
"github.com/lightningnetwork/lnd/kvdb/postgres"
)
const PostgresBackend = false
func NewPostgresFixture(dbName string) (postgres.Fixture, error) {
return nil, errors.New("postgres backend not available")
}
func StartEmbeddedPostgres() (func() error, error) {
return nil, errors.New("postgres backend not available")
}
//go:build !kvdb_sqlite || (windows && (arm || 386)) || (linux && (ppc64 || mips || mipsle || mips64))
package kvdb
import (
"fmt"
"runtime"
"github.com/btcsuite/btcwallet/walletdb"
)
var errSqliteNotAvailable = fmt.Errorf("sqlite backend not available either "+
"due to the `kvdb_sqlite` build tag not being set, or due to this "+
"OS(%s) and/or architecture(%s) not being supported", runtime.GOOS,
runtime.GOARCH)
// SqliteBackend is conditionally set to false when the kvdb_sqlite build tag is
// not defined. This will allow testing of other database backends.
const SqliteBackend = false
// StartSqliteTestBackend is a stub returning nil, and errSqliteNotAvailable
// error.
func StartSqliteTestBackend(path, name, table string) (walletdb.DB, error) {
return nil, errSqliteNotAvailable
}
package kvdb
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
)
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
sqlbase.UseLogger(log)
}
package sqlbase
import (
"database/sql"
"fmt"
"sync"
_ "github.com/jackc/pgx/v4/stdlib"
)
// dbConn stores the actual connection and a user count.
type dbConn struct {
db *sql.DB
count int
}
// dbConnSet stores a set of connections.
type dbConnSet struct {
dbConn map[string]*dbConn
maxConnections int
// mu is used to guard access to the dbConn map.
mu sync.Mutex
}
// newDbConnSet initializes a new set of connections.
func newDbConnSet(maxConnections int) *dbConnSet {
return &dbConnSet{
dbConn: make(map[string]*dbConn),
maxConnections: maxConnections,
}
}
// Open opens a new database connection. If a connection already exists for the
// given dsn, the existing connection is returned.
func (d *dbConnSet) Open(driver, dsn string) (*sql.DB, error) {
d.mu.Lock()
defer d.mu.Unlock()
if dbConn, ok := d.dbConn[dsn]; ok {
dbConn.count++
return dbConn.db, nil
}
db, err := sql.Open(driver, dsn)
if err != nil {
return nil, err
}
// Limit maximum number of open connections. This is useful to prevent
// the server from running out of connections and returning an error.
// With this client-side limit in place, lnd will wait for a connection
// to become available.
if d.maxConnections != 0 {
db.SetMaxOpenConns(d.maxConnections)
}
d.dbConn[dsn] = &dbConn{
db: db,
count: 1,
}
return db, nil
}
// Close closes the connection with the given dsn. If there are still other
// users of the same connection, this function does nothing.
func (d *dbConnSet) Close(dsn string) error {
d.mu.Lock()
defer d.mu.Unlock()
dbConn, ok := d.dbConn[dsn]
if !ok {
return fmt.Errorf("connection not found: %v", dsn)
}
// Reduce user count.
dbConn.count--
// Do not close if there are other users.
if dbConn.count > 0 {
return nil
}
// Close connection.
delete(d.dbConn, dsn)
return dbConn.db.Close()
}
package sqlbase
import "github.com/btcsuite/btclog/v2"
// log is a logger that is initialized as disabled. This means the package will
// not perform any logging by default until a logger is set.
var log = btclog.Disabled
// UseLogger uses a specified Logger to output package logging info.
func UseLogger(logger btclog.Logger) {
log = logger
}
//go:build !kvdb_postgres && (!kvdb_sqlite || (windows && (arm || 386)) || (linux && (ppc64 || mips || mipsle || mips64)))
package sqlbase
func Init(maxConnections int) {}
package kvdb
type KV struct {
key string
val string
}
func reverseKVs(a []KV) []KV {
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
a[i], a[j] = a[j], a[i]
}
return a
}
package kvdb
import (
"fmt"
"os"
"testing"
)
// RunTests is a helper function to run the tests in a package with
// initialization and tear-down of a test kvdb backend.
func RunTests(m *testing.M) {
var close func() error
if PostgresBackend {
var err error
close, err = StartEmbeddedPostgres()
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
}
// os.Exit() does not respect defer statements
code := m.Run()
if close != nil {
err := close()
if err != nil {
fmt.Printf("Error: %v\n", err)
}
}
os.Exit(code)
}
// Package labels contains labels used to label transactions broadcast by lnd.
// These labels are used across packages, so they are declared in a separate
// package to avoid dependency issues.
//
// Labels for transactions broadcast by lnd have two set fields followed by an
// optional set labelled data values, all separated by colons.
// - Label version: an integer that indicates the version lnd used
// - Label type: the type of transaction we are labelling
// - {field name}-{value}: a named field followed by its value, these items are
// optional, and there may be more than field present.
//
// For version 0 we have the following optional data fields defined:
// - shortchanid: the short channel ID that a transaction is associated with,
// with its value set to the uint64 short channel id.
package labels
import (
"fmt"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/lnwire"
)
// External labels a transaction as user initiated via the api. This
// label is only used when a custom user provided label is not given.
const External = "external"
// ValidateAPI returns the generic api label if the label provided is empty.
// This allows us to label all transactions published by the api, even if
// no label is provided. If a label is provided, it is validated against
// the known restrictions.
func ValidateAPI(label string) (string, error) {
if len(label) > wtxmgr.TxLabelLimit {
return "", fmt.Errorf("label length: %v exceeds "+
"limit of %v", len(label), wtxmgr.TxLabelLimit)
}
// If no label was provided by the user, add the generic user
// send label.
if len(label) == 0 {
return External, nil
}
return label, nil
}
// LabelVersion versions our labels so they can be easily update to contain
// new data while still easily string matched.
type LabelVersion uint8
// LabelVersionZero is the label version for labels that contain label type and
// channel ID (where available).
const LabelVersionZero LabelVersion = iota
// LabelType indicates the type of label we are creating. It is a string rather
// than an int for easy string matching and human-readability.
type LabelType string
const (
// LabelTypeChannelOpen is used to label channel opens.
LabelTypeChannelOpen LabelType = "openchannel"
// LabelTypeChannelClose is used to label channel closes.
LabelTypeChannelClose LabelType = "closechannel"
// LabelTypeJusticeTransaction is used to label justice transactions.
LabelTypeJusticeTransaction LabelType = "justicetx"
// LabelTypeSweepTransaction is used to label sweeps.
LabelTypeSweepTransaction LabelType = "sweep"
)
// LabelField is used to tag a value within a label.
type LabelField string
const (
// ShortChanID is used to tag short channel id values in our labels.
ShortChanID LabelField = "shortchanid"
)
// MakeLabel creates a label with the provided type and short channel id. If
// our short channel ID is not known, we simply return version:label_type. If
// we do have a short channel ID set, the label will also contain its value:
// shortchanid-{int64 chan ID}.
func MakeLabel(labelType LabelType, channelID *lnwire.ShortChannelID) string {
if channelID == nil {
return fmt.Sprintf("%v:%v", LabelVersionZero, labelType)
}
return fmt.Sprintf("%v:%v:%v-%v", LabelVersionZero, labelType,
ShortChanID, channelID.ToUint64())
}
package lnmock
import (
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/stretchr/testify/mock"
)
// MockChain is a mock implementation of the Chain interface.
type MockChain struct {
mock.Mock
}
// Compile-time constraint to ensure MockChain implements the Chain interface.
var _ chain.Interface = (*MockChain)(nil)
func (m *MockChain) Start() error {
args := m.Called()
return args.Error(0)
}
func (m *MockChain) Stop() {
m.Called()
}
func (m *MockChain) WaitForShutdown() {
m.Called()
}
func (m *MockChain) GetBestBlock() (*chainhash.Hash, int32, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Get(1).(int32), args.Error(2)
}
return args.Get(0).(*chainhash.Hash), args.Get(1).(int32), args.Error(2)
}
func (m *MockChain) GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) {
args := m.Called(hash)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.MsgBlock), args.Error(1)
}
func (m *MockChain) GetBlockHash(height int64) (*chainhash.Hash, error) {
args := m.Called(height)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*chainhash.Hash), args.Error(1)
}
func (m *MockChain) GetBlockHeader(hash *chainhash.Hash) (
*wire.BlockHeader, error) {
args := m.Called(hash)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.BlockHeader), args.Error(1)
}
func (m *MockChain) GetUtxo(op *wire.OutPoint, pkScript []byte,
heightHint uint32, cancel <-chan struct{}) (*wire.TxOut, error) {
args := m.Called(op, pkScript, heightHint, cancel)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*wire.TxOut), args.Error(1)
}
func (m *MockChain) IsCurrent() bool {
args := m.Called()
return args.Bool(0)
}
func (m *MockChain) FilterBlocks(req *chain.FilterBlocksRequest) (
*chain.FilterBlocksResponse, error) {
args := m.Called(req)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*chain.FilterBlocksResponse), args.Error(1)
}
func (m *MockChain) BlockStamp() (*waddrmgr.BlockStamp, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*waddrmgr.BlockStamp), args.Error(1)
}
func (m *MockChain) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) (
*chainhash.Hash, error) {
args := m.Called(tx, allowHighFees)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*chainhash.Hash), args.Error(1)
}
func (m *MockChain) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address,
outPoints map[wire.OutPoint]btcutil.Address) error {
args := m.Called(startHash, addrs, outPoints)
return args.Error(0)
}
func (m *MockChain) NotifyReceived(addrs []btcutil.Address) error {
args := m.Called(addrs)
return args.Error(0)
}
func (m *MockChain) NotifyBlocks() error {
args := m.Called()
return args.Error(0)
}
func (m *MockChain) Notifications() <-chan interface{} {
args := m.Called()
return args.Get(0).(<-chan interface{})
}
func (m *MockChain) BackEnd() string {
args := m.Called()
return args.String(0)
}
func (m *MockChain) TestMempoolAccept(txns []*wire.MsgTx, maxFeeRate float64) (
[]*btcjson.TestMempoolAcceptResult, error) {
args := m.Called(txns, maxFeeRate)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]*btcjson.TestMempoolAcceptResult), args.Error(1)
}
func (m *MockChain) MapRPCErr(err error) error {
args := m.Called(err)
return args.Error(0)
}
// NOTE: forcetypeassert is skipped for the mock because the test would fail if
// the returned value doesn't match the type.
package lnmock
import (
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/stretchr/testify/mock"
)
// MockClock implements the `clock.Clock` interface.
type MockClock struct {
mock.Mock
}
// Compile time assertion that MockClock implements clock.Clock.
var _ clock.Clock = (*MockClock)(nil)
func (m *MockClock) Now() time.Time {
args := m.Called()
return args.Get(0).(time.Time)
}
func (m *MockClock) TickAfter(d time.Duration) <-chan time.Time {
args := m.Called(d)
return args.Get(0).(chan time.Time)
}
package lnmock
import (
"bytes"
"github.com/lightningnetwork/lnd/lnwire"
)
// MockOnion returns a mock onion payload.
func MockOnion() [lnwire.OnionPacketSize]byte {
var onion [lnwire.OnionPacketSize]byte
onionBlob := bytes.Repeat([]byte{1}, lnwire.OnionPacketSize)
copy(onion[:], onionBlob)
return onion
}
package lnpeer
import (
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/mock"
)
// MockPeer implements the `lnpeer.Peer` interface.
type MockPeer struct {
mock.Mock
}
// Compile time assertion that MockPeer implements lnpeer.Peer.
var _ Peer = (*MockPeer)(nil)
func (m *MockPeer) SendMessage(sync bool, msgs ...lnwire.Message) error {
args := m.Called(sync, msgs)
return args.Error(0)
}
func (m *MockPeer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error {
args := m.Called(sync, msgs)
return args.Error(0)
}
func (m *MockPeer) AddNewChannel(channel *NewChannel,
cancel <-chan struct{}) error {
args := m.Called(channel, cancel)
return args.Error(0)
}
func (m *MockPeer) AddPendingChannel(cid lnwire.ChannelID,
cancel <-chan struct{}) error {
args := m.Called(cid, cancel)
return args.Error(0)
}
func (m *MockPeer) RemovePendingChannel(cid lnwire.ChannelID) error {
args := m.Called(cid)
return args.Error(0)
}
func (m *MockPeer) WipeChannel(op *wire.OutPoint) {
m.Called(op)
}
func (m *MockPeer) PubKey() [33]byte {
args := m.Called()
return args.Get(0).([33]byte)
}
func (m *MockPeer) IdentityKey() *btcec.PublicKey {
args := m.Called()
return args.Get(0).(*btcec.PublicKey)
}
func (m *MockPeer) Address() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockPeer) QuitSignal() <-chan struct{} {
args := m.Called()
return args.Get(0).(<-chan struct{})
}
func (m *MockPeer) LocalFeatures() *lnwire.FeatureVector {
args := m.Called()
return args.Get(0).(*lnwire.FeatureVector)
}
func (m *MockPeer) RemoteFeatures() *lnwire.FeatureVector {
args := m.Called()
return args.Get(0).(*lnwire.FeatureVector)
}
func (m *MockPeer) Disconnect(err error) {}
package mock
import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
// ChainIO is a mock implementation of the BlockChainIO interface.
type ChainIO struct {
BestHeight int32
}
// GetBestBlock currently returns dummy values.
func (c *ChainIO) GetBestBlock() (*chainhash.Hash, int32, error) {
return chaincfg.TestNet3Params.GenesisHash, c.BestHeight, nil
}
// GetUtxo currently returns dummy values.
func (c *ChainIO) GetUtxo(op *wire.OutPoint, _ []byte,
heightHint uint32, _ <-chan struct{}) (*wire.TxOut, error) {
return nil, nil
}
// GetBlockHash currently returns dummy values.
func (c *ChainIO) GetBlockHash(blockHeight int64) (*chainhash.Hash, error) {
return nil, nil
}
// GetBlock currently returns dummy values.
func (c *ChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
return nil, nil
}
// GetBlockHeader currently returns dummy values.
func (c *ChainIO) GetBlockHeader(blockHash *chainhash.Hash) (*wire.BlockHeader,
error) {
return nil, nil
}
package mock
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
)
// ChainNotifier is a mock implementation of the ChainNotifier interface.
type ChainNotifier struct {
SpendChan chan *chainntnfs.SpendDetail
EpochChan chan *chainntnfs.BlockEpoch
ConfChan chan *chainntnfs.TxConfirmation
}
// RegisterConfirmationsNtfn returns a ConfirmationEvent that contains a channel
// that the tx confirmation will go over.
func (c *ChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts ...chainntnfs.NotifierOption) (*chainntnfs.ConfirmationEvent, error) {
return &chainntnfs.ConfirmationEvent{
Confirmed: c.ConfChan,
Cancel: func() {},
}, nil
}
// RegisterSpendNtfn returns a SpendEvent that contains a channel that the spend
// details will go over.
func (c *ChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
return &chainntnfs.SpendEvent{
Spend: c.SpendChan,
Cancel: func() {},
}, nil
}
// RegisterBlockEpochNtfn returns a BlockEpochEvent that contains a channel that
// block epochs will go over.
func (c *ChainNotifier) RegisterBlockEpochNtfn(blockEpoch *chainntnfs.BlockEpoch) (
*chainntnfs.BlockEpochEvent, error) {
return &chainntnfs.BlockEpochEvent{
Epochs: c.EpochChan,
Cancel: func() {},
}, nil
}
// Start currently returns a dummy value.
func (c *ChainNotifier) Start() error {
return nil
}
// Started currently returns a dummy value.
func (c *ChainNotifier) Started() bool {
return true
}
// Stop currently returns a dummy value.
func (c *ChainNotifier) Stop() error {
return nil
}
package mock
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/lightningnetwork/lnd/keychain"
)
// SecretKeyRing is a mock implementation of the SecretKeyRing interface.
type SecretKeyRing struct {
RootKey *btcec.PrivateKey
}
// DeriveNextKey currently returns dummy values.
func (s *SecretKeyRing) DeriveNextKey(
_ keychain.KeyFamily) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{
PubKey: s.RootKey.PubKey(),
}, nil
}
// DeriveKey currently returns dummy values.
func (s *SecretKeyRing) DeriveKey(
_ keychain.KeyLocator) (keychain.KeyDescriptor, error) {
return keychain.KeyDescriptor{
PubKey: s.RootKey.PubKey(),
}, nil
}
// DerivePrivKey currently returns dummy values.
func (s *SecretKeyRing) DerivePrivKey(
_ keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
return s.RootKey, nil
}
// ECDH currently returns dummy values.
func (s *SecretKeyRing) ECDH(_ keychain.KeyDescriptor,
_ *btcec.PublicKey) ([32]byte, error) {
return [32]byte{}, nil
}
// SignMessage signs the passed message and ignores the KeyDescriptor.
func (s *SecretKeyRing) SignMessage(_ keychain.KeyLocator,
msg []byte, doubleHash bool) (*ecdsa.Signature, error) {
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.Sign(s.RootKey, digest), nil
}
// SignMessageCompact signs the passed message.
func (s *SecretKeyRing) SignMessageCompact(_ keychain.KeyLocator,
msg []byte, doubleHash bool) ([]byte, error) {
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.SignCompact(s.RootKey, digest, true), nil
}
// SignMessageSchnorr signs the passed message and ignores the KeyDescriptor.
func (s *SecretKeyRing) SignMessageSchnorr(_ keychain.KeyLocator,
msg []byte, doubleHash bool, taprootTweak []byte,
tag []byte) (*schnorr.Signature, error) {
var digest []byte
switch {
case len(tag) > 0:
taggedHash := chainhash.TaggedHash(tag, msg)
digest = taggedHash[:]
case doubleHash:
digest = chainhash.DoubleHashB(msg)
default:
digest = chainhash.HashB(msg)
}
privKey := s.RootKey
if len(taprootTweak) > 0 {
privKey = txscript.TweakTaprootPrivKey(*privKey, taprootTweak)
}
return schnorr.Sign(privKey, digest)
}
package mock
import (
"crypto/sha256"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
var (
idKeyLoc = keychain.KeyLocator{Family: keychain.KeyFamilyNodeKey}
)
// DummySignature is a dummy Signature implementation.
type DummySignature struct{}
// Serialize returns an empty byte slice.
func (d *DummySignature) Serialize() []byte {
return []byte{}
}
// Verify always returns true.
func (d *DummySignature) Verify(_ []byte, _ *btcec.PublicKey) bool {
return true
}
// DummySigner is an implementation of the Signer interface that returns
// dummy values when called.
type DummySigner struct{}
// SignOutputRaw returns a dummy signature.
func (d *DummySigner) SignOutputRaw(tx *wire.MsgTx,
signDesc *input.SignDescriptor) (input.Signature, error) {
return &DummySignature{}, nil
}
// ComputeInputScript returns nil for both values.
func (d *DummySigner) ComputeInputScript(tx *wire.MsgTx,
signDesc *input.SignDescriptor) (*input.Script, error) {
return &input.Script{}, nil
}
// MuSig2CreateSession creates a new MuSig2 signing session using the local
// key identified by the key locator. The complete list of all public keys of
// all signing parties must be provided, including the public key of the local
// signing key. If nonces of other parties are already known, they can be
// submitted as well to reduce the number of method calls necessary later on.
func (d *DummySigner) MuSig2CreateSession(input.MuSig2Version,
keychain.KeyLocator, []*btcec.PublicKey, *input.MuSig2Tweaks,
[][musig2.PubNonceSize]byte, *musig2.Nonces,
) (*input.MuSig2SessionInfo, error) {
return nil, nil
}
// MuSig2RegisterNonces registers one or more public nonces of other signing
// participants for a session identified by its ID. This method returns true
// once we have all nonces for all other signing participants.
func (d *DummySigner) MuSig2RegisterNonces(input.MuSig2SessionID,
[][musig2.PubNonceSize]byte) (bool, error) {
return false, nil
}
// MuSig2Sign creates a partial signature using the local signing key
// that was specified when the session was created. This can only be
// called when all public nonces of all participants are known and have
// been registered with the session. If this node isn't responsible for
// combining all the partial signatures, then the cleanup parameter
// should be set, indicating that the session can be removed from memory
// once the signature was produced.
func (d *DummySigner) MuSig2Sign(input.MuSig2SessionID,
[sha256.Size]byte, bool) (*musig2.PartialSignature, error) {
return nil, nil
}
// MuSig2CombineSig combines the given partial signature(s) with the
// local one, if it already exists. Once a partial signature of all
// participants is registered, the final signature will be combined and
// returned.
func (d *DummySigner) MuSig2CombineSig(input.MuSig2SessionID,
[]*musig2.PartialSignature) (*schnorr.Signature, bool, error) {
return nil, false, nil
}
// MuSig2Cleanup removes a session from memory to free up resources.
func (d *DummySigner) MuSig2Cleanup(input.MuSig2SessionID) error {
return nil
}
// SingleSigner is an implementation of the Signer interface that signs
// everything with a single private key.
type SingleSigner struct {
Privkey *btcec.PrivateKey
KeyLoc keychain.KeyLocator
*input.MusigSessionManager
}
func NewSingleSigner(privkey *btcec.PrivateKey) *SingleSigner {
signer := &SingleSigner{
Privkey: privkey,
KeyLoc: idKeyLoc,
}
keyFetcher := func(*keychain.KeyDescriptor) (*btcec.PrivateKey, error) {
return signer.Privkey, nil
}
signer.MusigSessionManager = input.NewMusigSessionManager(keyFetcher)
return signer
}
// SignOutputRaw generates a signature for the passed transaction using the
// stored private key.
func (s *SingleSigner) SignOutputRaw(tx *wire.MsgTx,
signDesc *input.SignDescriptor) (input.Signature, error) {
amt := signDesc.Output.Value
witnessScript := signDesc.WitnessScript
privKey := s.Privkey
if !privKey.PubKey().IsEqual(signDesc.KeyDesc.PubKey) {
return nil, fmt.Errorf("incorrect key passed")
}
switch {
case signDesc.SingleTweak != nil:
privKey = input.TweakPrivKey(privKey,
signDesc.SingleTweak)
case signDesc.DoubleTweak != nil:
privKey = input.DeriveRevocationPrivKey(privKey,
signDesc.DoubleTweak)
}
sig, err := txscript.RawTxInWitnessSignature(tx, signDesc.SigHashes,
signDesc.InputIndex, amt, witnessScript, signDesc.HashType,
privKey)
if err != nil {
return nil, err
}
return ecdsa.ParseDERSignature(sig[:len(sig)-1])
}
// ComputeInputScript computes an input script with the stored private key
// given a transaction and a SignDescriptor.
func (s *SingleSigner) ComputeInputScript(tx *wire.MsgTx,
signDesc *input.SignDescriptor) (*input.Script, error) {
privKey := s.Privkey
switch {
case signDesc.SingleTweak != nil:
privKey = input.TweakPrivKey(privKey,
signDesc.SingleTweak)
case signDesc.DoubleTweak != nil:
privKey = input.DeriveRevocationPrivKey(privKey,
signDesc.DoubleTweak)
}
witnessScript, err := txscript.WitnessSignature(tx, signDesc.SigHashes,
signDesc.InputIndex, signDesc.Output.Value, signDesc.Output.PkScript,
signDesc.HashType, privKey, true)
if err != nil {
return nil, err
}
return &input.Script{
Witness: witnessScript,
}, nil
}
// SignMessage takes a public key and a message and only signs the message
// with the stored private key if the public key matches the private key.
func (s *SingleSigner) SignMessage(keyLoc keychain.KeyLocator,
msg []byte, doubleHash bool) (*ecdsa.Signature, error) {
mockKeyLoc := s.KeyLoc
if s.KeyLoc.IsEmpty() {
mockKeyLoc = idKeyLoc
}
if keyLoc != mockKeyLoc {
return nil, fmt.Errorf("unknown public key")
}
var digest []byte
if doubleHash {
digest = chainhash.DoubleHashB(msg)
} else {
digest = chainhash.HashB(msg)
}
return ecdsa.Sign(s.Privkey, digest), nil
}
package mock
import (
"sync"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
)
// SpendNotifier extends the mock.ChainNotifier so that spend
// notifications can be triggered and delivered to subscribers.
type SpendNotifier struct {
*ChainNotifier
spendMap map[wire.OutPoint][]chan *chainntnfs.SpendDetail
spends map[wire.OutPoint]*chainntnfs.SpendDetail
mtx sync.Mutex
}
// MakeMockSpendNotifier creates a SpendNotifier.
func MakeMockSpendNotifier() *SpendNotifier {
return &SpendNotifier{
ChainNotifier: &ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
},
spendMap: make(map[wire.OutPoint][]chan *chainntnfs.SpendDetail),
spends: make(map[wire.OutPoint]*chainntnfs.SpendDetail),
}
}
// RegisterSpendNtfn registers a spend notification for a specified outpoint.
func (s *SpendNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
_ []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
spendChan := make(chan *chainntnfs.SpendDetail, 1)
if detail, ok := s.spends[*outpoint]; ok {
// Deliver spend immediately if details are already known.
spendChan <- &chainntnfs.SpendDetail{
SpentOutPoint: detail.SpentOutPoint,
SpendingHeight: detail.SpendingHeight,
SpendingTx: detail.SpendingTx,
SpenderTxHash: detail.SpenderTxHash,
SpenderInputIndex: detail.SpenderInputIndex,
}
} else {
// Otherwise, queue the notification for delivery if the spend
// is ever received.
s.spendMap[*outpoint] = append(s.spendMap[*outpoint], spendChan)
}
return &chainntnfs.SpendEvent{
Spend: spendChan,
Cancel: func() {},
}, nil
}
// Spend dispatches SpendDetails to all subscribers of the outpoint. The details
// will includethe transaction and height provided by the caller.
func (s *SpendNotifier) Spend(outpoint *wire.OutPoint, height int32,
txn *wire.MsgTx) {
s.mtx.Lock()
defer s.mtx.Unlock()
var inputIndex uint32
for i, in := range txn.TxIn {
if in.PreviousOutPoint == *outpoint {
inputIndex = uint32(i)
}
}
txnHash := txn.TxHash()
details := &chainntnfs.SpendDetail{
SpentOutPoint: outpoint,
SpendingHeight: height,
SpendingTx: txn,
SpenderTxHash: &txnHash,
SpenderInputIndex: inputIndex,
}
// Cache details in case of late registration.
if _, ok := s.spends[*outpoint]; !ok {
s.spends[*outpoint] = details
}
// Deliver any backlogged spend notifications.
if spendChans, ok := s.spendMap[*outpoint]; ok {
delete(s.spendMap, *outpoint)
for _, spendChan := range spendChans {
spendChan <- &chainntnfs.SpendDetail{
SpentOutPoint: details.SpentOutPoint,
SpendingHeight: details.SpendingHeight,
SpendingTx: details.SpendingTx,
SpenderTxHash: details.SpenderTxHash,
SpenderInputIndex: details.SpenderInputIndex,
}
}
}
}
package mock
import (
"encoding/hex"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
base "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
CoinPkScript, _ = hex.DecodeString("001431df1bde03c074d0cf21ea2529427e1499b8f1de")
)
// WalletController is a mock implementation of the WalletController
// interface. It let's us mock the interaction with the bitcoin network.
type WalletController struct {
RootKey *btcec.PrivateKey
PublishedTransactions chan *wire.MsgTx
index uint32
Utxos []*lnwallet.Utxo
}
// A compile time check to ensure this mocked WalletController implements the
// WalletController.
var _ lnwallet.WalletController = (*WalletController)(nil)
// BackEnd returns "mock" to signify a mock wallet controller.
func (w *WalletController) BackEnd() string {
return "mock"
}
// FetchOutpointInfo will be called to get info about the inputs to the funding
// transaction.
func (w *WalletController) FetchOutpointInfo(
prevOut *wire.OutPoint) (*lnwallet.Utxo, error) {
utxo := &lnwallet.Utxo{
AddressType: lnwallet.WitnessPubKey,
Value: 10 * btcutil.SatoshiPerBitcoin,
PkScript: []byte("dummy"),
Confirmations: 1,
OutPoint: *prevOut,
}
return utxo, nil
}
// ScriptForOutput returns the address, witness program and redeem script for a
// given UTXO. An error is returned if the UTXO does not belong to our wallet or
// it is not a managed pubKey address.
func (w *WalletController) ScriptForOutput(*wire.TxOut) (
waddrmgr.ManagedPubKeyAddress, []byte, []byte, error) {
return nil, nil, nil, nil
}
// ConfirmedBalance currently returns dummy values.
func (w *WalletController) ConfirmedBalance(int32, string) (btcutil.Amount,
error) {
return 0, nil
}
// NewAddress is called to get new addresses for delivery, change etc.
func (w *WalletController) NewAddress(lnwallet.AddressType, bool,
string) (btcutil.Address, error) {
pkh := btcutil.Hash160(w.RootKey.PubKey().SerializeCompressed())
addr, _ := btcutil.NewAddressPubKeyHash(pkh, &chaincfg.MainNetParams)
return addr, nil
}
// LastUnusedAddress currently returns dummy values.
func (w *WalletController) LastUnusedAddress(lnwallet.AddressType,
string) (btcutil.Address, error) {
return nil, nil
}
// IsOurAddress currently returns a dummy value.
func (w *WalletController) IsOurAddress(btcutil.Address) bool {
return false
}
// AddressInfo currently returns a dummy value.
func (w *WalletController) AddressInfo(
btcutil.Address) (waddrmgr.ManagedAddress, error) {
return nil, nil
}
// ListAccounts currently returns a dummy value.
func (w *WalletController) ListAccounts(string,
*waddrmgr.KeyScope) ([]*waddrmgr.AccountProperties, error) {
return nil, nil
}
// RequiredReserve currently returns a dummy value.
func (w *WalletController) RequiredReserve(uint32) btcutil.Amount {
return 0
}
// ListAddresses currently returns a dummy value.
func (w *WalletController) ListAddresses(string,
bool) (lnwallet.AccountAddressMap, error) {
return nil, nil
}
// ImportAccount currently returns a dummy value.
func (w *WalletController) ImportAccount(string, *hdkeychain.ExtendedKey,
uint32, *waddrmgr.AddressType, bool) (*waddrmgr.AccountProperties,
[]btcutil.Address, []btcutil.Address, error) {
return nil, nil, nil, nil
}
// ImportPublicKey currently returns a dummy value.
func (w *WalletController) ImportPublicKey(*btcec.PublicKey,
waddrmgr.AddressType) error {
return nil
}
// ImportTaprootScript currently returns a dummy value.
func (w *WalletController) ImportTaprootScript(waddrmgr.KeyScope,
*waddrmgr.Tapscript) (waddrmgr.ManagedAddress, error) {
return nil, nil
}
// SendOutputs currently returns dummy values.
func (w *WalletController) SendOutputs(fn.Set[wire.OutPoint], []*wire.TxOut,
chainfee.SatPerKWeight, int32, string, base.CoinSelectionStrategy) (
*wire.MsgTx, error) {
return nil, nil
}
// CreateSimpleTx currently returns dummy values.
func (w *WalletController) CreateSimpleTx(fn.Set[wire.OutPoint], []*wire.TxOut,
chainfee.SatPerKWeight, int32, base.CoinSelectionStrategy,
bool) (*txauthor.AuthoredTx, error) {
return nil, nil
}
// ListUnspentWitness is called by the wallet when doing coin selection. We just
// need one unspent for the funding transaction.
func (w *WalletController) ListUnspentWitness(int32, int32,
string) ([]*lnwallet.Utxo, error) {
// If the mock already has a list of utxos, return it.
if w.Utxos != nil {
return w.Utxos, nil
}
// Otherwise create one to return.
utxo := &lnwallet.Utxo{
AddressType: lnwallet.WitnessPubKey,
Value: btcutil.Amount(10 * btcutil.SatoshiPerBitcoin),
PkScript: CoinPkScript,
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{},
Index: w.index,
},
}
atomic.AddUint32(&w.index, 1)
var ret []*lnwallet.Utxo
ret = append(ret, utxo)
return ret, nil
}
// ListTransactionDetails currently returns dummy values.
func (w *WalletController) ListTransactionDetails(int32, int32,
string, uint32, uint32) ([]*lnwallet.TransactionDetail,
uint64, uint64, error) {
return nil, 0, 0, nil
}
// LeaseOutput returns the current time and a nil error.
func (w *WalletController) LeaseOutput(wtxmgr.LockID, wire.OutPoint,
time.Duration) (time.Time, error) {
return time.Now(), nil
}
// ReleaseOutput currently does nothing.
func (w *WalletController) ReleaseOutput(wtxmgr.LockID, wire.OutPoint) error {
return nil
}
func (w *WalletController) ListLeasedOutputs() ([]*base.ListLeasedOutputResult,
error) {
return nil, nil
}
// FundPsbt currently does nothing.
func (w *WalletController) FundPsbt(*psbt.Packet, int32, chainfee.SatPerKWeight,
string, *waddrmgr.KeyScope, base.CoinSelectionStrategy,
func(utxo wtxmgr.Credit) bool) (int32, error) {
return 0, nil
}
// SignPsbt currently does nothing.
func (w *WalletController) SignPsbt(*psbt.Packet) ([]uint32, error) {
return nil, nil
}
// FinalizePsbt currently does nothing.
func (w *WalletController) FinalizePsbt(_ *psbt.Packet, _ string) error {
return nil
}
// DecorateInputs currently does nothing.
func (w *WalletController) DecorateInputs(*psbt.Packet, bool) error {
return nil
}
// PublishTransaction sends a transaction to the PublishedTransactions chan.
func (w *WalletController) PublishTransaction(tx *wire.MsgTx, _ string) error {
w.PublishedTransactions <- tx
return nil
}
// GetTransactionDetails currently does nothing.
func (w *WalletController) GetTransactionDetails(
txHash *chainhash.Hash) (*lnwallet.TransactionDetail, error) {
return nil, nil
}
// LabelTransaction currently does nothing.
func (w *WalletController) LabelTransaction(chainhash.Hash, string,
bool) error {
return nil
}
// SubscribeTransactions currently does nothing.
func (w *WalletController) SubscribeTransactions() (lnwallet.TransactionSubscription,
error) {
return nil, nil
}
// IsSynced currently returns dummy values.
func (w *WalletController) IsSynced() (bool, int64, error) {
return true, int64(0), nil
}
// GetRecoveryInfo currently returns dummy values.
func (w *WalletController) GetRecoveryInfo() (bool, float64, error) {
return true, float64(1), nil
}
// Start currently does nothing.
func (w *WalletController) Start() error {
return nil
}
// Stop currently does nothing.
func (w *WalletController) Stop() error {
return nil
}
func (w *WalletController) FetchTx(chainhash.Hash) (*wire.MsgTx, error) {
return nil, nil
}
func (w *WalletController) RemoveDescendants(*wire.MsgTx) error {
return nil
}
func (w *WalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error {
return nil
}
// FetchDerivationInfo queries for the wallet's knowledge of the passed
// pkScript and constructs the derivation info and returns it.
func (w *WalletController) FetchDerivationInfo(
pkScript []byte) (*psbt.Bip32Derivation, error) {
return nil, nil
}
package wait
import (
"fmt"
"time"
)
// PollInterval is a constant specifying a 200 ms interval.
const PollInterval = 200 * time.Millisecond
// Predicate is a helper test function that will wait for a timeout period of
// time until the passed predicate returns true. This function is helpful as
// timing doesn't always line up well when running integration tests with
// several running lnd nodes. This function gives callers a way to assert that
// some property is upheld within a particular time frame.
//
// TODO(yy): build a counter here so we know how many times we've tried the
// `pred`.
func Predicate(pred func() bool, timeout time.Duration) error {
exitTimer := time.After(timeout)
result := make(chan bool, 1)
for {
<-time.After(PollInterval)
go func() {
result <- pred()
}()
// Each time we call the pred(), we expect a result to be
// returned otherwise it will timeout.
select {
case <-exitTimer:
return fmt.Errorf("predicate not satisfied after " +
"time out")
case succeed := <-result:
if succeed {
return nil
}
}
}
}
// NoError is a wrapper around Predicate that waits for the passed method f to
// execute without error, and returns the last error encountered if this doesn't
// happen within the timeout.
func NoError(f func() error, timeout time.Duration) error {
var predErr error
pred := func() bool {
if err := f(); err != nil {
predErr = err
return false
}
return true
}
// If f() doesn't succeed within the timeout, return the last
// encountered error.
if err := Predicate(pred, timeout); err != nil {
// Handle the case where the passed in method, f, hangs for the
// full timeout
if predErr == nil {
return fmt.Errorf("method did not return within the " +
"timeout")
}
return predErr
}
return nil
}
// Invariant is a helper test function that will wait for a timeout period of
// time, verifying that a statement remains true for the entire duration. This
// function is helpful as timing doesn't always line up well when running
// integration tests with several running lnd nodes. This function gives callers
// a way to assert that some property is maintained over a particular time
// frame.
func Invariant(statement func() bool, timeout time.Duration) error {
const pollInterval = 20 * time.Millisecond
exitTimer := time.After(timeout)
for {
<-time.After(pollInterval)
// Fail if the invariant is broken while polling.
if !statement() {
return fmt.Errorf("invariant broken before time out")
}
select {
case <-exitTimer:
return nil
default:
}
}
}
// InvariantNoError is a wrapper around Invariant that waits out the duration
// specified by timeout. It fails if the predicate ever returns an error during
// that time.
func InvariantNoError(f func() error, timeout time.Duration) error {
var predErr error
pred := func() bool {
if err := f(); err != nil {
predErr = err
return false
}
return true
}
if err := Invariant(pred, timeout); err != nil {
return predErr
}
return nil
}
package lntypes
import "fmt"
// ChannelParty is a type used to have an unambiguous description of which node
// is being referred to. This eliminates the need to describe as "local" or
// "remote" using bool.
type ChannelParty uint8
const (
// Local is a ChannelParty constructor that is used to refer to the
// node that is running.
Local ChannelParty = iota
// Remote is a ChannelParty constructor that is used to refer to the
// node on the other end of the peer connection.
Remote
)
// String provides a string representation of ChannelParty (useful for logging).
func (p ChannelParty) String() string {
switch p {
case Local:
return "Local"
case Remote:
return "Remote"
default:
panic(fmt.Sprintf("invalid ChannelParty value: %d", p))
}
}
// CounterParty inverts the role of the ChannelParty.
func (p ChannelParty) CounterParty() ChannelParty {
switch p {
case Local:
return Remote
case Remote:
return Local
default:
panic(fmt.Sprintf("invalid ChannelParty value: %v", p))
}
}
// IsLocal returns true if the ChannelParty is Local.
func (p ChannelParty) IsLocal() bool {
return p == Local
}
// IsRemote returns true if the ChannelParty is Remote.
func (p ChannelParty) IsRemote() bool {
return p == Remote
}
// Dual represents a structure when we are tracking the same parameter for both
// the Local and Remote parties.
type Dual[A any] struct {
// Local is the value tracked for the Local ChannelParty.
Local A
// Remote is the value tracked for the Remote ChannelParty.
Remote A
}
// GetForParty gives Dual an access method that takes a ChannelParty as an
// argument. It is included for ergonomics in cases where the ChannelParty is
// in a variable and which party determines how we want to access the Dual.
func (d *Dual[A]) GetForParty(p ChannelParty) A {
switch p {
case Local:
return d.Local
case Remote:
return d.Remote
default:
panic(fmt.Sprintf(
"switch default triggered in ForParty: %v", p,
))
}
}
// SetForParty sets the value in the Dual for the given ChannelParty. This
// returns a copy of the original value.
func (d *Dual[A]) SetForParty(p ChannelParty, value A) {
switch p {
case Local:
d.Local = value
case Remote:
d.Remote = value
default:
panic(fmt.Sprintf(
"switch default triggered in ForParty: %v", p,
))
}
}
// ModifyForParty applies the function argument to the given ChannelParty field
// and returns a new copy of the Dual.
func (d *Dual[A]) ModifyForParty(p ChannelParty, f func(A) A) A {
switch p {
case Local:
d.Local = f(d.Local)
return d.Local
case Remote:
d.Remote = f(d.Remote)
return d.Remote
default:
panic(fmt.Sprintf(
"switch default triggered in ForParty: %v", p,
))
}
}
// MapDual applies the function argument to both the Local and Remote fields of
// the Dual[A] and returns a Dual[B] with that function applied.
func MapDual[A, B any](d Dual[A], f func(A) B) Dual[B] {
return Dual[B]{
Local: f(d.Local),
Remote: f(d.Remote),
}
}
var BothParties []ChannelParty = []ChannelParty{Local, Remote}
package lntypes
import (
"encoding/hex"
"fmt"
)
// HashSize of array used to store hashes.
const HashSize = 32
// ZeroHash is a predefined hash containing all zeroes.
var ZeroHash Hash
// Hash is used in several of the lightning messages and common structures. It
// typically represents a payment hash.
type Hash [HashSize]byte
// String returns the Hash as a hexadecimal string.
func (hash Hash) String() string {
return hex.EncodeToString(hash[:])
}
// MakeHash returns a new Hash from a byte slice. An error is returned if
// the number of bytes passed in is not HashSize.
func MakeHash(newHash []byte) (Hash, error) {
nhlen := len(newHash)
if nhlen != HashSize {
return Hash{}, fmt.Errorf("invalid hash length of %v, want %v",
nhlen, HashSize)
}
var hash Hash
copy(hash[:], newHash)
return hash, nil
}
// MakeHashFromStr creates a Hash from a hex hash string.
func MakeHashFromStr(newHash string) (Hash, error) {
// Return error if hash string is of incorrect length.
if len(newHash) != HashSize*2 {
return Hash{}, fmt.Errorf("invalid hash string length of %v, "+
"want %v", len(newHash), HashSize*2)
}
hash, err := hex.DecodeString(newHash)
if err != nil {
return Hash{}, err
}
return MakeHash(hash)
}
package lntypes
import (
"crypto/sha256"
"encoding/hex"
"fmt"
)
// PreimageSize of array used to store preimagees.
const PreimageSize = 32
// Preimage is used in several of the lightning messages and common structures.
// It represents a payment preimage.
type Preimage [PreimageSize]byte
// String returns the Preimage as a hexadecimal string.
func (p Preimage) String() string {
return hex.EncodeToString(p[:])
}
// MakePreimage returns a new Preimage from a bytes slice. An error is returned
// if the number of bytes passed in is not PreimageSize.
func MakePreimage(newPreimage []byte) (Preimage, error) {
nhlen := len(newPreimage)
if nhlen != PreimageSize {
return Preimage{}, fmt.Errorf("invalid preimage length of %v, "+
"want %v", nhlen, PreimageSize)
}
var preimage Preimage
copy(preimage[:], newPreimage)
return preimage, nil
}
// MakePreimageFromStr creates a Preimage from a hex preimage string.
func MakePreimageFromStr(newPreimage string) (Preimage, error) {
// Return error if preimage string is of incorrect length.
if len(newPreimage) != PreimageSize*2 {
return Preimage{}, fmt.Errorf("invalid preimage string length "+
"of %v, want %v", len(newPreimage), PreimageSize*2)
}
preimage, err := hex.DecodeString(newPreimage)
if err != nil {
return Preimage{}, err
}
return MakePreimage(preimage)
}
// Hash returns the sha256 hash of the preimage.
func (p *Preimage) Hash() Hash {
return Hash(sha256.Sum256(p[:]))
}
// Matches returns whether this preimage is the preimage of the given hash.
func (p *Preimage) Matches(h Hash) bool {
return h == p.Hash()
}
package lntypes
import (
"fmt"
"math"
)
// WeightUnit defines a unit to express the transaction size. One weight unit
// is 1/4_000_000 of the max block size. The tx weight is calculated using
// `Base tx size * 3 + Total tx size`.
// - Base tx size is size of the transaction serialized without the witness
// data.
// - Total tx size is the transaction size in bytes serialized according
// #BIP144.
type WeightUnit uint64
// ToVB converts a value expressed in weight units to virtual bytes.
func (wu WeightUnit) ToVB() VByte {
// According to BIP141: Virtual transaction size is defined as
// Transaction weight / 4 (rounded up to the next integer).
return VByte(math.Ceil(float64(wu) / 4))
}
// String returns the string representation of the weight unit.
func (wu WeightUnit) String() string {
return fmt.Sprintf("%d wu", wu)
}
// VByte defines a unit to express the transaction size. One virtual byte is
// 1/4th of a weight unit. The tx virtual bytes is calculated using `TxWeight /
// 4`.
type VByte uint64
// ToWU converts a value expressed in virtual bytes to weight units.
func (vb VByte) ToWU() WeightUnit {
return WeightUnit(vb * 4)
}
// String returns the string representation of the virtual byte.
func (vb VByte) String() string {
return fmt.Sprintf("%d vb", vb)
}
package lnutils
import (
"fmt"
"time"
)
// RecvOrTimeout attempts to recv over chan c, returning the value. If the
// timeout passes before the recv succeeds, an error is returned.
func RecvOrTimeout[T any](c <-chan T, timeout time.Duration) (*T, error) {
select {
case m := <-c:
return &m, nil
case <-time.After(timeout):
return nil, fmt.Errorf("timeout hit")
}
}
package lnutils
import "errors"
// ErrorAs behaves the same as `errors.As` except there's no need to declare
// the target error as a variable first.
// Instead of writing:
//
// var targetErr *TargetErr
// errors.As(err, &targetErr)
//
// We can write:
//
// lnutils.ErrorAs[*TargetErr](err)
//
// To save us from declaring the target error variable.
func ErrorAs[Target error](err error) bool {
var targetErr Target
return errors.As(err, &targetErr)
}
package lnutils
import (
"errors"
"fmt"
"os"
)
// CreateDir creates a directory if it doesn't exist and also handles
// symlink-related errors with user-friendly messages. It creates all necessary
// parent directories with the specified permissions.
func CreateDir(dir string, perm os.FileMode) error {
err := os.MkdirAll(dir, perm)
if err == nil {
return nil
}
// Show a nicer error message if it's because a symlink
// is linked to a directory that does not exist
// (probably because it's not mounted).
var pathErr *os.PathError
if errors.As(err, &pathErr) && os.IsExist(err) {
link, lerr := os.Readlink(pathErr.Path)
if lerr == nil {
return fmt.Errorf("is symlink %s -> %s "+
"mounted?", pathErr.Path, link)
}
}
return fmt.Errorf("failed to create directory '%s': %w", dir, err)
}
package lnutils
import (
"log/slog"
"strings"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btclog/v2"
"github.com/davecgh/go-spew/spew"
)
// LogClosure is used to provide a closure over expensive logging operations so
// don't have to be performed when the logging level doesn't warrant it.
type LogClosure func() string
// String invokes the underlying function and returns the result.
func (c LogClosure) String() string {
return c()
}
// NewLogClosure returns a new closure over a function that returns a string
// which itself provides a Stringer interface so that it can be used with the
// logging system.
func NewLogClosure(c func() string) LogClosure {
return LogClosure(c)
}
// SpewLogClosure takes an interface and returns the string of it created from
// `spew.Sdump` in a LogClosure.
func SpewLogClosure(a any) LogClosure {
return func() string {
return spew.Sdump(a)
}
}
// NewSeparatorClosure returns a new closure that logs a separator line.
func NewSeparatorClosure() LogClosure {
return func() string {
return strings.Repeat("=", 80)
}
}
// LogPubKey returns a slog attribute for logging a public key in hex format.
func LogPubKey(key string, pubKey *btcec.PublicKey) slog.Attr {
// Handle nil pubkey gracefully, although callers should ideally prevent
// this.
if pubKey == nil {
return btclog.Fmt(key, "<nil>")
}
return btclog.Hex6(key, pubKey.SerializeCompressed())
}
package lnutils
// Ptr returns the pointer of the given value. This is useful in instances
// where a function returns the value, but a pointer is wanted. Without this,
// then an intermediate variable is needed.
func Ptr[T any](v T) *T {
return &v
}
// ByteArray is a type constraint for type that reduces down to a fixed sized
// array.
type ByteArray interface {
~[32]byte
}
// ByteSlice takes a byte array, and returns a slice. This is useful when a
// function returns an array, but a slice is wanted. Without this, then an
// intermediate variable is needed.
func ByteSlice[T ByteArray](v T) []byte {
return v[:]
}
package lnutils
// Map takes an input slice, and applies the function f to each element,
// yielding a new slice.
func Map[T1, T2 any](s []T1, f func(T1) T2) []T2 {
r := make([]T2, len(s))
for i, v := range s {
r[i] = f(v)
}
return r
}
package lnutils
import "sync"
// SyncMap wraps a sync.Map with type parameters such that it's easier to
// access the items stored in the map since no type assertion is needed. It
// also requires explicit type definition when declaring and initiating the
// variables, which helps us understanding what's stored in a given map.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// Store puts an item in the map.
func (m *SyncMap[K, V]) Store(key K, value V) {
m.Map.Store(key, value)
}
// Load queries an item from the map using the specified key. If the item
// cannot be found, an empty value and false will be returned. If the stored
// item fails the type assertion, a nil value and false will be returned.
func (m *SyncMap[K, V]) Load(key K) (V, bool) {
result, ok := m.Map.Load(key)
if !ok {
return *new(V), false // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Delete removes an item from the map specified by the key.
func (m *SyncMap[K, V]) Delete(key K) {
m.Map.Delete(key)
}
// LoadAndDelete queries an item and deletes it from the map using the
// specified key.
func (m *SyncMap[K, V]) LoadAndDelete(key K) (V, bool) {
result, loaded := m.Map.LoadAndDelete(key)
if !loaded {
return *new(V), loaded // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Range iterates the map and applies the `visitor` function. If the `visitor`
// returns false, the iteration will be stopped.
func (m *SyncMap[K, V]) Range(visitor func(K, V) bool) {
m.Map.Range(func(k any, v any) bool {
return visitor(k.(K), v.(V))
})
}
// ForEach iterates the map and applies the `visitor` function. Unlike the
// `Range` method, the `visitor` function will be applied to all the items
// unless there's an error.
func (m *SyncMap[K, V]) ForEach(visitor func(K, V) error) {
// rangeVisitor wraps the `visitor` function and returns false if
// there's an error returned from the `visitor` function.
rangeVisitor := func(k K, v V) bool {
if err := visitor(k, v); err != nil {
// Break the iteration if there's an error.
return false
}
return true
}
m.Range(rangeVisitor)
}
// Len returns the number of items in the map.
func (m *SyncMap[K, V]) Len() int {
var count int
m.Range(func(_ K, _ V) bool {
count++
return true
})
return count
}
// LoadOrStore queries an item from the map using the specified key. If the
// item cannot be found, the `value` will be stored in the map and returned.
// If the stored item fails the type assertion, a nil value and false will be
// returned.
func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (V, bool) {
result, loaded := m.Map.LoadOrStore(key, value)
item, ok := result.(V)
if !ok {
return *new(V), false
}
return item, loaded
}
package lnwallet
import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// CommitSortFunc is a function type alias for a function that sorts the
// commitment transaction outputs. The second parameter is a list of CLTV
// timeouts that must correspond to the number of transaction outputs, with the
// value of 0 for non-HTLC outputs. The HTLC indexes are needed to have a
// deterministic sort value for HTLCs that have the identical amount, CLTV
// timeout and payment hash (e.g. multiple MPP shards of the same payment, where
// the on-chain script would be identical).
type CommitSortFunc func(tx *wire.MsgTx, cltvs []uint32,
indexes []input.HtlcIndex) error
// DefaultCommitSort is the default commitment sort function that sorts the
// commitment transaction inputs and outputs according to BIP69. The second
// parameter is a list of CLTV timeouts that must correspond to the number of
// transaction outputs, with the value of 0 for non-HTLC outputs. The third
// parameter is unused for the default sort function.
func DefaultCommitSort(tx *wire.MsgTx, cltvs []uint32,
_ []input.HtlcIndex) error {
InPlaceCommitSort(tx, cltvs)
return nil
}
// CommitAuxLeaves stores two potential auxiliary leaves for the remote and
// local output that may be used to augment the final tapscript trees of the
// commitment transaction.
type CommitAuxLeaves struct {
// LocalAuxLeaf is the local party's auxiliary leaf.
LocalAuxLeaf input.AuxTapLeaf
// RemoteAuxLeaf is the remote party's auxiliary leaf.
RemoteAuxLeaf input.AuxTapLeaf
// OutgoingHTLCLeaves is the set of aux leaves for the outgoing HTLCs
// on this commitment transaction.
OutgoingHtlcLeaves input.HtlcAuxLeaves
// IncomingHTLCLeaves is the set of aux leaves for the incoming HTLCs
// on this commitment transaction.
IncomingHtlcLeaves input.HtlcAuxLeaves
}
// AuxChanState is a struct that holds certain fields of the
// channeldb.OpenChannel struct that are used by the aux components. The data
// is copied over to prevent accidental mutation of the original channel state.
type AuxChanState struct {
// ChanType denotes which type of channel this is.
ChanType channeldb.ChannelType
// FundingOutpoint is the outpoint of the final funding transaction.
// This value uniquely and globally identifies the channel within the
// target blockchain as specified by the chain hash parameter.
FundingOutpoint wire.OutPoint
// ShortChannelID encodes the exact location in the chain in which the
// channel was initially confirmed. This includes: the block height,
// transaction index, and the output within the target transaction.
//
// If IsZeroConf(), then this will the "base" (very first) ALIAS scid
// and the confirmed SCID will be stored in ConfirmedScid.
ShortChannelID lnwire.ShortChannelID
// IsInitiator is a bool which indicates if we were the original
// initiator for the channel. This value may affect how higher levels
// negotiate fees, or close the channel.
IsInitiator bool
// Capacity is the total capacity of this channel.
Capacity btcutil.Amount
// LocalChanCfg is the channel configuration for the local node.
LocalChanCfg channeldb.ChannelConfig
// RemoteChanCfg is the channel configuration for the remote node.
RemoteChanCfg channeldb.ChannelConfig
// ThawHeight is the height when a frozen channel once again becomes a
// normal channel. If this is zero, then there're no restrictions on
// this channel. If the value is lower than 500,000, then it's
// interpreted as a relative height, or an absolute height otherwise.
ThawHeight uint32
// TapscriptRoot is an optional tapscript root used to derive the MuSig2
// funding output.
TapscriptRoot fn.Option[chainhash.Hash]
// CustomBlob is an optional blob that can be used to store information
// specific to a custom channel type. This information is only created
// at channel funding time, and after wards is to be considered
// immutable.
CustomBlob fn.Option[tlv.Blob]
}
// NewAuxChanState creates a new AuxChanState from the given channel state.
func NewAuxChanState(chanState *channeldb.OpenChannel) AuxChanState {
return AuxChanState{
ChanType: chanState.ChanType,
FundingOutpoint: chanState.FundingOutpoint,
ShortChannelID: chanState.ShortChannelID,
IsInitiator: chanState.IsInitiator,
Capacity: chanState.Capacity,
LocalChanCfg: chanState.LocalChanCfg,
RemoteChanCfg: chanState.RemoteChanCfg,
ThawHeight: chanState.ThawHeight,
TapscriptRoot: chanState.TapscriptRoot,
CustomBlob: chanState.CustomBlob,
}
}
// CommitDiffAuxInput is the input required to compute the diff of the auxiliary
// leaves for a commitment transaction.
type CommitDiffAuxInput struct {
// ChannelState is the static channel information of the channel this
// commitment transaction relates to.
ChannelState AuxChanState
// PrevBlob is the blob of the previous commitment transaction.
PrevBlob tlv.Blob
// UnfilteredView is the unfiltered, original HTLC view of the channel.
// Unfiltered in this context means that the view contains all HTLCs,
// including the canceled ones.
UnfilteredView AuxHtlcView
// WhoseCommit denotes whose commitment transaction we are computing the
// diff for.
WhoseCommit lntypes.ChannelParty
// OurBalance is the balance of the local party.
OurBalance lnwire.MilliSatoshi
// TheirBalance is the balance of the remote party.
TheirBalance lnwire.MilliSatoshi
// KeyRing is the key ring that can be used to derive keys for the
// commitment transaction.
KeyRing CommitmentKeyRing
}
// CommitDiffAuxResult is the result of computing the diff of the auxiliary
// leaves for a commitment transaction.
type CommitDiffAuxResult struct {
// AuxLeaves are the auxiliary leaves for the new commitment
// transaction.
AuxLeaves fn.Option[CommitAuxLeaves]
// CommitSortFunc is an optional function that sorts the commitment
// transaction inputs and outputs.
CommitSortFunc fn.Option[CommitSortFunc]
}
// AuxLeafStore is used to optionally fetch auxiliary tapscript leaves for the
// commitment transaction given an opaque blob. This is also used to implement
// a state transition function for the blobs to allow them to be refreshed with
// each state.
type AuxLeafStore interface {
// FetchLeavesFromView attempts to fetch the auxiliary leaves that
// correspond to the passed aux blob, and pending original (unfiltered)
// HTLC view.
FetchLeavesFromView(
in CommitDiffAuxInput) fn.Result[CommitDiffAuxResult]
// FetchLeavesFromCommit attempts to fetch the auxiliary leaves that
// correspond to the passed aux blob, and an existing channel
// commitment.
FetchLeavesFromCommit(chanState AuxChanState,
commit channeldb.ChannelCommitment, keyRing CommitmentKeyRing,
whoseCommit lntypes.ChannelParty) fn.Result[CommitDiffAuxResult]
// FetchLeavesFromRevocation attempts to fetch the auxiliary leaves
// from a channel revocation that stores balance + blob information.
FetchLeavesFromRevocation(
r *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult]
// ApplyHtlcView serves as the state transition function for the custom
// channel's blob. Given the old blob, and an HTLC view, then a new
// blob should be returned that reflects the pending updates.
ApplyHtlcView(in CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]]
}
// auxLeavesFromView is used to derive the set of commit aux leaves (if any),
// that are needed to create a new commitment transaction using the original
// (unfiltered) htlc view.
func auxLeavesFromView(leafStore AuxLeafStore, chanState *channeldb.OpenChannel,
prevBlob fn.Option[tlv.Blob], originalView *HtlcView,
whoseCommit lntypes.ChannelParty, ourBalance,
theirBalance lnwire.MilliSatoshi,
keyRing CommitmentKeyRing) fn.Result[CommitDiffAuxResult] {
return fn.MapOptionZ(
prevBlob, func(blob tlv.Blob) fn.Result[CommitDiffAuxResult] {
return leafStore.FetchLeavesFromView(CommitDiffAuxInput{
ChannelState: NewAuxChanState(chanState),
PrevBlob: blob,
UnfilteredView: newAuxHtlcView(originalView),
WhoseCommit: whoseCommit,
OurBalance: ourBalance,
TheirBalance: theirBalance,
KeyRing: keyRing,
})
},
)
}
// updateAuxBlob is a helper function that attempts to update the aux blob
// given the prior and current state information.
func updateAuxBlob(leafStore AuxLeafStore, chanState *channeldb.OpenChannel,
prevBlob fn.Option[tlv.Blob], nextViewUnfiltered *HtlcView,
whoseCommit lntypes.ChannelParty, ourBalance,
theirBalance lnwire.MilliSatoshi,
keyRing CommitmentKeyRing) fn.Result[fn.Option[tlv.Blob]] {
return fn.MapOptionZ(
prevBlob, func(blob tlv.Blob) fn.Result[fn.Option[tlv.Blob]] {
return leafStore.ApplyHtlcView(CommitDiffAuxInput{
ChannelState: NewAuxChanState(chanState),
PrevBlob: blob,
UnfilteredView: newAuxHtlcView(
nextViewUnfiltered,
),
WhoseCommit: whoseCommit,
OurBalance: ourBalance,
TheirBalance: theirBalance,
KeyRing: keyRing,
})
},
)
}
package lnwallet
import (
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// htlcCustomSigType is the TLV type that is used to encode the custom HTLC
// signatures within the custom data for an existing HTLC.
var htlcCustomSigType tlv.TlvType65543
// AuxHtlcView is a struct that contains a safe copy of an HTLC view that can
// be used by aux components.
type AuxHtlcView struct {
// NextHeight is the height of the commitment transaction that will be
// created using this view.
NextHeight uint64
// Updates is a Dual of the Local and Remote HTLCs.
Updates lntypes.Dual[[]AuxHtlcDescriptor]
// FeePerKw is the fee rate in sat/kw of the commitment transaction.
FeePerKw chainfee.SatPerKWeight
}
// newAuxHtlcView creates a new safe copy of the HTLC view that can be used by
// aux components.
//
// NOTE: This function should only be called while holding the channel's read
// lock, since the underlying local/remote payment descriptors are accessed
// directly.
func newAuxHtlcView(v *HtlcView) AuxHtlcView {
return AuxHtlcView{
NextHeight: v.NextHeight,
Updates: lntypes.Dual[[]AuxHtlcDescriptor]{
Local: fn.Map(v.Updates.Local, newAuxHtlcDescriptor),
Remote: fn.Map(v.Updates.Remote, newAuxHtlcDescriptor),
},
FeePerKw: v.FeePerKw,
}
}
// AuxHtlcDescriptor is a struct that contains the information needed to sign or
// verify an HTLC for custom channels.
type AuxHtlcDescriptor struct {
// ChanID is the ChannelID of the LightningChannel that this
// paymentDescriptor belongs to. We track this here so we can
// reconstruct the Messages that this paymentDescriptor is built from.
ChanID lnwire.ChannelID
// RHash is the payment hash for this HTLC. The HTLC can be settled iff
// the preimage to this hash is presented.
RHash PaymentHash
// Timeout is the absolute timeout in blocks, after which this HTLC
// expires.
Timeout uint32
// Amount is the HTLC amount in milli-satoshis.
Amount lnwire.MilliSatoshi
// HtlcIndex is the index within the main update log for this HTLC.
// Entries within the log of type Add will have this field populated,
// as other entries will point to the entry via this counter.
//
// NOTE: This field will only be populated if EntryType is Add.
HtlcIndex uint64
// ParentIndex is the HTLC index of the entry that this update settles
// or times out.
//
// NOTE: This field will only be populated if EntryType is Fail or
// Settle.
ParentIndex uint64
// EntryType denotes the exact type of the paymentDescriptor. In the
// case of a Timeout, or Settle type, then the Parent field will point
// into the log to the HTLC being modified.
EntryType updateType
// CustomRecords also stores the set of optional custom records that
// may have been attached to a sent HTLC.
CustomRecords lnwire.CustomRecords
// addCommitHeight[Remote|Local] encodes the height of the commitment
// which included this HTLC on either the remote or local commitment
// chain. This value is used to determine when an HTLC is fully
// "locked-in".
addCommitHeightRemote uint64
addCommitHeightLocal uint64
// removeCommitHeight[Remote|Local] encodes the height of the
// commitment which removed the parent pointer of this
// paymentDescriptor either due to a timeout or a settle. Once both
// these heights are below the tail of both chains, the log entries can
// safely be removed.
removeCommitHeightRemote uint64
removeCommitHeightLocal uint64
}
// AddHeight returns the height at which the HTLC was added to the commitment
// chain. The height is returned based on the chain the HTLC is being added to
// (local or remote chain).
func (a *AuxHtlcDescriptor) AddHeight(
whoseCommitChain lntypes.ChannelParty) uint64 {
if whoseCommitChain.IsRemote() {
return a.addCommitHeightRemote
}
return a.addCommitHeightLocal
}
// RemoveHeight returns the height at which the HTLC was removed from the
// commitment chain. The height is returned based on the chain the HTLC is being
// removed from (local or remote chain).
func (a *AuxHtlcDescriptor) RemoveHeight(
whoseCommitChain lntypes.ChannelParty) uint64 {
if whoseCommitChain.IsRemote() {
return a.removeCommitHeightRemote
}
return a.removeCommitHeightLocal
}
// newAuxHtlcDescriptor creates a new AuxHtlcDescriptor from a payment
// descriptor.
//
// NOTE: This function should only be called while holding the channel's read
// lock, since the underlying payment descriptors are accessed directly.
func newAuxHtlcDescriptor(p *paymentDescriptor) AuxHtlcDescriptor {
return AuxHtlcDescriptor{
ChanID: p.ChanID,
RHash: p.RHash,
Timeout: p.Timeout,
Amount: p.Amount,
HtlcIndex: p.HtlcIndex,
ParentIndex: p.ParentIndex,
EntryType: p.EntryType,
CustomRecords: p.CustomRecords.Copy(),
addCommitHeightRemote: p.addCommitHeights.Remote,
addCommitHeightLocal: p.addCommitHeights.Local,
removeCommitHeightRemote: p.removeCommitHeights.Remote,
removeCommitHeightLocal: p.removeCommitHeights.Local,
}
}
// BaseAuxJob is a struct that contains the common fields that are shared among
// the aux sign/verify jobs.
type BaseAuxJob struct {
// OutputIndex is the output index of the HTLC on the commitment
// transaction being signed.
//
// NOTE: If the output is dust from the PoV of the commitment chain,
// then this value will be -1.
OutputIndex int32
// KeyRing is the commitment key ring that contains the keys needed to
// generate the second level HTLC signatures.
KeyRing CommitmentKeyRing
// HTLC is the HTLC that is being signed or verified.
HTLC AuxHtlcDescriptor
// Incoming is a boolean that indicates if the HTLC is incoming or
// outgoing.
Incoming bool
// CommitBlob is the commitment transaction blob that contains the aux
// information for this channel.
CommitBlob fn.Option[tlv.Blob]
// HtlcLeaf is the aux tap leaf that corresponds to the HTLC being
// signed/verified.
HtlcLeaf input.AuxTapLeaf
}
// AuxSigJob is a struct that contains all the information needed to sign an
// HTLC for custom channels.
type AuxSigJob struct {
// SignDesc is the sign desc for this HTLC.
SignDesc input.SignDescriptor
BaseAuxJob
// Resp is a channel that will be used to send the result of the sign
// job. This channel MUST be buffered.
Resp chan AuxSigJobResp
// Cancel is a channel that is closed by the caller if they wish to
// abandon all pending sign jobs part of a single batch. This should
// never be closed by the validator.
Cancel <-chan struct{}
}
// NewAuxSigJob creates a new AuxSigJob.
func NewAuxSigJob(sigJob SignJob, keyRing CommitmentKeyRing, incoming bool,
htlc AuxHtlcDescriptor, commitBlob fn.Option[tlv.Blob],
htlcLeaf input.AuxTapLeaf, cancelChan <-chan struct{}) AuxSigJob {
return AuxSigJob{
SignDesc: sigJob.SignDesc,
BaseAuxJob: BaseAuxJob{
OutputIndex: sigJob.OutputIndex,
KeyRing: keyRing,
HTLC: htlc,
Incoming: incoming,
CommitBlob: commitBlob,
HtlcLeaf: htlcLeaf,
},
Resp: make(chan AuxSigJobResp, 1),
Cancel: cancelChan,
}
}
// AuxSigJobResp is a struct that contains the result of a sign job.
type AuxSigJobResp struct {
// SigBlob is the signature blob that was generated for the HTLC. This
// is an opaque TLV field that may contain the signature and other data.
SigBlob fn.Option[tlv.Blob]
// HtlcIndex is the index of the HTLC that was signed.
HtlcIndex uint64
// Err is the error that occurred when executing the specified
// signature job. In the case that no error occurred, this value will
// be nil.
Err error
}
// AuxVerifyJob is a struct that contains all the information needed to verify
// an HTLC for custom channels.
type AuxVerifyJob struct {
// SigBlob is the signature blob that was generated for the HTLC. This
// is an opaque TLV field that may contain the signature and other data.
SigBlob fn.Option[tlv.Blob]
BaseAuxJob
}
// NewAuxVerifyJob creates a new AuxVerifyJob.
func NewAuxVerifyJob(sig fn.Option[tlv.Blob], keyRing CommitmentKeyRing,
incoming bool, htlc AuxHtlcDescriptor, commitBlob fn.Option[tlv.Blob],
htlcLeaf input.AuxTapLeaf) AuxVerifyJob {
return AuxVerifyJob{
SigBlob: sig,
BaseAuxJob: BaseAuxJob{
KeyRing: keyRing,
HTLC: htlc,
Incoming: incoming,
CommitBlob: commitBlob,
HtlcLeaf: htlcLeaf,
},
}
}
// AuxSigner is an interface that is used to sign and verify HTLCs for custom
// channels. It is similar to the existing SigPool, but uses opaque blobs to
// shuffle around signature information and other metadata.
type AuxSigner interface {
// SubmitSecondLevelSigBatch takes a batch of aux sign jobs and
// processes them asynchronously.
SubmitSecondLevelSigBatch(chanState AuxChanState, commitTx *wire.MsgTx,
sigJob []AuxSigJob) error
// PackSigs takes a series of aux signatures and packs them into a
// single blob that can be sent alongside the CommitSig messages.
PackSigs([]fn.Option[tlv.Blob]) fn.Result[fn.Option[tlv.Blob]]
// UnpackSigs takes a packed blob of signatures and returns the
// original signatures for each HTLC, keyed by HTLC index.
UnpackSigs(fn.Option[tlv.Blob]) fn.Result[[]fn.Option[tlv.Blob]]
// VerifySecondLevelSigs attempts to synchronously verify a batch of aux
// sig jobs.
VerifySecondLevelSigs(chanState AuxChanState, commitTx *wire.MsgTx,
verifyJob []AuxVerifyJob) error
}
package chainfee
import (
"encoding/json"
"errors"
"fmt"
"io"
"math"
prand "math/rand"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/rpcclient"
"github.com/lightningnetwork/lnd/lnutils"
)
const (
// MaxBlockTarget is the highest number of blocks confirmations that
// a WebAPIEstimator will cache fees for. This number is chosen
// because it's the highest number of confs bitcoind will return a fee
// estimate for.
MaxBlockTarget uint32 = 1008
// minBlockTarget is the lowest number of blocks confirmations that
// a WebAPIEstimator will cache fees for. Requesting an estimate for
// less than this will result in an error.
minBlockTarget uint32 = 1
// WebAPIConnectionTimeout specifies the timeout value for connecting
// to the api source.
WebAPIConnectionTimeout = 5 * time.Second
// WebAPIResponseTimeout specifies the timeout value for receiving a
// fee response from the api source.
WebAPIResponseTimeout = 10 * time.Second
// economicalFeeMode is a mode that bitcoind uses to serve
// non-conservative fee estimates. These fee estimates are less
// resistant to shocks.
economicalFeeMode = "ECONOMICAL"
// filterCapConfTarget is the conf target that will be used to cap our
// minimum feerate if we used the median of our peers' feefilter
// values.
filterCapConfTarget = uint32(1)
)
var (
// errNoFeeRateFound is used when a given conf target cannot be found
// from the fee estimator.
errNoFeeRateFound = errors.New("no fee estimation for block target")
// errEmptyCache is used when the fee rate cache is empty.
errEmptyCache = errors.New("fee rate cache is empty")
)
// Estimator provides the ability to estimate on-chain transaction fees for
// various combinations of transaction sizes and desired confirmation time
// (measured by number of blocks).
type Estimator interface {
// EstimateFeePerKW takes in a target for the number of blocks until an
// initial confirmation and returns the estimated fee expressed in
// sat/kw.
EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, error)
// Start signals the Estimator to start any processes or goroutines
// it needs to perform its duty.
Start() error
// Stop stops any spawned goroutines and cleans up the resources used
// by the fee estimator.
Stop() error
// RelayFeePerKW returns the minimum fee rate required for transactions
// to be relayed. This is also the basis for calculation of the dust
// limit.
RelayFeePerKW() SatPerKWeight
}
// StaticEstimator will return a static value for all fee calculation requests.
// It is designed to be replaced by a proper fee calculation implementation.
// The fees are not accessible directly, because changing them would not be
// thread safe.
type StaticEstimator struct {
// feePerKW is the static fee rate in satoshis-per-vbyte that will be
// returned by this fee estimator.
feePerKW SatPerKWeight
// relayFee is the minimum fee rate required for transactions to be
// relayed.
relayFee SatPerKWeight
}
// NewStaticEstimator returns a new static fee estimator instance.
func NewStaticEstimator(feePerKW, relayFee SatPerKWeight) *StaticEstimator {
return &StaticEstimator{
feePerKW: feePerKW,
relayFee: relayFee,
}
}
// EstimateFeePerKW will return a static value for fee calculations.
//
// NOTE: This method is part of the Estimator interface.
func (e StaticEstimator) EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, error) {
return e.feePerKW, nil
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed.
//
// NOTE: This method is part of the Estimator interface.
func (e StaticEstimator) RelayFeePerKW() SatPerKWeight {
return e.relayFee
}
// Start signals the Estimator to start any processes or goroutines
// it needs to perform its duty.
//
// NOTE: This method is part of the Estimator interface.
func (e StaticEstimator) Start() error {
return nil
}
// Stop stops any spawned goroutines and cleans up the resources used
// by the fee estimator.
//
// NOTE: This method is part of the Estimator interface.
func (e StaticEstimator) Stop() error {
return nil
}
// A compile-time assertion to ensure that StaticFeeEstimator implements the
// Estimator interface.
var _ Estimator = (*StaticEstimator)(nil)
// BtcdEstimator is an implementation of the Estimator interface backed
// by the RPC interface of an active btcd node. This implementation will proxy
// any fee estimation requests to btcd's RPC interface.
type BtcdEstimator struct {
// fallbackFeePerKW is the fall back fee rate in sat/kw that is returned
// if the fee estimator does not yet have enough data to actually
// produce fee estimates.
fallbackFeePerKW SatPerKWeight
// minFeeManager is used to query the current minimum fee, in sat/kw,
// that we should enforce. This will be used to determine fee rate for
// a transaction when the estimated fee rate is too low to allow the
// transaction to propagate through the network.
minFeeManager *minFeeManager
btcdConn *rpcclient.Client
// filterManager uses our peer's feefilter values to determine a
// suitable feerate to use that will allow successful transaction
// propagation.
filterManager *filterManager
}
// NewBtcdEstimator creates a new BtcdEstimator given a fully populated
// rpc config that is able to successfully connect and authenticate with the
// btcd node, and also a fall back fee rate. The fallback fee rate is used in
// the occasion that the estimator has insufficient data, or returns zero for a
// fee estimate.
func NewBtcdEstimator(rpcConfig rpcclient.ConnConfig,
fallBackFeeRate SatPerKWeight) (*BtcdEstimator, error) {
rpcConfig.DisableConnectOnNew = true
rpcConfig.DisableAutoReconnect = false
chainConn, err := rpcclient.New(&rpcConfig, nil)
if err != nil {
return nil, err
}
fetchCb := func() ([]SatPerKWeight, error) {
return fetchBtcdFilters(chainConn)
}
return &BtcdEstimator{
fallbackFeePerKW: fallBackFeeRate,
btcdConn: chainConn,
filterManager: newFilterManager(fetchCb),
}, nil
}
// Start signals the Estimator to start any processes or goroutines
// it needs to perform its duty.
//
// NOTE: This method is part of the Estimator interface.
func (b *BtcdEstimator) Start() error {
if err := b.btcdConn.Connect(20); err != nil {
return err
}
// Once the connection to the backend node has been established, we
// can initialise the minimum relay fee manager which queries the
// chain backend for the minimum relay fee on construction.
minRelayFeeManager, err := newMinFeeManager(
defaultUpdateInterval, b.fetchMinRelayFee,
)
if err != nil {
return err
}
b.minFeeManager = minRelayFeeManager
b.filterManager.Start()
return nil
}
// fetchMinRelayFee fetches and returns the minimum relay fee in sat/kb from
// the btcd backend.
func (b *BtcdEstimator) fetchMinRelayFee() (SatPerKWeight, error) {
info, err := b.btcdConn.GetInfo()
if err != nil {
return 0, err
}
relayFee, err := btcutil.NewAmount(info.RelayFee)
if err != nil {
return 0, err
}
// The fee rate is expressed in sat/kb, so we'll manually convert it to
// our desired sat/kw rate.
return SatPerKVByte(relayFee).FeePerKWeight(), nil
}
// Stop stops any spawned goroutines and cleans up the resources used
// by the fee estimator.
//
// NOTE: This method is part of the Estimator interface.
func (b *BtcdEstimator) Stop() error {
b.filterManager.Stop()
b.btcdConn.Shutdown()
b.btcdConn.WaitForShutdown()
return nil
}
// EstimateFeePerKW takes in a target for the number of blocks until an initial
// confirmation and returns the estimated fee expressed in sat/kw.
//
// NOTE: This method is part of the Estimator interface.
func (b *BtcdEstimator) EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, error) {
feeEstimate, err := b.fetchEstimate(numBlocks)
switch {
// If the estimator doesn't have enough data, or returns an error, then
// to return a proper value, then we'll return the default fall back
// fee rate.
case err != nil:
log.Errorf("unable to query estimator: %v", err)
fallthrough
case feeEstimate == 0:
return b.fallbackFeePerKW, nil
}
return feeEstimate, nil
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed.
//
// NOTE: This method is part of the Estimator interface.
func (b *BtcdEstimator) RelayFeePerKW() SatPerKWeight {
// Get a suitable minimum feerate to use. This may optionally use the
// median of our peers' feefilter values.
feeCapClosure := func() (SatPerKWeight, error) {
return b.fetchEstimateInner(filterCapConfTarget)
}
return chooseMinFee(
b.minFeeManager.fetchMinFee, b.filterManager.FetchMedianFilter,
feeCapClosure,
)
}
// fetchEstimate returns a fee estimate for a transaction to be confirmed in
// confTarget blocks. The estimate is returned in sat/kw.
func (b *BtcdEstimator) fetchEstimate(confTarget uint32) (SatPerKWeight, error) {
satPerKw, err := b.fetchEstimateInner(confTarget)
if err != nil {
return 0, err
}
// Finally, we'll enforce our fee floor by choosing the higher of the
// minimum relay fee and the feerate returned by the filterManager.
absoluteMinFee := b.RelayFeePerKW()
if satPerKw < absoluteMinFee {
log.Debugf("Estimated fee rate of %v sat/kw is too low, "+
"using fee floor of %v sat/kw instead", satPerKw,
absoluteMinFee)
satPerKw = absoluteMinFee
}
log.Debugf("Returning %v sat/kw for conf target of %v",
int64(satPerKw), confTarget)
return satPerKw, nil
}
func (b *BtcdEstimator) fetchEstimateInner(confTarget uint32) (SatPerKWeight,
error) {
// First, we'll fetch the estimate for our confirmation target.
btcPerKB, err := b.btcdConn.EstimateFee(int64(confTarget))
if err != nil {
return 0, err
}
// Next, we'll convert the returned value to satoshis, as it's
// currently returned in BTC.
satPerKB, err := btcutil.NewAmount(btcPerKB)
if err != nil {
return 0, err
}
// Since we use fee rates in sat/kw internally, we'll convert the
// estimated fee rate from its sat/kb representation to sat/kw.
return SatPerKVByte(satPerKB).FeePerKWeight(), nil
}
// A compile-time assertion to ensure that BtcdEstimator implements the
// Estimator interface.
var _ Estimator = (*BtcdEstimator)(nil)
// BitcoindEstimator is an implementation of the Estimator interface backed by
// the RPC interface of an active bitcoind node. This implementation will proxy
// any fee estimation requests to bitcoind's RPC interface.
type BitcoindEstimator struct {
// fallbackFeePerKW is the fallback fee rate in sat/kw that is returned
// if the fee estimator does not yet have enough data to actually
// produce fee estimates.
fallbackFeePerKW SatPerKWeight
// minFeeManager is used to keep track of the minimum fee, in sat/kw,
// that we should enforce. This will be used as the default fee rate
// for a transaction when the estimated fee rate is too low to allow
// the transaction to propagate through the network.
minFeeManager *minFeeManager
// feeMode is the estimate_mode to use when calling "estimatesmartfee".
// It can be either "ECONOMICAL" or "CONSERVATIVE", and it's default
// to "CONSERVATIVE".
feeMode string
// TODO(ziggie): introduce an interface for the client to enhance
// testability of the estimator.
bitcoindConn *rpcclient.Client
// filterManager uses our peer's feefilter values to determine a
// suitable feerate to use that will allow successful transaction
// propagation.
filterManager *filterManager
}
// NewBitcoindEstimator creates a new BitcoindEstimator given a fully populated
// rpc config that is able to successfully connect and authenticate with the
// bitcoind node, and also a fall back fee rate. The fallback fee rate is used
// in the occasion that the estimator has insufficient data, or returns zero
// for a fee estimate.
func NewBitcoindEstimator(rpcConfig rpcclient.ConnConfig, feeMode string,
fallBackFeeRate SatPerKWeight) (*BitcoindEstimator, error) {
rpcConfig.DisableConnectOnNew = true
rpcConfig.DisableAutoReconnect = false
rpcConfig.DisableTLS = true
rpcConfig.HTTPPostMode = true
chainConn, err := rpcclient.New(&rpcConfig, nil)
if err != nil {
return nil, err
}
fetchCb := func() ([]SatPerKWeight, error) {
return fetchBitcoindFilters(chainConn)
}
return &BitcoindEstimator{
fallbackFeePerKW: fallBackFeeRate,
bitcoindConn: chainConn,
feeMode: feeMode,
filterManager: newFilterManager(fetchCb),
}, nil
}
// Start signals the Estimator to start any processes or goroutines
// it needs to perform its duty.
//
// NOTE: This method is part of the Estimator interface.
func (b *BitcoindEstimator) Start() error {
// Once the connection to the backend node has been established, we'll
// initialise the minimum relay fee manager which will query
// the backend node for its minimum mempool fee.
relayFeeManager, err := newMinFeeManager(
defaultUpdateInterval,
b.fetchMinMempoolFee,
)
if err != nil {
return err
}
b.minFeeManager = relayFeeManager
b.filterManager.Start()
return nil
}
// fetchMinMempoolFee is used to fetch the minimum fee that the backend node
// requires for a tx to enter its mempool. The returned fee will be the
// maximum of the minimum relay fee and the minimum mempool fee.
func (b *BitcoindEstimator) fetchMinMempoolFee() (SatPerKWeight, error) {
resp, err := b.bitcoindConn.RawRequest("getmempoolinfo", nil)
if err != nil {
return 0, err
}
// Parse the response to retrieve the min mempool fee in sat/KB.
// mempoolminfee is the maximum of minrelaytxfee and
// minimum mempool fee
info := struct {
MempoolMinFee float64 `json:"mempoolminfee"`
}{}
if err := json.Unmarshal(resp, &info); err != nil {
return 0, err
}
minMempoolFee, err := btcutil.NewAmount(info.MempoolMinFee)
if err != nil {
return 0, err
}
// The fee rate is expressed in sat/kb, so we'll manually convert it to
// our desired sat/kw rate.
return SatPerKVByte(minMempoolFee).FeePerKWeight(), nil
}
// Stop stops any spawned goroutines and cleans up the resources used
// by the fee estimator.
//
// NOTE: This method is part of the Estimator interface.
func (b *BitcoindEstimator) Stop() error {
b.filterManager.Stop()
return nil
}
// EstimateFeePerKW takes in a target for the number of blocks until an initial
// confirmation and returns the estimated fee expressed in sat/kw.
//
// NOTE: This method is part of the Estimator interface.
func (b *BitcoindEstimator) EstimateFeePerKW(
numBlocks uint32) (SatPerKWeight, error) {
if numBlocks > MaxBlockTarget {
log.Debugf("conf target %d exceeds the max value, "+
"use %d instead.", numBlocks, MaxBlockTarget,
)
numBlocks = MaxBlockTarget
}
feeEstimate, err := b.fetchEstimate(numBlocks, b.feeMode)
switch {
// If the estimator doesn't have enough data, or returns an error, then
// to return a proper value, then we'll return the default fall back
// fee rate.
case err != nil:
log.Errorf("unable to query estimator: %v", err)
fallthrough
case feeEstimate == 0:
return b.fallbackFeePerKW, nil
}
return feeEstimate, nil
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed.
//
// NOTE: This method is part of the Estimator interface.
func (b *BitcoindEstimator) RelayFeePerKW() SatPerKWeight {
// Get a suitable minimum feerate to use. This may optionally use the
// median of our peers' feefilter values.
feeCapClosure := func() (SatPerKWeight, error) {
return b.fetchEstimateInner(
filterCapConfTarget, economicalFeeMode,
)
}
return chooseMinFee(
b.minFeeManager.fetchMinFee, b.filterManager.FetchMedianFilter,
feeCapClosure,
)
}
// fetchEstimate returns a fee estimate for a transaction to be confirmed in
// confTarget blocks. The estimate is returned in sat/kw.
func (b *BitcoindEstimator) fetchEstimate(confTarget uint32, feeMode string) (
SatPerKWeight, error) {
satPerKw, err := b.fetchEstimateInner(confTarget, feeMode)
if err != nil {
return 0, err
}
// Finally, we'll enforce our fee floor by choosing the higher of the
// minimum relay fee and the feerate returned by the filterManager.
absoluteMinFee := b.RelayFeePerKW()
if satPerKw < absoluteMinFee {
log.Debugf("Estimated fee rate of %v sat/kw is too low, "+
"using fee floor of %v sat/kw instead", satPerKw,
absoluteMinFee)
satPerKw = absoluteMinFee
}
log.Debugf("Returning %v sat/kw for conf target of %v",
int64(satPerKw), confTarget)
return satPerKw, nil
}
func (b *BitcoindEstimator) fetchEstimateInner(confTarget uint32,
feeMode string) (SatPerKWeight, error) {
// First, we'll send an "estimatesmartfee" command as a raw request,
// since it isn't supported by btcd but is available in bitcoind.
target, err := json.Marshal(uint64(confTarget))
if err != nil {
return 0, err
}
// The mode must be either ECONOMICAL or CONSERVATIVE.
mode, err := json.Marshal(feeMode)
if err != nil {
return 0, err
}
resp, err := b.bitcoindConn.RawRequest(
"estimatesmartfee", []json.RawMessage{target, mode},
)
if err != nil {
return 0, err
}
// Next, we'll parse the response to get the BTC per KB.
feeEstimate := struct {
FeeRate float64 `json:"feerate"`
}{}
err = json.Unmarshal(resp, &feeEstimate)
if err != nil {
return 0, err
}
// Next, we'll convert the returned value to satoshis, as it's currently
// returned in BTC.
satPerKB, err := btcutil.NewAmount(feeEstimate.FeeRate)
if err != nil {
return 0, err
}
// Bitcoind will not report any fee estimation if it has not enough
// data available hence the fee will remain zero. We return an error
// here to make sure that we do not use the min relay fee instead.
if satPerKB == 0 {
return 0, fmt.Errorf("fee estimation data not available yet")
}
// Since we use fee rates in sat/kw internally, we'll convert the
// estimated fee rate from its sat/kb representation to sat/kw.
return SatPerKVByte(satPerKB).FeePerKWeight(), nil
}
// chooseMinFee takes the minimum relay fee and the median of our peers'
// feefilter values and takes the higher of the two. It then compares the value
// against a maximum fee and caps it if the value is higher than the maximum
// fee. This function is only called if we have data for our peers' feefilter.
// The returned value will be used as the fee floor for calls to
// RelayFeePerKW.
func chooseMinFee(minRelayFeeFunc func() SatPerKWeight,
medianFilterFunc func() (SatPerKWeight, error),
feeCapFunc func() (SatPerKWeight, error)) SatPerKWeight {
minRelayFee := minRelayFeeFunc()
medianFilter, err := medianFilterFunc()
if err != nil {
// If we don't have feefilter data, we fallback to using our
// minimum relay fee.
return minRelayFee
}
feeCap, err := feeCapFunc()
if err != nil {
// If we encountered an error, don't use the medianFilter and
// instead fallback to using our minimum relay fee.
return minRelayFee
}
// If the median feefilter is higher than our minimum relay fee, use it
// instead.
if medianFilter > minRelayFee {
// Only apply the cap if the median filter was used. This is
// to prevent an adversary from taking up the majority of our
// outbound peer slots and forcing us to use a high median
// filter value.
if medianFilter > feeCap {
return feeCap
}
return medianFilter
}
return minRelayFee
}
// A compile-time assertion to ensure that BitcoindEstimator implements the
// Estimator interface.
var _ Estimator = (*BitcoindEstimator)(nil)
// WebAPIFeeSource is an interface allows the WebAPIEstimator to query an
// arbitrary HTTP-based fee estimator. Each new set/network will gain an
// implementation of this interface in order to allow the WebAPIEstimator to
// be fully generic in its logic.
type WebAPIFeeSource interface {
// GetFeeInfo will query the web API, parse the response into a
// WebAPIResponse which contains a map of confirmation targets to
// sat/kw fees and min relay feerate.
GetFeeInfo() (WebAPIResponse, error)
}
// SparseConfFeeSource is an implementation of the WebAPIFeeSource that utilizes
// a user-specified fee estimation API for Bitcoin. It expects the response
// to be in the JSON format: `fee_by_block_target: { ... }` where the value maps
// block targets to fee estimates (in sat per kilovbyte).
type SparseConfFeeSource struct {
// URL is the fee estimation API specified by the user.
URL string
}
// WebAPIResponse is the response returned by the fee estimation API.
type WebAPIResponse struct {
// FeeByBlockTarget is a map of confirmation targets to sat/kvb fees.
FeeByBlockTarget map[uint32]uint32 `json:"fee_by_block_target"`
// MinRelayFeerate is the minimum relay fee in sat/kvb.
MinRelayFeerate SatPerKVByte `json:"min_relay_feerate"`
}
// parseResponse attempts to parse the body of the response generated by the
// above query URL. Typically this will be JSON, but the specifics are left to
// the WebAPIFeeSource implementation.
func (s SparseConfFeeSource) parseResponse(r io.Reader) (
WebAPIResponse, error) {
resp := WebAPIResponse{
FeeByBlockTarget: make(map[uint32]uint32),
MinRelayFeerate: 0,
}
jsonReader := json.NewDecoder(r)
if err := jsonReader.Decode(&resp); err != nil {
return WebAPIResponse{}, err
}
if resp.MinRelayFeerate == 0 {
log.Errorf("No min relay fee rate available, using default %v",
FeePerKwFloor)
resp.MinRelayFeerate = FeePerKwFloor.FeePerKVByte()
}
return resp, nil
}
// GetFeeInfo will query the web API, parse the response and return a map of
// confirmation targets to sat/kw fees and min relay feerate in a parsed
// response.
func (s SparseConfFeeSource) GetFeeInfo() (WebAPIResponse, error) {
// Rather than use the default http.Client, we'll make a custom one
// which will allow us to control how long we'll wait to read the
// response from the service. This way, if the service is down or
// overloaded, we can exit early and use our default fee.
netTransport := &http.Transport{
Dial: (&net.Dialer{
Timeout: WebAPIConnectionTimeout,
}).Dial,
TLSHandshakeTimeout: WebAPIConnectionTimeout,
}
netClient := &http.Client{
Timeout: WebAPIResponseTimeout,
Transport: netTransport,
}
// With the client created, we'll query the API source to fetch the URL
// that we should use to query for the fee estimation.
targetURL := s.URL
resp, err := netClient.Get(targetURL)
if err != nil {
log.Errorf("unable to query web api for fee response: %v",
err)
return WebAPIResponse{}, err
}
defer resp.Body.Close()
// Once we've obtained the response, we'll instruct the WebAPIFeeSource
// to parse out the body to obtain our final result.
parsedResp, err := s.parseResponse(resp.Body)
if err != nil {
log.Errorf("unable to parse fee api response: %v", err)
return WebAPIResponse{}, err
}
return parsedResp, nil
}
// A compile-time assertion to ensure that SparseConfFeeSource implements the
// WebAPIFeeSource interface.
var _ WebAPIFeeSource = (*SparseConfFeeSource)(nil)
// WebAPIEstimator is an implementation of the Estimator interface that
// queries an HTTP-based fee estimation from an existing web API.
type WebAPIEstimator struct {
started atomic.Bool
stopped atomic.Bool
// apiSource is the backing web API source we'll use for our queries.
apiSource WebAPIFeeSource
// updateFeeTicker is the ticker responsible for updating the Estimator's
// fee estimates every time it fires.
updateFeeTicker *time.Ticker
// feeByBlockTarget is our cache for fees pulled from the API. When a
// fee estimate request comes in, we pull the estimate from this array
// rather than re-querying the API, to prevent an inadvertent DoS attack.
feesMtx sync.Mutex
feeByBlockTarget map[uint32]uint32
minRelayFeerate SatPerKVByte
// noCache determines whether the web estimator should cache fee
// estimates.
noCache bool
// minFeeUpdateTimeout represents the minimum interval in which the
// web estimator will request fresh fees from its API.
minFeeUpdateTimeout time.Duration
// minFeeUpdateTimeout represents the maximum interval in which the
// web estimator will request fresh fees from its API.
maxFeeUpdateTimeout time.Duration
quit chan struct{}
wg sync.WaitGroup
}
// NewWebAPIEstimator creates a new WebAPIEstimator from a given URL and a
// fallback default fee. The fees are updated whenever a new block is mined.
func NewWebAPIEstimator(api WebAPIFeeSource, noCache bool,
minFeeUpdateTimeout time.Duration,
maxFeeUpdateTimeout time.Duration) (*WebAPIEstimator, error) {
if minFeeUpdateTimeout == 0 || maxFeeUpdateTimeout == 0 {
return nil, fmt.Errorf("minFeeUpdateTimeout and " +
"maxFeeUpdateTimeout must be greater than 0")
}
if minFeeUpdateTimeout >= maxFeeUpdateTimeout {
return nil, fmt.Errorf("minFeeUpdateTimeout target of %v "+
"cannot be greater than maxFeeUpdateTimeout of %v",
minFeeUpdateTimeout, maxFeeUpdateTimeout)
}
return &WebAPIEstimator{
apiSource: api,
feeByBlockTarget: make(map[uint32]uint32),
noCache: noCache,
quit: make(chan struct{}),
minFeeUpdateTimeout: minFeeUpdateTimeout,
maxFeeUpdateTimeout: maxFeeUpdateTimeout,
}, nil
}
// EstimateFeePerKW takes in a target for the number of blocks until an initial
// confirmation and returns the estimated fee expressed in sat/kw.
//
// NOTE: This method is part of the Estimator interface.
func (w *WebAPIEstimator) EstimateFeePerKW(numBlocks uint32) (
SatPerKWeight, error) {
// If the estimator hasn't been started yet, we'll return an error as
// we can't provide a fee estimate.
if !w.started.Load() {
return 0, fmt.Errorf("estimator not started")
}
if numBlocks > MaxBlockTarget {
numBlocks = MaxBlockTarget
} else if numBlocks < minBlockTarget {
return 0, fmt.Errorf("conf target of %v is too low, minimum "+
"accepted is %v", numBlocks, minBlockTarget)
}
// Get fee estimates now if we don't refresh periodically.
if w.noCache {
w.updateFeeEstimates()
}
feePerKb, err := w.getCachedFee(numBlocks)
// If the estimator returns an error, a zero value fee rate will be
// returned. We will log the error and return the fall back fee rate
// instead.
if err != nil {
log.Errorf("Unable to query estimator: %v", err)
}
// If the result is too low, then we'll clamp it to our current fee
// floor.
satPerKw := SatPerKVByte(feePerKb).FeePerKWeight()
if satPerKw < FeePerKwFloor {
satPerKw = FeePerKwFloor
}
log.Debugf("Web API returning %v sat/kw for conf target of %v",
int64(satPerKw), numBlocks)
return satPerKw, nil
}
// Start signals the Estimator to start any processes or goroutines it needs
// to perform its duty.
//
// NOTE: This method is part of the Estimator interface.
func (w *WebAPIEstimator) Start() error {
log.Infof("Starting Web API fee estimator...")
// Return an error if it's already been started.
if w.started.Load() {
return fmt.Errorf("web API fee estimator already started")
}
defer w.started.Store(true)
// During startup we'll query the API to initialize the fee map.
w.updateFeeEstimates()
// No update loop is needed when we don't cache.
if w.noCache {
return nil
}
feeUpdateTimeout := w.randomFeeUpdateTimeout()
log.Infof("Web API fee estimator using update timeout of %v",
feeUpdateTimeout)
w.updateFeeTicker = time.NewTicker(feeUpdateTimeout)
w.wg.Add(1)
go w.feeUpdateManager()
return nil
}
// Stop stops any spawned goroutines and cleans up the resources used by the
// fee estimator.
//
// NOTE: This method is part of the Estimator interface.
func (w *WebAPIEstimator) Stop() error {
log.Infof("Stopping web API fee estimator")
if w.stopped.Swap(true) {
return fmt.Errorf("web API fee estimator already stopped")
}
// Update loop is not running when we don't cache.
if w.noCache {
return nil
}
if w.updateFeeTicker != nil {
w.updateFeeTicker.Stop()
}
close(w.quit)
w.wg.Wait()
return nil
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed.
//
// NOTE: This method is part of the Estimator interface.
func (w *WebAPIEstimator) RelayFeePerKW() SatPerKWeight {
if !w.started.Load() {
log.Error("WebAPIEstimator not started")
}
// Get fee estimates now if we don't refresh periodically.
if w.noCache {
w.updateFeeEstimates()
}
log.Infof("Web API returning %v for min relay feerate",
w.minRelayFeerate)
return w.minRelayFeerate.FeePerKWeight()
}
// randomFeeUpdateTimeout returns a random timeout between minFeeUpdateTimeout
// and maxFeeUpdateTimeout that will be used to determine how often the Estimator
// should retrieve fresh fees from its API.
func (w *WebAPIEstimator) randomFeeUpdateTimeout() time.Duration {
lower := int64(w.minFeeUpdateTimeout)
upper := int64(w.maxFeeUpdateTimeout)
return time.Duration(
prand.Int63n(upper-lower) + lower, //nolint:gosec
).Round(time.Second)
}
// getCachedFee takes a conf target and returns the cached fee rate. When the
// fee rate cannot be found, it will search the cache by decrementing the conf
// target until a fee rate is found. If still not found, it will return the fee
// rate of the minimum conf target cached, in other words, the most expensive
// fee rate it knows of.
func (w *WebAPIEstimator) getCachedFee(numBlocks uint32) (uint32, error) {
w.feesMtx.Lock()
defer w.feesMtx.Unlock()
// If the cache is empty, return an error.
if len(w.feeByBlockTarget) == 0 {
return 0, fmt.Errorf("web API error: %w", errEmptyCache)
}
// Search the conf target from the cache. We expect a query to the web
// API has been made and the result has been cached at this point.
fee, ok := w.feeByBlockTarget[numBlocks]
// If the conf target can be found, exit early.
if ok {
return fee, nil
}
// The conf target cannot be found. We will first search the cache
// using a lower conf target. This is a conservative approach as the
// fee rate returned will be larger than what's requested.
for target := numBlocks; target >= minBlockTarget; target-- {
fee, ok := w.feeByBlockTarget[target]
if !ok {
continue
}
log.Warnf("Web API does not have a fee rate for target=%d, "+
"using the fee rate for target=%d instead",
numBlocks, target)
// Return the fee rate found, which will be more expensive than
// requested. We will not cache the fee rate here in the hope
// that the web API will later populate this value.
return fee, nil
}
// There are no lower conf targets cached, which is likely when the
// requested conf target is 1. We will search the cache using a higher
// conf target, which gives a fee rate that's cheaper than requested.
//
// NOTE: we can only get here iff the requested conf target is smaller
// than the minimum conf target cached, so we return the minimum conf
// target from the cache.
minTargetCached := uint32(math.MaxUint32)
for target := range w.feeByBlockTarget {
if target < minTargetCached {
minTargetCached = target
}
}
fee, ok = w.feeByBlockTarget[minTargetCached]
if !ok {
// We should never get here, just a vanity check.
return 0, fmt.Errorf("web API error: %w, conf target: %d",
errNoFeeRateFound, numBlocks)
}
// Log an error instead of a warning as a cheaper fee rate may delay
// the confirmation for some important transactions.
log.Errorf("Web API does not have a fee rate for target=%d, "+
"using the fee rate for target=%d instead",
numBlocks, minTargetCached)
return fee, nil
}
// updateFeeEstimates re-queries the API for fresh fees and caches them.
func (w *WebAPIEstimator) updateFeeEstimates() {
// Once we've obtained the response, we'll instruct the WebAPIFeeSource
// to parse out the body to obtain our final result.
resp, err := w.apiSource.GetFeeInfo()
if err != nil {
log.Errorf("unable to get fee response: %v", err)
return
}
log.Debugf("Received response from source: %s", lnutils.NewLogClosure(
func() string {
resp, _ := json.Marshal(resp)
return string(resp)
}))
w.feesMtx.Lock()
w.feeByBlockTarget = resp.FeeByBlockTarget
w.minRelayFeerate = resp.MinRelayFeerate
w.feesMtx.Unlock()
}
// feeUpdateManager updates the fee estimates whenever a new block comes in.
func (w *WebAPIEstimator) feeUpdateManager() {
defer w.wg.Done()
for {
select {
case <-w.updateFeeTicker.C:
w.updateFeeEstimates()
case <-w.quit:
return
}
}
}
// A compile-time assertion to ensure that WebAPIEstimator implements the
// Estimator interface.
var _ Estimator = (*WebAPIEstimator)(nil)
package chainfee
import (
"encoding/json"
"fmt"
"sort"
"sync"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/rpcclient"
"github.com/lightningnetwork/lnd/fn/v2"
)
const (
// fetchFilterInterval is the interval between successive fetches of
// our peers' feefilters.
fetchFilterInterval = time.Minute * 5
// minNumFilters is the minimum number of feefilters we need during a
// polling interval. If we have fewer than this, we won't consider the
// data.
minNumFilters = 6
)
var (
// errNoData is an error that's returned if fetchMedianFilter is
// called and there is no data available.
errNoData = fmt.Errorf("no data available")
)
// filterManager is responsible for determining an acceptable minimum fee to
// use based on our peers' feefilter values.
type filterManager struct {
// median stores the median of our outbound peer's feefilter values.
median fn.Option[SatPerKWeight]
medianMtx sync.RWMutex
fetchFunc func() ([]SatPerKWeight, error)
wg sync.WaitGroup
quit chan struct{}
}
// newFilterManager constructs a filterManager with a callback that fetches the
// set of peers' feefilters.
func newFilterManager(cb func() ([]SatPerKWeight, error)) *filterManager {
f := &filterManager{
quit: make(chan struct{}),
}
f.fetchFunc = cb
return f
}
// Start starts the filterManager.
func (f *filterManager) Start() {
f.wg.Add(1)
go f.fetchPeerFilters()
}
// Stop stops the filterManager.
func (f *filterManager) Stop() {
close(f.quit)
f.wg.Wait()
}
// fetchPeerFilters fetches our peers' feefilter values and calculates the
// median.
func (f *filterManager) fetchPeerFilters() {
defer f.wg.Done()
filterTicker := time.NewTicker(fetchFilterInterval)
defer filterTicker.Stop()
for {
select {
case <-filterTicker.C:
filters, err := f.fetchFunc()
if err != nil {
log.Errorf("Encountered err while fetching "+
"fee filters: %v", err)
return
}
f.updateMedian(filters)
case <-f.quit:
return
}
}
}
// fetchMedianFilter fetches the median feefilter value.
func (f *filterManager) FetchMedianFilter() (SatPerKWeight, error) {
f.medianMtx.RLock()
defer f.medianMtx.RUnlock()
// If there is no median, return errNoData so the caller knows to
// ignore the output and continue.
return f.median.UnwrapOrErr(errNoData)
}
type bitcoindPeerInfoResp struct {
Inbound bool `json:"inbound"`
MinFeeFilter float64 `json:"minfeefilter"`
}
func fetchBitcoindFilters(client *rpcclient.Client) ([]SatPerKWeight, error) {
resp, err := client.RawRequest("getpeerinfo", nil)
if err != nil {
return nil, err
}
var peerResps []bitcoindPeerInfoResp
err = json.Unmarshal(resp, &peerResps)
if err != nil {
return nil, err
}
// We filter for outbound peers since it is harder for an attacker to
// be our outbound peer and therefore harder for them to manipulate us
// into broadcasting high-fee or low-fee transactions.
var outboundPeerFilters []SatPerKWeight
for _, peerResp := range peerResps {
if peerResp.Inbound {
continue
}
// The value that bitcoind returns for the "minfeefilter" field
// is in fractions of a bitcoin that represents the satoshis
// per KvB. We need to convert this fraction to whole satoshis
// by multiplying with COIN. Then we need to convert the
// sats/KvB to sats/KW.
//
// Convert the sats/KvB from fractions of a bitcoin to whole
// satoshis.
filterKVByte := SatPerKVByte(
peerResp.MinFeeFilter * btcutil.SatoshiPerBitcoin,
)
if !isWithinBounds(filterKVByte) {
continue
}
// Convert KvB to KW and add it to outboundPeerFilters.
outboundPeerFilters = append(
outboundPeerFilters, filterKVByte.FeePerKWeight(),
)
}
// Check that we have enough data to use. We don't return an error as
// that would stop the querying goroutine.
if len(outboundPeerFilters) < minNumFilters {
return nil, nil
}
return outboundPeerFilters, nil
}
func fetchBtcdFilters(client *rpcclient.Client) ([]SatPerKWeight, error) {
resp, err := client.GetPeerInfo()
if err != nil {
return nil, err
}
var outboundPeerFilters []SatPerKWeight
for _, peerResp := range resp {
// We filter for outbound peers since it is harder for an
// attacker to be our outbound peer and therefore harder for
// them to manipulate us into broadcasting high-fee or low-fee
// transactions.
if peerResp.Inbound {
continue
}
// The feefilter is already in units of sat/KvB.
filter := SatPerKVByte(peerResp.FeeFilter)
if !isWithinBounds(filter) {
continue
}
outboundPeerFilters = append(
outboundPeerFilters, filter.FeePerKWeight(),
)
}
// Check that we have enough data to use. We don't return an error as
// that would stop the querying goroutine.
if len(outboundPeerFilters) < minNumFilters {
return nil, nil
}
return outboundPeerFilters, nil
}
// updateMedian takes a slice of feefilter values, computes the median, and
// updates our stored median value.
func (f *filterManager) updateMedian(feeFilters []SatPerKWeight) {
// If there are no elements, don't update.
numElements := len(feeFilters)
if numElements == 0 {
return
}
f.medianMtx.Lock()
defer f.medianMtx.Unlock()
// Log the new median.
median := med(feeFilters)
f.median = fn.Some(median)
log.Debugf("filterManager updated moving median to: %v",
median.FeePerKVByte())
}
// isWithinBounds returns false if the filter is unusable and true if it is.
func isWithinBounds(filter SatPerKVByte) bool {
// Ignore values of 0 and MaxSatoshi. A value of 0 likely means that
// the peer hasn't sent over a feefilter and a value of MaxSatoshi
// means the peer is using bitcoind and is in IBD.
switch filter {
case 0:
return false
case btcutil.MaxSatoshi:
return false
}
return true
}
// med calculates the median of a slice.
// NOTE: Passing in an empty slice will panic!
func med(f []SatPerKWeight) SatPerKWeight {
// Copy the original slice so that sorting doesn't modify the original.
fCopy := make([]SatPerKWeight, len(f))
copy(fCopy, f)
sort.Slice(fCopy, func(i, j int) bool {
return fCopy[i] < fCopy[j]
})
var median SatPerKWeight
numElements := len(fCopy)
switch numElements % 2 {
case 0:
// There's an even number of elements, so we need to average.
middle := numElements / 2
upper := fCopy[middle]
lower := fCopy[middle-1]
median = (upper + lower) / 2
case 1:
median = fCopy[numElements/2]
}
return median
}
package chainfee
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("CFEE", nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package chainfee
import (
"sync"
"time"
)
const defaultUpdateInterval = 10 * time.Minute
// minFeeManager is used to store and update the minimum fee that is required
// by a transaction to be accepted to the mempool. The minFeeManager ensures
// that the backend used to fetch the fee is not queried too regularly.
type minFeeManager struct {
mu sync.Mutex
minFeePerKW SatPerKWeight
lastUpdatedTime time.Time
minUpdateInterval time.Duration
fetchFeeFunc fetchFee
}
// fetchFee represents a function that can be used to fetch a fee.
type fetchFee func() (SatPerKWeight, error)
// newMinFeeManager creates a new minFeeManager and uses the
// given fetchMinFee function to set the minFeePerKW of the minFeeManager.
// This function requires the fetchMinFee function to succeed.
func newMinFeeManager(minUpdateInterval time.Duration,
fetchMinFee fetchFee) (*minFeeManager, error) {
minFee, err := fetchMinFee()
if err != nil {
return nil, err
}
// Ensure that the minimum fee we use is always clamped by our fee
// floor.
if minFee < FeePerKwFloor {
minFee = FeePerKwFloor
}
return &minFeeManager{
minFeePerKW: minFee,
lastUpdatedTime: time.Now(),
minUpdateInterval: minUpdateInterval,
fetchFeeFunc: fetchMinFee,
}, nil
}
// fetchMinFee returns the stored minFeePerKW if it has been updated recently
// or if the call to the chain backend fails. Otherwise, it sets the stored
// minFeePerKW to the fee returned from the backend and floors it based on
// our fee floor.
func (m *minFeeManager) fetchMinFee() SatPerKWeight {
m.mu.Lock()
defer m.mu.Unlock()
if time.Since(m.lastUpdatedTime) < m.minUpdateInterval {
return m.minFeePerKW
}
newMinFee, err := m.fetchFeeFunc()
if err != nil {
log.Errorf("Unable to fetch updated min fee from chain "+
"backend. Using last known min fee instead: %v", err)
return m.minFeePerKW
}
// By default, we'll use the backend node's minimum fee as the
// minimum fee rate we'll propose for transactions. However, if this
// happens to be lower than our fee floor, we'll enforce that instead.
m.minFeePerKW = newMinFee
if m.minFeePerKW < FeePerKwFloor {
m.minFeePerKW = FeePerKwFloor
}
m.lastUpdatedTime = time.Now()
log.Debugf("Using minimum fee rate of %v sat/kw",
int64(m.minFeePerKW))
return m.minFeePerKW
}
package chainfee
import (
"github.com/stretchr/testify/mock"
)
type mockFeeSource struct {
mock.Mock
}
// A compile-time assertion to ensure that mockFeeSource implements the
// WebAPIFeeSource interface.
var _ WebAPIFeeSource = (*mockFeeSource)(nil)
func (m *mockFeeSource) GetFeeInfo() (WebAPIResponse, error) {
args := m.Called()
return args.Get(0).(WebAPIResponse), args.Error(1)
}
// MockEstimator implements the `Estimator` interface and is used by
// other packages for mock testing.
type MockEstimator struct {
mock.Mock
}
// Compile time assertion that MockEstimator implements Estimator.
var _ Estimator = (*MockEstimator)(nil)
// EstimateFeePerKW takes in a target for the number of blocks until an initial
// confirmation and returns the estimated fee expressed in sat/kw.
func (m *MockEstimator) EstimateFeePerKW(
numBlocks uint32) (SatPerKWeight, error) {
args := m.Called(numBlocks)
if args.Get(0) == nil {
return 0, args.Error(1)
}
return args.Get(0).(SatPerKWeight), args.Error(1)
}
// Start signals the Estimator to start any processes or goroutines it needs to
// perform its duty.
func (m *MockEstimator) Start() error {
args := m.Called()
return args.Error(0)
}
// Stop stops any spawned goroutines and cleans up the resources used by the
// fee estimator.
func (m *MockEstimator) Stop() error {
args := m.Called()
return args.Error(0)
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed. This is also the basis for calculation of the dust limit.
func (m *MockEstimator) RelayFeePerKW() SatPerKWeight {
args := m.Called()
return args.Get(0).(SatPerKWeight)
}
package chainfee
import (
"fmt"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lntypes"
)
const (
// FeePerKwFloor is the lowest fee rate in sat/kw that we should use for
// estimating transaction fees before signing.
FeePerKwFloor SatPerKWeight = 253
// AbsoluteFeePerKwFloor is the lowest fee rate in sat/kw of a
// transaction that we should ever _create_. This is the equivalent
// of 1 sat/byte in sat/kw.
AbsoluteFeePerKwFloor SatPerKWeight = 250
)
// SatPerVByte represents a fee rate in sat/vbyte.
type SatPerVByte btcutil.Amount
// FeePerKWeight converts the current fee rate from sat/vb to sat/kw.
func (s SatPerVByte) FeePerKWeight() SatPerKWeight {
return SatPerKWeight(s * 1000 / blockchain.WitnessScaleFactor)
}
// FeePerKVByte converts the current fee rate from sat/vb to sat/kvb.
func (s SatPerVByte) FeePerKVByte() SatPerKVByte {
return SatPerKVByte(s * 1000)
}
// String returns a human-readable string of the fee rate.
func (s SatPerVByte) String() string {
return fmt.Sprintf("%v sat/vb", int64(s))
}
// SatPerKVByte represents a fee rate in sat/kb.
type SatPerKVByte btcutil.Amount
// FeeForVSize calculates the fee resulting from this fee rate and the given
// vsize in vbytes.
func (s SatPerKVByte) FeeForVSize(vbytes lntypes.VByte) btcutil.Amount {
return btcutil.Amount(s) * btcutil.Amount(vbytes) / 1000
}
// FeePerKWeight converts the current fee rate from sat/kb to sat/kw.
func (s SatPerKVByte) FeePerKWeight() SatPerKWeight {
return SatPerKWeight(s / blockchain.WitnessScaleFactor)
}
// String returns a human-readable string of the fee rate.
func (s SatPerKVByte) String() string {
return fmt.Sprintf("%v sat/kvb", int64(s))
}
// SatPerKWeight represents a fee rate in sat/kw.
type SatPerKWeight btcutil.Amount
// NewSatPerKWeight creates a new fee rate in sat/kw.
func NewSatPerKWeight(fee btcutil.Amount, wu lntypes.WeightUnit) SatPerKWeight {
return SatPerKWeight(fee.MulF64(1000 / float64(wu)))
}
// FeeForWeight calculates the fee resulting from this fee rate and the given
// weight in weight units (wu).
func (s SatPerKWeight) FeeForWeight(wu lntypes.WeightUnit) btcutil.Amount {
// The resulting fee is rounded down, as specified in BOLT#03.
return btcutil.Amount(s) * btcutil.Amount(wu) / 1000
}
// FeeForVByte calculates the fee resulting from this fee rate and the given
// size in vbytes (vb).
func (s SatPerKWeight) FeeForVByte(vb lntypes.VByte) btcutil.Amount {
return s.FeePerKVByte().FeeForVSize(vb)
}
// FeePerKVByte converts the current fee rate from sat/kw to sat/kb.
func (s SatPerKWeight) FeePerKVByte() SatPerKVByte {
return SatPerKVByte(s * blockchain.WitnessScaleFactor)
}
// FeePerVByte converts the current fee rate from sat/kw to sat/vb.
func (s SatPerKWeight) FeePerVByte() SatPerVByte {
return SatPerVByte(s * blockchain.WitnessScaleFactor / 1000)
}
// String returns a human-readable string of the fee rate.
func (s SatPerKWeight) String() string {
return fmt.Sprintf("%v sat/kw", int64(s))
}
package chanfunding
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
// NewShimIntent creates a new ShimIntent. This is only used for testing.
func NewShimIntent(localAmt, remoteAmt btcutil.Amount,
localKey *keychain.KeyDescriptor, remoteKey *btcec.PublicKey,
chanPoint *wire.OutPoint, thawHeight uint32, musig2 bool) *ShimIntent {
return &ShimIntent{
localFundingAmt: localAmt,
remoteFundingAmt: remoteAmt,
localKey: localKey,
remoteKey: remoteKey,
chanPoint: chanPoint,
thawHeight: thawHeight,
musig2: musig2,
}
}
// ShimIntent is an intent created by the CannedAssembler which represents a
// funding output to be created that was constructed outside the wallet. This
// might be used when a hardware wallet, or a channel factory is the entity
// crafting the funding transaction, and not lnd.
type ShimIntent struct {
// localFundingAmt is the final amount we put into the funding output.
localFundingAmt btcutil.Amount
// remoteFundingAmt is the final amount the remote party put into the
// funding output.
remoteFundingAmt btcutil.Amount
// localKey is our multi-sig key.
localKey *keychain.KeyDescriptor
// remoteKey is the remote party's multi-sig key.
remoteKey *btcec.PublicKey
// chanPoint is the final channel point for the to be created channel.
chanPoint *wire.OutPoint
// thawHeight, if non-zero is the height where this channel will become
// a normal channel. Until this height, it's considered frozen, so it
// can only be cooperatively closed by the responding party.
thawHeight uint32
// musig2 determines if the funding output should use musig2 to
// generate an aggregate key to use as the taproot-native multi-sig
// output.
musig2 bool
// tapscriptRoot is the root of the tapscript tree that will be used to
// create the funding output. This field will only be utilized if the
// MuSig2 flag above is set to true.
//
// TODO(roasbeef): fold above into new chan type? sum type like thing,
// includes the tapscript root, etc
tapscriptRoot fn.Option[chainhash.Hash]
}
// FundingOutput returns the witness script, and the output that creates the
// funding output.
//
// NOTE: This method satisfies the chanfunding.Intent interface.
func (s *ShimIntent) FundingOutput() ([]byte, *wire.TxOut, error) {
if s.localKey == nil || s.remoteKey == nil {
return nil, nil, fmt.Errorf("unable to create witness " +
"script, no funding keys")
}
totalAmt := s.localFundingAmt + s.remoteFundingAmt
// If musig2 is active, then we'll return a single aggregated key
// rather than using the "existing" funding script.
if s.musig2 {
// Similar to the existing p2wsh script, we'll always ensure
// the keys are sorted before use.
return input.GenTaprootFundingScript(
s.localKey.PubKey, s.remoteKey, int64(totalAmt),
s.tapscriptRoot,
)
}
return input.GenFundingPkScript(
s.localKey.PubKey.SerializeCompressed(),
s.remoteKey.SerializeCompressed(),
int64(totalAmt),
)
}
// TaprootInternalKey may return the internal key for a MuSig2 funding output,
// but only if this is actually a MuSig2 channel.
func (s *ShimIntent) TaprootInternalKey() fn.Option[*btcec.PublicKey] {
if !s.musig2 {
return fn.None[*btcec.PublicKey]()
}
// Similar to the existing p2wsh script, we'll always ensure the keys
// are sorted before use. Since we're only interested in the internal
// key, we don't need to take into account any tapscript root.
//
// We ignore the error here as this is only called after FundingOutput
// is called.
combinedKey, _, _, _ := musig2.AggregateKeys(
[]*btcec.PublicKey{s.localKey.PubKey, s.remoteKey}, true,
)
return fn.Some(combinedKey.PreTweakedKey)
}
// Cancel allows the caller to cancel a funding Intent at any time. This will
// return any resources such as coins back to the eligible pool to be used in
// order channel fundings.
//
// NOTE: This method satisfies the chanfunding.Intent interface.
func (s *ShimIntent) Cancel() {
}
// LocalFundingAmt is the amount we put into the channel. This may differ from
// the local amount requested, as depending on coin selection, we may bleed
// from of that LocalAmt into fees to minimize change.
//
// NOTE: This method satisfies the chanfunding.Intent interface.
func (s *ShimIntent) LocalFundingAmt() btcutil.Amount {
return s.localFundingAmt
}
// RemoteFundingAmt is the amount the remote party put into the channel.
//
// NOTE: This method satisfies the chanfunding.Intent interface.
func (s *ShimIntent) RemoteFundingAmt() btcutil.Amount {
return s.remoteFundingAmt
}
// ChanPoint returns the final outpoint that will create the funding output
// described above.
//
// NOTE: This method satisfies the chanfunding.Intent interface.
func (s *ShimIntent) ChanPoint() (*wire.OutPoint, error) {
if s.chanPoint == nil {
return nil, fmt.Errorf("chan point unknown, funding output " +
"not constructed")
}
return s.chanPoint, nil
}
// ThawHeight returns the height where this channel goes back to being a normal
// channel.
func (s *ShimIntent) ThawHeight() uint32 {
return s.thawHeight
}
// Inputs returns all inputs to the final funding transaction that we
// know about. For the ShimIntent this will always be none, since it is funded
// externally.
func (s *ShimIntent) Inputs() []wire.OutPoint {
return nil
}
// Outputs returns all outputs of the final funding transaction that we
// know about. Since this is an externally funded channel, the channel output
// is the only known one.
func (s *ShimIntent) Outputs() []*wire.TxOut {
_, txOut, err := s.FundingOutput()
if err != nil {
log.Warnf("Unable to find funding output for shim intent: %v",
err)
// Failed finding funding output, return empty list of known
// outputs.
return nil
}
return []*wire.TxOut{txOut}
}
// FundingKeys couples our multi-sig key along with the remote party's key.
type FundingKeys struct {
// LocalKey is our multi-sig key.
LocalKey *keychain.KeyDescriptor
// RemoteKey is the multi-sig key of the remote party.
RemoteKey *btcec.PublicKey
}
// MultiSigKeys returns the committed multi-sig keys, but only if they've been
// specified/provided.
func (s *ShimIntent) MultiSigKeys() (*FundingKeys, error) {
if s.localKey == nil || s.remoteKey == nil {
return nil, fmt.Errorf("unknown funding keys")
}
return &FundingKeys{
LocalKey: s.localKey,
RemoteKey: s.remoteKey,
}, nil
}
// A compile-time check to ensure ShimIntent adheres to the Intent interface.
var _ Intent = (*ShimIntent)(nil)
// CannedAssembler is a type of chanfunding.Assembler wherein the funding
// transaction is constructed outside of lnd, and may already exist. This
// Assembler serves as a shim which gives the funding flow the only thing it
// actually needs to proceed: the channel point.
type CannedAssembler struct {
// fundingAmt is the total amount of coins in the funding output.
fundingAmt btcutil.Amount
// localKey is our multi-sig key.
localKey *keychain.KeyDescriptor
// remoteKey is the remote party's multi-sig key.
remoteKey *btcec.PublicKey
// chanPoint is the final channel point for the to be created channel.
chanPoint wire.OutPoint
// initiator indicates if we're the initiator or the channel or not.
initiator bool
// thawHeight, if non-zero is the height where this channel will become
// a normal channel. Until this height, it's considered frozen, so it
// can only be cooperatively closed by the responding party.
thawHeight uint32
// musig2 determines if the funding output should use musig2 to
// generate an aggregate key to use as the taproot-native multi-sig
// output.
musig2 bool
}
// NewCannedAssembler creates a new CannedAssembler from the material required
// to construct a funding output and channel point.
//
// TODO(roasbeef): pass in chan type instead?
func NewCannedAssembler(thawHeight uint32, chanPoint wire.OutPoint,
fundingAmt btcutil.Amount, localKey *keychain.KeyDescriptor,
remoteKey *btcec.PublicKey, initiator, musig2 bool) *CannedAssembler {
return &CannedAssembler{
initiator: initiator,
localKey: localKey,
remoteKey: remoteKey,
fundingAmt: fundingAmt,
chanPoint: chanPoint,
thawHeight: thawHeight,
musig2: musig2,
}
}
// ProvisionChannel creates a new ShimIntent given the passed funding Request.
// The returned intent is immediately able to provide the channel point and
// funding output as they've already been created outside lnd.
//
// NOTE: This method satisfies the chanfunding.Assembler interface.
func (c *CannedAssembler) ProvisionChannel(req *Request) (Intent, error) {
// We'll exit out if SubtractFees is set as the funding transaction has
// already been assembled, so we don't influence coin selection.
if req.SubtractFees {
return nil, fmt.Errorf("SubtractFees ignored, funding " +
"transaction is frozen")
}
// We'll exit out if FundUpToMaxAmt or MinFundAmt is set as the funding
// transaction has already been assembled, so we don't influence coin
// selection.
if req.FundUpToMaxAmt != 0 || req.MinFundAmt != 0 {
return nil, fmt.Errorf("FundUpToMaxAmt and MinFundAmt " +
"ignored, funding transaction is frozen")
}
intent := &ShimIntent{
localKey: c.localKey,
remoteKey: c.remoteKey,
chanPoint: &c.chanPoint,
thawHeight: c.thawHeight,
musig2: c.musig2,
}
if c.initiator {
intent.localFundingAmt = c.fundingAmt
} else {
intent.remoteFundingAmt = c.fundingAmt
}
// A simple sanity check to ensure the provisioned request matches the
// re-made shim intent.
if req.LocalAmt+req.RemoteAmt != c.fundingAmt {
return nil, fmt.Errorf("intent doesn't match canned "+
"assembler: local_amt=%v, remote_amt=%v, funding_amt=%v",
req.LocalAmt, req.RemoteAmt, c.fundingAmt)
}
return intent, nil
}
// A compile-time assertion to ensure CannedAssembler meets the Assembler
// interface.
var _ Assembler = (*CannedAssembler)(nil)
package chanfunding
import (
"errors"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcwallet/wallet"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
// ErrInsufficientFunds is a type matching the error interface which is
// returned when coin selection for a new funding transaction fails due to
// having an insufficient amount of confirmed funds.
type ErrInsufficientFunds struct {
amountAvailable btcutil.Amount
amountSelected btcutil.Amount
}
// Error returns a human-readable string describing the error.
func (e *ErrInsufficientFunds) Error() string {
return fmt.Sprintf("not enough witness outputs to create funding "+
"transaction, need %v only have %v available",
e.amountAvailable, e.amountSelected)
}
// errUnsupportedInput is a type matching the error interface, which is returned
// when trying to calculate the fee of a transaction that references an
// unsupported script in the outpoint of a transaction input.
type errUnsupportedInput struct {
PkScript []byte
}
// Error returns a human-readable string describing the error.
func (e *errUnsupportedInput) Error() string {
return fmt.Sprintf("unsupported address type: %x", e.PkScript)
}
// ChangeAddressType is an enum-like type that describes the type of change
// address that should be generated for a transaction.
type ChangeAddressType uint8
const (
// P2WKHChangeAddress indicates that the change output should be a
// P2WKH output.
P2WKHChangeAddress ChangeAddressType = 0
// P2TRChangeAddress indicates that the change output should be a
// P2TR output.
P2TRChangeAddress ChangeAddressType = 1
// ExistingChangeAddress indicates that the coin selection algorithm
// should assume an existing output will be used for any change, meaning
// that the change amount calculated will be added to an existing output
// and no weight for a new change output should be assumed. The caller
// must assert that the output value of the selected existing output
// already is above dust when using this change address type.
ExistingChangeAddress ChangeAddressType = 2
// DefaultMaxFeeRatio is the default fee to total amount of outputs
// ratio that is used to sanity check the fees of a transaction.
DefaultMaxFeeRatio float64 = 0.2
)
// selectInputs selects a slice of inputs necessary to meet the specified
// selection amount. If input selection is unable to succeed due to insufficient
// funds, a non-nil error is returned. Additionally, the total amount of the
// selected coins are returned in order for the caller to properly handle
// change+fees.
func selectInputs(amt btcutil.Amount, coins []wallet.Coin,
strategy wallet.CoinSelectionStrategy,
feeRate chainfee.SatPerKWeight) (btcutil.Amount, []wallet.Coin, error) {
// All coin selection code in the btcwallet library requires sat/KB.
feeSatPerKB := btcutil.Amount(feeRate.FeePerKVByte())
arrangedCoins, err := strategy.ArrangeCoins(coins, feeSatPerKB)
if err != nil {
return 0, nil, err
}
satSelected := btcutil.Amount(0)
for i, coin := range arrangedCoins {
satSelected += btcutil.Amount(coin.Value)
if satSelected >= amt {
return satSelected, arrangedCoins[:i+1], nil
}
}
return 0, nil, &ErrInsufficientFunds{amt, satSelected}
}
// calculateFees returns for the specified utxos and fee rate two fee
// estimates, one calculated using a change output and one without. The weight
// added to the estimator from a change output is for a P2WKH output.
func calculateFees(utxos []wallet.Coin, feeRate chainfee.SatPerKWeight,
existingWeight input.TxWeightEstimator,
changeType ChangeAddressType) (btcutil.Amount, btcutil.Amount, error) {
weightEstimate := existingWeight
for _, utxo := range utxos {
switch {
case txscript.IsPayToWitnessPubKeyHash(utxo.PkScript):
weightEstimate.AddP2WKHInput()
case txscript.IsPayToScriptHash(utxo.PkScript):
weightEstimate.AddNestedP2WKHInput()
case txscript.IsPayToTaproot(utxo.PkScript):
weightEstimate.AddTaprootKeySpendInput(
txscript.SigHashDefault,
)
default:
return 0, 0, &errUnsupportedInput{utxo.PkScript}
}
}
// Estimate the fee required for a transaction without a change
// output.
totalWeight := weightEstimate.Weight()
requiredFeeNoChange := feeRate.FeeForWeight(totalWeight)
// Estimate the fee required for a transaction with a change output.
switch changeType {
case P2WKHChangeAddress:
weightEstimate.AddP2WKHOutput()
case P2TRChangeAddress:
weightEstimate.AddP2TROutput()
case ExistingChangeAddress:
// Don't add an extra output.
default:
return 0, 0, fmt.Errorf("unknown change address type: %v",
changeType)
}
// Now that we have added the change output, redo the fee
// estimate.
totalWeight = weightEstimate.Weight()
requiredFeeWithChange := feeRate.FeeForWeight(totalWeight)
return requiredFeeNoChange, requiredFeeWithChange, nil
}
// sanityCheckFee checks if the specified fee amounts to what the provided ratio
// allows.
func sanityCheckFee(totalOut, fee btcutil.Amount, maxFeeRatio float64) error {
// Sanity check the maxFeeRatio itself.
if maxFeeRatio <= 0.00 || maxFeeRatio > 1.00 {
return fmt.Errorf("maxFeeRatio must be between 0.00 and 1.00 "+
"got %.2f", maxFeeRatio)
}
maxFee := btcutil.Amount(float64(totalOut) * maxFeeRatio)
// Check that the fees do not exceed the max allowed value.
if fee > maxFee {
return fmt.Errorf("fee %v exceeds max fee (%v) on total "+
"output value %v with max fee ratio of %.2f", fee,
maxFee, totalOut, maxFeeRatio)
}
// All checks passed, we return nil to signal that the fees are valid.
return nil
}
// CoinSelect attempts to select a sufficient amount of coins, including a
// change output to fund amt satoshis, adhering to the specified fee rate. The
// specified fee rate should be expressed in sat/kw for coin selection to
// function properly.
func CoinSelect(feeRate chainfee.SatPerKWeight, amt, dustLimit btcutil.Amount,
coins []wallet.Coin, strategy wallet.CoinSelectionStrategy,
existingWeight input.TxWeightEstimator,
changeType ChangeAddressType, maxFeeRatio float64) ([]wallet.Coin,
btcutil.Amount, error) {
amtNeeded := amt
for {
// First perform an initial round of coin selection to estimate
// the required fee.
totalSat, selectedUtxos, err := selectInputs(
amtNeeded, coins, strategy, feeRate,
)
if err != nil {
return nil, 0, err
}
// Obtain fee estimates both with and without using a change
// output.
requiredFeeNoChange, requiredFeeWithChange, err := calculateFees(
selectedUtxos, feeRate, existingWeight, changeType,
)
if err != nil {
return nil, 0, err
}
changeAmount, newAmtNeeded, err := CalculateChangeAmount(
totalSat, amt, requiredFeeNoChange,
requiredFeeWithChange, dustLimit, changeType,
maxFeeRatio,
)
if err != nil {
return nil, 0, err
}
// Need another round, the selected coins aren't enough to pay
// for the fees.
if newAmtNeeded != 0 {
amtNeeded = newAmtNeeded
continue
}
// Coin selection was successful.
return selectedUtxos, changeAmount, nil
}
}
// CalculateChangeAmount calculates the change amount being left over when the
// given total amount of sats is provided as inputs for the required output
// amount. The calculation takes into account that we might not want to add a
// change output if the change amount is below the dust limit. The first amount
// returned is the change amount. If that is non-zero, change is left over and
// should be dealt with. The second amount, if non-zero, indicates that the
// total input amount was just not enough to pay for the required amount and
// fees and that more coins need to be selected.
func CalculateChangeAmount(totalInputAmt, requiredAmt, requiredFeeNoChange,
requiredFeeWithChange, dustLimit btcutil.Amount,
changeType ChangeAddressType, maxFeeRatio float64) (btcutil.Amount,
btcutil.Amount, error) {
// This is just a sanity check to make sure the function is used
// correctly.
if changeType == ExistingChangeAddress &&
requiredFeeNoChange != requiredFeeWithChange {
return 0, 0, fmt.Errorf("when using existing change address, " +
"the fees for with or without change must be the same")
}
// The difference between the selected amount and the amount
// requested will be used to pay fees, and generate a change
// output with the remaining.
overShootAmt := totalInputAmt - requiredAmt
var changeAmt btcutil.Amount
switch {
// If the excess amount isn't enough to pay for fees based on
// fee rate and estimated size without using a change output,
// then increase the requested coin amount by the estimate
// required fee without using change, performing another round
// of coin selection.
case overShootAmt < requiredFeeNoChange:
return 0, requiredAmt + requiredFeeNoChange, nil
// If sufficient funds were selected to cover the fee required
// to include a change output, the remainder will be our change
// amount.
case overShootAmt > requiredFeeWithChange:
changeAmt = overShootAmt - requiredFeeWithChange
// Otherwise we have selected enough to pay for a tx without a
// change output.
default:
changeAmt = 0
}
// In case we would end up with a dust output if we created a
// change output, we instead just let the dust amount go to
// fees. Unless we want the change to go to an existing output,
// in that case we can increase that output value by any amount.
if changeAmt < dustLimit && changeType != ExistingChangeAddress {
changeAmt = 0
}
// Sanity check the resulting output values to make sure we
// don't burn a great part to fees.
totalOut := requiredAmt + changeAmt
err := sanityCheckFee(totalOut, totalInputAmt-totalOut, maxFeeRatio)
if err != nil {
return 0, 0, err
}
return changeAmt, 0, nil
}
// CoinSelectSubtractFees attempts to select coins such that we'll spend up to
// amt in total after fees, adhering to the specified fee rate. The selected
// coins, the final output and change values are returned.
func CoinSelectSubtractFees(feeRate chainfee.SatPerKWeight, amt,
dustLimit btcutil.Amount, coins []wallet.Coin,
strategy wallet.CoinSelectionStrategy,
existingWeight input.TxWeightEstimator, changeType ChangeAddressType,
maxFeeRatio float64) ([]wallet.Coin, btcutil.Amount, btcutil.Amount,
error) {
// First perform an initial round of coin selection to estimate
// the required fee.
totalSat, selectedUtxos, err := selectInputs(
amt, coins, strategy, feeRate,
)
if err != nil {
return nil, 0, 0, err
}
// Obtain fee estimates both with and without using a change
// output.
requiredFeeNoChange, requiredFeeWithChange, err := calculateFees(
selectedUtxos, feeRate, existingWeight, changeType,
)
if err != nil {
return nil, 0, 0, err
}
// For a transaction without a change output, we'll let everything go
// to our multi-sig output after subtracting fees.
outputAmt := totalSat - requiredFeeNoChange
changeAmt := btcutil.Amount(0)
// If the output is too small after subtracting the fee, the coin
// selection cannot be performed with an amount this small.
if outputAmt < dustLimit {
return nil, 0, 0, fmt.Errorf("output amount(%v) after "+
"subtracting fees(%v) below dust limit(%v)", outputAmt,
requiredFeeNoChange, dustLimit)
}
// For a transaction with a change output, everything we don't spend
// will go to change.
newOutput := amt - requiredFeeWithChange
newChange := totalSat - amt
// If adding a change output leads to both outputs being above
// the dust limit, we'll add the change output. Otherwise we'll
// go with the no change tx we originally found.
if newChange >= dustLimit && newOutput >= dustLimit {
outputAmt = newOutput
changeAmt = newChange
}
// Sanity check the resulting output values to make sure we
// don't burn a great part to fees.
totalOut := outputAmt + changeAmt
err = sanityCheckFee(totalOut, totalSat-totalOut, maxFeeRatio)
if err != nil {
return nil, 0, 0, err
}
return selectedUtxos, outputAmt, changeAmt, nil
}
// CoinSelectUpToAmount attempts to select coins such that we'll select up to
// maxAmount exclusive of fees and optional reserve if sufficient funds are
// available. If insufficient funds are available this method selects all
// available coins.
func CoinSelectUpToAmount(feeRate chainfee.SatPerKWeight, minAmount, maxAmount,
reserved, dustLimit btcutil.Amount, coins []wallet.Coin,
strategy wallet.CoinSelectionStrategy,
existingWeight input.TxWeightEstimator,
changeType ChangeAddressType, maxFeeRatio float64) ([]wallet.Coin,
btcutil.Amount, btcutil.Amount, error) {
var (
// selectSubtractFee is tracking if our coin selection was
// unsuccessful and whether we have to start a new round of
// selecting coins considering fees.
selectSubtractFee = false
outputAmount = maxAmount
)
// Get total balance from coins which we need for reserve considerations
// and fee sanity checks.
var totalBalance btcutil.Amount
for _, coin := range coins {
totalBalance += btcutil.Amount(coin.Value)
}
// First we try to select coins to create an output of the specified
// maxAmount with or without a change output that covers the miner fee.
selected, changeAmt, err := CoinSelect(
feeRate, maxAmount, dustLimit, coins, strategy, existingWeight,
changeType, maxFeeRatio,
)
var errInsufficientFunds *ErrInsufficientFunds
switch {
case err == nil:
// If the coin selection succeeds we check if our total balance
// covers the selected set of coins including fees plus an
// optional anchor reserve.
// First we sum up the value of all selected coins.
var sumSelected btcutil.Amount
for _, coin := range selected {
sumSelected += btcutil.Amount(coin.Value)
}
// We then subtract the change amount from the value of all
// selected coins to obtain the actual amount that is selected.
sumSelected -= changeAmt
// Next we check if our total balance can cover for the selected
// output plus the optional anchor reserve.
if totalBalance-sumSelected < reserved {
// If our local balance is insufficient to cover for the
// reserve we try to select an output amount that uses
// our total balance minus reserve and fees.
selectSubtractFee = true
}
case errors.As(err, &errInsufficientFunds):
// If the initial coin selection fails due to insufficient funds
// we select our total available balance minus fees.
selectSubtractFee = true
default:
return nil, 0, 0, err
}
// If we determined that our local balance is insufficient we check
// our total balance minus fees and optional reserve.
if selectSubtractFee {
selected, outputAmount, changeAmt, err = CoinSelectSubtractFees(
feeRate, totalBalance-reserved, dustLimit, coins,
strategy, existingWeight, changeType, maxFeeRatio,
)
if err != nil {
return nil, 0, 0, err
}
}
// Sanity check the resulting output values to make sure we don't burn a
// great part to fees.
totalOut := outputAmount + changeAmt
sum := func(coins []wallet.Coin) btcutil.Amount {
var sum btcutil.Amount
for _, coin := range coins {
sum += btcutil.Amount(coin.Value)
}
return sum
}
err = sanityCheckFee(totalOut, sum(selected)-totalOut, maxFeeRatio)
if err != nil {
return nil, 0, 0, err
}
// In case the selected amount is lower than minimum funding amount we
// must return an error. The minimum funding amount is determined
// upstream and denotes either the minimum viable channel size or an
// amount sufficient to cover for the initial remote balance.
if outputAmount < minAmount {
return nil, 0, 0, fmt.Errorf("available funds(%v) below the "+
"minimum amount(%v)", outputAmount, minAmount)
}
return selected, outputAmount, changeAmt, nil
}
package chanfunding
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("CHFD", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package chanfunding
import (
"errors"
"fmt"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
// PsbtState is a type for the state of the PSBT intent state machine.
type PsbtState uint8
const (
// PsbtShimRegistered denotes a channel funding process has started with
// a PSBT shim attached. This is the default state for a PsbtIntent. We
// don't use iota here because the values have to be in sync with the
// RPC constants.
PsbtShimRegistered PsbtState = 1
// PsbtOutputKnown denotes that the local and remote peer have
// negotiated the multisig keys to be used as the channel funding output
// and therefore the PSBT funding process can now start.
PsbtOutputKnown PsbtState = 2
// PsbtVerified denotes that a potential PSBT has been presented to the
// intent and passed all checks. The verified PSBT can be given to a/the
// signer(s).
PsbtVerified PsbtState = 3
// PsbtFinalized denotes that a fully signed PSBT has been given to the
// intent that looks identical to the previously verified transaction
// but has all witness data added and is therefore completely signed.
PsbtFinalized PsbtState = 4
// PsbtFundingTxCompiled denotes that the PSBT processed by this intent
// has been successfully converted into a protocol transaction. It is
// not yet completely certain that the resulting transaction will be
// published because the commitment transactions between the channel
// peers first need to be counter signed. But the job of the intent is
// hereby completed.
PsbtFundingTxCompiled PsbtState = 5
// PsbtInitiatorCanceled denotes that the user has canceled the intent.
PsbtInitiatorCanceled PsbtState = 6
// PsbtResponderCanceled denotes that the remote peer has canceled the
// funding, likely due to a timeout.
PsbtResponderCanceled PsbtState = 7
)
// String returns a string representation of the PsbtState.
func (s PsbtState) String() string {
switch s {
case PsbtShimRegistered:
return "shim_registered"
case PsbtOutputKnown:
return "output_known"
case PsbtVerified:
return "verified"
case PsbtFinalized:
return "finalized"
case PsbtFundingTxCompiled:
return "funding_tx_compiled"
case PsbtInitiatorCanceled:
return "user_canceled"
case PsbtResponderCanceled:
return "remote_canceled"
default:
return fmt.Sprintf("<unknown(%d)>", s)
}
}
var (
// ErrRemoteCanceled is the error that is returned to the user if the
// funding flow was canceled by the remote peer.
ErrRemoteCanceled = errors.New("remote canceled funding, possibly " +
"timed out")
// ErrUserCanceled is the error that is returned through the PsbtReady
// channel if the user canceled the funding flow.
ErrUserCanceled = errors.New("user canceled funding")
)
// PsbtIntent is an intent created by the PsbtAssembler which represents a
// funding output to be created by a PSBT. This might be used when a hardware
// wallet, or a channel factory is the entity crafting the funding transaction,
// and not lnd.
type PsbtIntent struct {
// ShimIntent is the wrapped basic intent that contains common fields
// we also use in the PSBT funding case.
ShimIntent
// State is the current state the intent state machine is in.
State PsbtState
// BasePsbt is the user-supplied base PSBT the channel output should be
// added to. If this is nil we will create a new, empty PSBT as the base
// for the funding transaction.
BasePsbt *psbt.Packet
// PendingPsbt is the parsed version of the current PSBT. This can be
// in two stages: If the user has not yet provided any PSBT, this is
// nil. Once the user sends us an unsigned funded PSBT, we verify that
// we have a valid transaction that sends to the channel output PK
// script and has an input large enough to pay for it. We keep this
// verified but not yet signed version around until the fully signed
// transaction is submitted by the user. At that point we make sure the
// inputs and outputs haven't changed to what was previously verified.
// Only witness data should be added after the verification process.
PendingPsbt *psbt.Packet
// FinalTX is the final, signed and ready to be published wire format
// transaction. This is only set after the PsbtFinalize step was
// completed successfully.
FinalTX *wire.MsgTx
// PsbtReady is an error channel the funding manager will listen for
// a signal about the PSBT being ready to continue the funding flow. In
// the normal, happy flow, this channel is only ever closed. If a
// non-nil error is sent through the channel, the funding flow will be
// canceled.
//
// NOTE: This channel must always be buffered.
PsbtReady chan error
// shouldPublish specifies if the intent assumes its assembler should
// publish the transaction once the channel funding has completed. If
// this is set to false then the finalize step can be skipped.
shouldPublish bool
// signalPsbtReady is a Once guard to make sure the PsbtReady channel is
// only closed exactly once.
signalPsbtReady sync.Once
// netParams are the network parameters used to encode the P2WSH funding
// address.
netParams *chaincfg.Params
}
// BindKeys sets both the remote and local node's keys that will be used for the
// channel funding multisig output.
func (i *PsbtIntent) BindKeys(localKey *keychain.KeyDescriptor,
remoteKey *btcec.PublicKey) {
i.localKey = localKey
i.remoteKey = remoteKey
i.State = PsbtOutputKnown
}
// BindTapscriptRoot takes an optional tapscript root and binds it to the
// underlying funding intent. This only applies to musig2 channels, and will be
// used to make the musig2 funding output.
func (i *PsbtIntent) BindTapscriptRoot(root fn.Option[chainhash.Hash]) {
i.tapscriptRoot = root
}
// FundingParams returns the parameters that are necessary to start funding the
// channel output this intent was created for. It returns the P2WSH funding
// address, the exact funding amount and a PSBT packet that contains exactly one
// output that encodes the previous two parameters.
func (i *PsbtIntent) FundingParams() (btcutil.Address, int64, *psbt.Packet,
error) {
if i.State != PsbtOutputKnown {
return nil, 0, nil, fmt.Errorf("invalid state, got %v "+
"expected %v", i.State, PsbtOutputKnown)
}
// The funding output needs to be known already at this point, which
// means we need to have the local and remote multisig keys bound
// already.
_, out, err := i.FundingOutput()
if err != nil {
return nil, 0, nil, fmt.Errorf("unable to create funding "+
"output: %v", err)
}
script, err := txscript.ParsePkScript(out.PkScript)
if err != nil {
return nil, 0, nil, fmt.Errorf("unable to parse funding "+
"output script: %w", err)
}
// Encode the address in the human-readable bech32 format.
addr, err := script.Address(i.netParams)
if err != nil {
return nil, 0, nil, fmt.Errorf("unable to encode address: %w",
err)
}
// We'll also encode the address/amount in a machine-readable raw PSBT
// format. If the user supplied a base PSBT, we'll add the output to
// that one, otherwise we'll create a new one.
packet := i.BasePsbt
if packet == nil {
packet, err = psbt.New(nil, nil, 2, 0, nil)
if err != nil {
return nil, 0, nil, fmt.Errorf("unable to create "+
"PSBT: %w", err)
}
}
packet.UnsignedTx.TxOut = append(packet.UnsignedTx.TxOut, out)
var pOut psbt.POutput
// If this is a MuSig2 channel, we also need to communicate the internal
// key to the caller. Otherwise, they cannot verify the construction of
// the P2TR output script.
pOut.TaprootInternalKey = fn.MapOptionZ(
i.TaprootInternalKey(), schnorr.SerializePubKey,
)
packet.Outputs = append(packet.Outputs, pOut)
return addr, out.Value, packet, nil
}
// Verify makes sure the PSBT that is given to the intent has an output that
// sends to the channel funding multisig address with the correct amount. A
// simple check that at least a single input has been specified is performed.
func (i *PsbtIntent) Verify(packet *psbt.Packet, skipFinalize bool) error {
if packet == nil {
return fmt.Errorf("PSBT is nil")
}
if i.State != PsbtOutputKnown {
return fmt.Errorf("invalid state. got %v expected %v", i.State,
PsbtOutputKnown)
}
// Try to locate the channel funding multisig output.
_, expectedOutput, err := i.FundingOutput()
if err != nil {
return fmt.Errorf("funding output cannot be created: %w", err)
}
outputFound := false
outputSum := int64(0)
for _, out := range packet.UnsignedTx.TxOut {
outputSum += out.Value
if psbt.TxOutsEqual(out, expectedOutput) {
outputFound = true
}
}
if !outputFound {
return fmt.Errorf("funding output not found in PSBT")
}
// At least one input needs to be specified and it must be large enough
// to pay for all outputs. We don't want to dive into fee estimation
// here so we just assume that if the input amount exceeds the output
// amount, the chosen fee is sufficient.
if len(packet.UnsignedTx.TxIn) == 0 {
return fmt.Errorf("PSBT has no inputs")
}
sum, err := psbt.SumUtxoInputValues(packet)
if err != nil {
return fmt.Errorf("error determining input sum: %w", err)
}
if sum <= outputSum {
return fmt.Errorf("input amount sum must be larger than " +
"output amount sum")
}
// To avoid possible malleability, all inputs to a funding transaction
// must be SegWit spends.
err = verifyAllInputsSegWit(packet.UnsignedTx.TxIn, packet.Inputs)
if err != nil {
return fmt.Errorf("cannot use TX for channel funding, "+
"not all inputs are SegWit spends, risk of "+
"malleability: %v", err)
}
// In case we aren't going to publish any transaction, we now have
// everything we need and can skip the Finalize step.
i.PendingPsbt = packet
if !i.shouldPublish && skipFinalize {
i.FinalTX = packet.UnsignedTx
i.State = PsbtFinalized
// Signal the funding manager that it can now continue with its
// funding flow as the PSBT is now complete .
i.signalPsbtReady.Do(func() {
close(i.PsbtReady)
})
return nil
}
i.State = PsbtVerified
return nil
}
// Finalize makes sure the final PSBT that is given to the intent is fully valid
// and signed but still contains the same UTXOs and outputs as the pending
// transaction we previously verified. If everything checks out, the funding
// manager is informed that the channel can now be opened and the funding
// transaction be broadcast.
func (i *PsbtIntent) Finalize(packet *psbt.Packet) error {
if packet == nil {
return fmt.Errorf("PSBT is nil")
}
if i.State != PsbtVerified {
return fmt.Errorf("invalid state. got %v expected %v", i.State,
PsbtVerified)
}
// Make sure the PSBT itself thinks it's finalized and ready to be
// broadcast.
err := psbt.MaybeFinalizeAll(packet)
if err != nil {
return fmt.Errorf("error finalizing PSBT: %w", err)
}
rawTx, err := psbt.Extract(packet)
if err != nil {
return fmt.Errorf("unable to extract funding TX: %w", err)
}
return i.FinalizeRawTX(rawTx)
}
// FinalizeRawTX makes sure the final raw transaction that is given to the
// intent is fully valid and signed but still contains the same UTXOs and
// outputs as the pending transaction we previously verified. If everything
// checks out, the funding manager is informed that the channel can now be
// opened and the funding transaction be broadcast.
func (i *PsbtIntent) FinalizeRawTX(rawTx *wire.MsgTx) error {
if rawTx == nil {
return fmt.Errorf("raw transaction is nil")
}
if i.State != PsbtVerified {
return fmt.Errorf("invalid state. got %v expected %v", i.State,
PsbtVerified)
}
// Do a basic check that this is still the same TX that we verified in
// the previous step. This is to protect the user from unwanted
// modifications. We only check the outputs and previous outpoints of
// the inputs of the wire transaction because the fields in the PSBT
// part are allowed to change.
if i.PendingPsbt == nil {
return fmt.Errorf("PSBT was not verified first")
}
err := psbt.VerifyOutputsEqual(
rawTx.TxOut, i.PendingPsbt.UnsignedTx.TxOut,
)
if err != nil {
return fmt.Errorf("outputs differ from verified PSBT: %w", err)
}
err = psbt.VerifyInputPrevOutpointsEqual(
rawTx.TxIn, i.PendingPsbt.UnsignedTx.TxIn,
)
if err != nil {
return fmt.Errorf("inputs differ from verified PSBT: %w", err)
}
// We also check that we have a signed TX. This is only necessary if the
// FinalizeRawTX is called directly with a wire format TX instead of
// extracting the TX from a PSBT.
err = verifyInputsSigned(rawTx.TxIn)
if err != nil {
return fmt.Errorf("inputs not signed: %w", err)
}
// As far as we can tell, this TX is ok to be used as a funding
// transaction.
i.State = PsbtFinalized
i.FinalTX = rawTx
// Signal the funding manager that it can now finally continue with its
// funding flow as the PSBT is now ready to be converted into a real
// transaction and be published.
i.signalPsbtReady.Do(func() {
close(i.PsbtReady)
})
return nil
}
// CompileFundingTx finalizes the previously verified PSBT and returns the
// extracted binary serialized transaction from it. It also prepares the channel
// point for which this funding intent was initiated for.
func (i *PsbtIntent) CompileFundingTx() (*wire.MsgTx, error) {
if i.State != PsbtFinalized {
return nil, fmt.Errorf("invalid state. got %v expected %v",
i.State, PsbtFinalized)
}
// Identify our funding outpoint now that we know everything's ready.
_, txOut, err := i.FundingOutput()
if err != nil {
return nil, fmt.Errorf("cannot get funding output: %w", err)
}
ok, idx := input.FindScriptOutputIndex(i.FinalTX, txOut.PkScript)
if !ok {
return nil, fmt.Errorf("funding output not found in PSBT")
}
i.chanPoint = &wire.OutPoint{
Hash: i.FinalTX.TxHash(),
Index: idx,
}
i.State = PsbtFundingTxCompiled
return i.FinalTX, nil
}
// RemoteCanceled informs the listener of the PSBT ready channel that the
// funding has been canceled by the remote peer and that we can no longer
// continue with it.
func (i *PsbtIntent) RemoteCanceled() {
log.Debugf("PSBT funding intent canceled by remote, state=%v", i.State)
i.signalPsbtReady.Do(func() {
i.PsbtReady <- ErrRemoteCanceled
i.State = PsbtResponderCanceled
})
i.ShimIntent.Cancel()
}
// Cancel allows the caller to cancel a funding Intent at any time. This will
// return make sure the channel funding flow with the remote peer is failed and
// any reservations are canceled.
//
// NOTE: Part of the chanfunding.Intent interface.
func (i *PsbtIntent) Cancel() {
log.Debugf("PSBT funding intent canceled, state=%v", i.State)
i.signalPsbtReady.Do(func() {
i.PsbtReady <- ErrUserCanceled
i.State = PsbtInitiatorCanceled
})
i.ShimIntent.Cancel()
}
// Inputs returns all inputs to the final funding transaction that we know
// about. These are only known after the PSBT has been verified.
func (i *PsbtIntent) Inputs() []wire.OutPoint {
var inputs []wire.OutPoint
switch i.State {
// We return the inputs to the pending psbt.
case PsbtVerified:
for _, in := range i.PendingPsbt.UnsignedTx.TxIn {
inputs = append(inputs, in.PreviousOutPoint)
}
// We return the inputs to the final funding tx.
case PsbtFinalized, PsbtFundingTxCompiled:
for _, in := range i.FinalTX.TxIn {
inputs = append(inputs, in.PreviousOutPoint)
}
// In all other states we cannot know the inputs to the funding tx, and
// return an empty list.
default:
}
return inputs
}
// Outputs returns all outputs of the final funding transaction that we
// know about. These are only known after the PSBT has been verified.
func (i *PsbtIntent) Outputs() []*wire.TxOut {
switch i.State {
// We return the outputs of the pending psbt.
case PsbtVerified:
return i.PendingPsbt.UnsignedTx.TxOut
// We return the outputs of the final funding tx.
case PsbtFinalized, PsbtFundingTxCompiled:
return i.FinalTX.TxOut
// In all other states we cannot know the final outputs, and return an
// empty list.
default:
return nil
}
}
// ShouldPublishFundingTX returns true if the intent assumes that its assembler
// should publish the funding TX once the funding negotiation is complete.
func (i *PsbtIntent) ShouldPublishFundingTX() bool {
return i.shouldPublish
}
// PsbtAssembler is a type of chanfunding.Assembler wherein the funding
// transaction is constructed outside of lnd by using partially signed bitcoin
// transactions (PSBT).
type PsbtAssembler struct {
// fundingAmt is the total amount of coins in the funding output.
fundingAmt btcutil.Amount
// basePsbt is the user-supplied base PSBT the channel output should be
// added to.
basePsbt *psbt.Packet
// netParams are the network parameters used to encode the P2WSH funding
// address.
netParams *chaincfg.Params
// shouldPublish specifies if the assembler should publish the
// transaction once the channel funding has completed.
shouldPublish bool
}
// NewPsbtAssembler creates a new CannedAssembler from the material required
// to construct a funding output and channel point. An optional base PSBT can
// be supplied which will be used to add the channel output to instead of
// creating a new one.
func NewPsbtAssembler(fundingAmt btcutil.Amount, basePsbt *psbt.Packet,
netParams *chaincfg.Params, shouldPublish bool) *PsbtAssembler {
return &PsbtAssembler{
fundingAmt: fundingAmt,
basePsbt: basePsbt,
netParams: netParams,
shouldPublish: shouldPublish,
}
}
// ProvisionChannel creates a new ShimIntent given the passed funding Request.
// The returned intent is immediately able to provide the channel point and
// funding output as they've already been created outside lnd.
//
// NOTE: This method satisfies the chanfunding.Assembler interface.
func (p *PsbtAssembler) ProvisionChannel(req *Request) (Intent, error) {
// We'll exit out if SubtractFees is set as the funding transaction will
// be assembled externally, so we don't influence coin selection.
if req.SubtractFees {
return nil, fmt.Errorf("SubtractFees not supported for PSBT")
}
// We'll exit out if FundUpToMaxAmt or MinFundAmt is set as the funding
// transaction will be assembled externally, so we don't influence coin
// selection.
if req.FundUpToMaxAmt != 0 || req.MinFundAmt != 0 {
return nil, fmt.Errorf("FundUpToMaxAmt and MinFundAmt not " +
"supported for PSBT")
}
intent := &PsbtIntent{
ShimIntent: ShimIntent{
localFundingAmt: p.fundingAmt,
musig2: req.Musig2,
tapscriptRoot: req.TapscriptRoot,
},
State: PsbtShimRegistered,
BasePsbt: p.basePsbt,
PsbtReady: make(chan error, 1),
shouldPublish: p.shouldPublish,
netParams: p.netParams,
}
// A simple sanity check to ensure the provisioned request matches the
// re-made shim intent.
if req.LocalAmt+req.RemoteAmt != p.fundingAmt {
return nil, fmt.Errorf("intent doesn't match PSBT "+
"assembler: local_amt=%v, remote_amt=%v, funding_amt=%v",
req.LocalAmt, req.RemoteAmt, p.fundingAmt)
}
return intent, nil
}
// ShouldPublishFundingTx is a method of the assembler that signals if the
// funding transaction should be published after the channel negotiations are
// completed with the remote peer.
//
// NOTE: This method is a part of the ConditionalPublishAssembler interface.
func (p *PsbtAssembler) ShouldPublishFundingTx() bool {
return p.shouldPublish
}
// A compile-time assertion to ensure PsbtAssembler meets the
// ConditionalPublishAssembler interface.
var _ ConditionalPublishAssembler = (*PsbtAssembler)(nil)
// verifyInputsSigned verifies that the given list of inputs is non-empty and
// that all the inputs either contain a script signature or a witness stack.
func verifyInputsSigned(ins []*wire.TxIn) error {
if len(ins) == 0 {
return fmt.Errorf("no inputs in transaction")
}
for idx, in := range ins {
if len(in.SignatureScript) == 0 && len(in.Witness) == 0 {
return fmt.Errorf("input %d has no signature data "+
"attached", idx)
}
}
return nil
}
// verifyAllInputsSegWit makes sure all inputs to a transaction are SegWit
// spends. This is a bit tricky because the PSBT spec doesn't require the
// WitnessUtxo field to be set. Therefore if only a NonWitnessUtxo is given, we
// need to look at it and make sure it's either a witness pkScript or a nested
// SegWit spend.
func verifyAllInputsSegWit(txIns []*wire.TxIn, ins []psbt.PInput) error {
for idx, in := range ins {
switch {
// The optimal case is that the witness UTXO is set explicitly.
case in.WitnessUtxo != nil:
// Only the non witness UTXO field is set, we need to inspect it
// to make sure it's not P2PKH or bare P2SH.
case in.NonWitnessUtxo != nil:
utxo := in.NonWitnessUtxo
txIn := txIns[idx]
txOut := utxo.TxOut[txIn.PreviousOutPoint.Index]
if !txscript.IsWitnessProgram(txOut.PkScript) &&
!txscript.IsWitnessProgram(in.RedeemScript) {
return fmt.Errorf("input %d is non-SegWit "+
"spend or missing redeem script", idx)
}
// This should've already been caught by a previous check but we
// keep it in for completeness' sake.
default:
return fmt.Errorf("input %d has no UTXO information",
idx)
}
}
return nil
}
package chanfunding
import (
"fmt"
"math"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/txsort"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
)
const (
// DefaultReservationTimeout is the default time we wait until we remove
// an unfinished (zombiestate) open channel flow from memory.
DefaultReservationTimeout = 10 * time.Minute
// DefaultLockDuration is the default duration used to lock outputs.
DefaultLockDuration = 10 * time.Minute
)
var (
// LndInternalLockID is the binary representation of the SHA256 hash of
// the string "lnd-internal-lock-id" and is used for UTXO lock leases to
// identify that we ourselves are locking an UTXO, for example when
// giving out a funded PSBT. The ID corresponds to the hex value of
// ede19a92ed321a4705f8a1cccc1d4f6182545d4bb4fae08bd5937831b7e38f98.
LndInternalLockID = wtxmgr.LockID{
0xed, 0xe1, 0x9a, 0x92, 0xed, 0x32, 0x1a, 0x47,
0x05, 0xf8, 0xa1, 0xcc, 0xcc, 0x1d, 0x4f, 0x61,
0x82, 0x54, 0x5d, 0x4b, 0xb4, 0xfa, 0xe0, 0x8b,
0xd5, 0x93, 0x78, 0x31, 0xb7, 0xe3, 0x8f, 0x98,
}
)
// FullIntent is an intent that is fully backed by the internal wallet. This
// intent differs from the ShimIntent, in that the funding transaction will be
// constructed internally, and will consist of only inputs we wholly control.
// This Intent implements a basic state machine that must be executed in order
// before CompileFundingTx can be called.
//
// Steps to final channel provisioning:
// 1. Call BindKeys to notify the intent which keys to use when constructing
// the multi-sig output.
// 2. Call CompileFundingTx afterwards to obtain the funding transaction.
//
// If either of these steps fail, then the Cancel method MUST be called.
type FullIntent struct {
ShimIntent
// InputCoins are the set of coins selected as inputs to this funding
// transaction.
InputCoins []wallet.Coin
// ChangeOutputs are the set of outputs that the Assembler will use as
// change from the main funding transaction.
ChangeOutputs []*wire.TxOut
// coinLeaser is the Assembler's instance of the OutputLeaser
// interface.
coinLeaser OutputLeaser
// coinSource is the Assembler's instance of the CoinSource interface.
coinSource CoinSource
// signer is the Assembler's instance of the Singer interface.
signer input.Signer
}
// BindKeys is a method unique to the FullIntent variant. This allows the
// caller to decide precisely which keys are used in the final funding
// transaction. This is kept out of the main Assembler as these may may not
// necessarily be under full control of the wallet. Only after this method has
// been executed will CompileFundingTx succeed.
func (f *FullIntent) BindKeys(localKey *keychain.KeyDescriptor,
remoteKey *btcec.PublicKey) {
f.localKey = localKey
f.remoteKey = remoteKey
}
// CompileFundingTx is to be called after BindKeys on the sub-intent has been
// called. This method will construct the final funding transaction, and fully
// sign all inputs that are known by the backing CoinSource. After this method
// returns, the Intent is assumed to be complete, as the output can be created
// at any point.
func (f *FullIntent) CompileFundingTx(extraInputs []*wire.TxIn,
extraOutputs []*wire.TxOut) (*wire.MsgTx, error) {
// Create a blank, fresh transaction. Soon to be a complete funding
// transaction which will allow opening a lightning channel.
fundingTx := wire.NewMsgTx(2)
// Add all multi-party inputs and outputs to the transaction.
for _, coin := range f.InputCoins {
fundingTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: coin.OutPoint,
})
}
for _, theirInput := range extraInputs {
fundingTx.AddTxIn(theirInput)
}
for _, ourChangeOutput := range f.ChangeOutputs {
fundingTx.AddTxOut(ourChangeOutput)
}
for _, theirChangeOutput := range extraOutputs {
fundingTx.AddTxOut(theirChangeOutput)
}
_, fundingOutput, err := f.FundingOutput()
if err != nil {
return nil, err
}
// Sort the transaction. Since both side agree to a canonical ordering,
// by sorting we no longer need to send the entire transaction. Only
// signatures will be exchanged.
fundingTx.AddTxOut(fundingOutput)
txsort.InPlaceSort(fundingTx)
// Now that the funding tx has been fully assembled, we'll locate the
// index of the funding output so we can create our final channel
// point.
_, multiSigIndex := input.FindScriptOutputIndex(
fundingTx, fundingOutput.PkScript,
)
// Next, sign all inputs that are ours, collecting the signatures in
// order of the inputs.
prevOutFetcher := NewSegWitV0DualFundingPrevOutputFetcher(
f.coinSource, extraInputs,
)
sigHashes := txscript.NewTxSigHashes(fundingTx, prevOutFetcher)
for i, txIn := range fundingTx.TxIn {
signDesc := input.SignDescriptor{
SigHashes: sigHashes,
PrevOutputFetcher: prevOutFetcher,
}
// We can only sign this input if it's ours, so we'll ask the
// coin source if it can map this outpoint into a coin we own.
// If not, then we'll continue as it isn't our input.
info, err := f.coinSource.CoinFromOutPoint(
txIn.PreviousOutPoint,
)
if err != nil {
continue
}
// Now that we know the input is ours, we'll populate the
// signDesc with the per input unique information.
signDesc.Output = &wire.TxOut{
Value: info.Value,
PkScript: info.PkScript,
}
signDesc.InputIndex = i
// We support spending a p2tr input ourselves. But not as part
// of their inputs.
signDesc.HashType = txscript.SigHashAll
if txscript.IsPayToTaproot(info.PkScript) {
signDesc.HashType = txscript.SigHashDefault
}
// Finally, we'll sign the input as is, and populate the input
// with the witness and sigScript (if needed).
inputScript, err := f.signer.ComputeInputScript(
fundingTx, &signDesc,
)
if err != nil {
return nil, err
}
txIn.SignatureScript = inputScript.SigScript
txIn.Witness = inputScript.Witness
}
// Finally, we'll populate the chanPoint now that we've fully
// constructed the funding transaction.
f.chanPoint = &wire.OutPoint{
Hash: fundingTx.TxHash(),
Index: multiSigIndex,
}
return fundingTx, nil
}
// Inputs returns all inputs to the final funding transaction that we
// know about. Since this funding transaction is created all from our wallet,
// it will be all inputs.
func (f *FullIntent) Inputs() []wire.OutPoint {
var ins []wire.OutPoint
for _, coin := range f.InputCoins {
ins = append(ins, coin.OutPoint)
}
return ins
}
// Outputs returns all outputs of the final funding transaction that we
// know about. This will be the funding output and the change outputs going
// back to our wallet.
func (f *FullIntent) Outputs() []*wire.TxOut {
outs := f.ShimIntent.Outputs()
outs = append(outs, f.ChangeOutputs...)
return outs
}
// Cancel allows the caller to cancel a funding Intent at any time. This will
// return any resources such as coins back to the eligible pool to be used in
// order channel fundings.
//
// NOTE: Part of the chanfunding.Intent interface.
func (f *FullIntent) Cancel() {
for _, coin := range f.InputCoins {
err := f.coinLeaser.ReleaseOutput(
LndInternalLockID, coin.OutPoint,
)
if err != nil {
log.Warnf("Failed to release UTXO %s (%v))",
coin.OutPoint, err)
}
}
f.ShimIntent.Cancel()
}
// A compile-time check to ensure FullIntent meets the Intent interface.
var _ Intent = (*FullIntent)(nil)
// WalletConfig is the main config of the WalletAssembler.
type WalletConfig struct {
// CoinSource is what the WalletAssembler uses to list/locate coins.
CoinSource CoinSource
// CoinSelectionStrategy is the strategy that is used for selecting
// coins when funding a transaction.
CoinSelectionStrategy wallet.CoinSelectionStrategy
// CoinSelectionLocker allows the WalletAssembler to gain exclusive
// access to the current set of coins returned by the CoinSource.
CoinSelectLocker CoinSelectionLocker
// CoinLeaser is what the WalletAssembler uses to lease coins that may
// be used as inputs for a new funding transaction.
CoinLeaser OutputLeaser
// Signer allows the WalletAssembler to sign inputs on any potential
// funding transactions.
Signer input.Signer
// DustLimit is the current dust limit. We'll use this to ensure that
// we don't make dust outputs on the funding transaction.
DustLimit btcutil.Amount
}
// WalletAssembler is an instance of the Assembler interface that is backed by
// a full wallet. This variant of the Assembler interface will produce the
// entirety of the funding transaction within the wallet. This implements the
// typical funding flow that is initiated either on the p2p level or using the
// CLi.
type WalletAssembler struct {
cfg WalletConfig
}
// NewWalletAssembler creates a new instance of the WalletAssembler from a
// fully populated wallet config.
func NewWalletAssembler(cfg WalletConfig) *WalletAssembler {
return &WalletAssembler{
cfg: cfg,
}
}
// ProvisionChannel is the main entry point to begin a funding workflow given a
// fully populated request. The internal WalletAssembler will perform coin
// selection in a goroutine safe manner, returning an Intent that will allow
// the caller to finalize the funding process.
//
// NOTE: To cancel the funding flow the Cancel() method on the returned Intent,
// MUST be called.
//
// NOTE: This is a part of the chanfunding.Assembler interface.
func (w *WalletAssembler) ProvisionChannel(r *Request) (Intent, error) {
var intent Intent
// We hold the coin select mutex while querying for outputs, and
// performing coin selection in order to avoid inadvertent double
// spends across funding transactions.
err := w.cfg.CoinSelectLocker.WithCoinSelectLock(func() error {
log.Infof("Performing funding tx coin selection using %v "+
"sat/kw as fee rate", int64(r.FeeRate))
var (
// allCoins refers to the entirety of coins in our
// wallet that are available for funding a channel.
allCoins []wallet.Coin
// manuallySelectedCoins refers to the client-side
// selected coins that should be considered available
// for funding a channel.
manuallySelectedCoins []wallet.Coin
err error
)
// Convert manually selected outpoints to coins.
manuallySelectedCoins, err = outpointsToCoins(
r.Outpoints, w.cfg.CoinSource.CoinFromOutPoint,
)
if err != nil {
return err
}
// Find all unlocked unspent witness outputs that satisfy the
// minimum number of confirmations required. Coin selection in
// this function currently ignores the configured coin selection
// strategy.
allCoins, err = w.cfg.CoinSource.ListCoins(
r.MinConfs, math.MaxInt32,
)
if err != nil {
return err
}
// Ensure that all manually selected coins remain unspent.
unspent := make(map[wire.OutPoint]struct{})
for _, coin := range allCoins {
unspent[coin.OutPoint] = struct{}{}
}
for _, coin := range manuallySelectedCoins {
if _, ok := unspent[coin.OutPoint]; !ok {
return fmt.Errorf("outpoint already spent or "+
"locked by another subsystem: %v",
coin.OutPoint)
}
}
// The coin selection algorithm requires to know what
// inputs/outputs are already present in the funding
// transaction and what a change output would look like. Since
// a channel funding is always either a P2WSH or P2TR output,
// we can use just P2WSH here (both of these output types have
// the same length). And we currently don't support specifying a
// change output type, so we always use P2TR.
var fundingOutputWeight input.TxWeightEstimator
fundingOutputWeight.AddP2WSHOutput()
changeType := P2TRChangeAddress
var (
coins []wallet.Coin
selectedCoins []wallet.Coin
localContributionAmt btcutil.Amount
changeAmt btcutil.Amount
)
// If outputs were specified manually then we'll take the
// corresponding coins as basis for coin selection. Otherwise,
// all available coins from our wallet are used.
coins = allCoins
if len(manuallySelectedCoins) > 0 {
coins = manuallySelectedCoins
}
// Perform coin selection over our available, unlocked unspent
// outputs in order to find enough coins to meet the funding
// amount requirements.
switch {
// If there's no funding amount at all (receiving an inbound
// single funder request), then we don't need to perform any
// coin selection at all.
case r.LocalAmt == 0 && r.FundUpToMaxAmt == 0:
break
// The local funding amount cannot be used in combination with
// the funding up to some maximum amount. If that is the case
// we return an error.
case r.LocalAmt != 0 && r.FundUpToMaxAmt != 0:
return fmt.Errorf("cannot use a local funding amount " +
"and fundmax parameters")
// We cannot use the subtract fees flag while using the funding
// up to some maximum amount. If that is the case we return an
// error.
case r.SubtractFees && r.FundUpToMaxAmt != 0:
return fmt.Errorf("cannot subtract fees from local " +
"amount while using fundmax parameters")
// In case this request uses funding up to some maximum amount,
// we will call the specialized coin selection function for
// that.
case r.FundUpToMaxAmt != 0 && r.MinFundAmt != 0:
// We need to ensure that manually selected coins, which
// are spent entirely on the channel funding, leave
// enough funds in the wallet to cover for a reserve.
reserve := r.WalletReserve
if len(manuallySelectedCoins) > 0 {
sumCoins := func(
coins []wallet.Coin) btcutil.Amount {
var sum btcutil.Amount
for _, coin := range coins {
sum += btcutil.Amount(
coin.Value,
)
}
return sum
}
sumManual := sumCoins(manuallySelectedCoins)
sumAll := sumCoins(allCoins)
// If sufficient reserve funds are available we
// don't have to provide for it during coin
// selection. The manually selected coins can be
// spent entirely on the channel funding. If
// the excess of coins cover the reserve
// partially then we have to provide for the
// rest during coin selection.
excess := sumAll - sumManual
if excess >= reserve {
reserve = 0
} else {
reserve -= excess
}
}
selectedCoins, localContributionAmt, changeAmt,
err = CoinSelectUpToAmount(
r.FeeRate, r.MinFundAmt, r.FundUpToMaxAmt,
reserve, w.cfg.DustLimit, coins,
w.cfg.CoinSelectionStrategy,
fundingOutputWeight, changeType,
DefaultMaxFeeRatio,
)
if err != nil {
return err
}
// Now where the actual channel capacity is determined
// we can check for local contribution constraints.
//
// Ensure that the remote channel reserve does not
// exceed 20% of the channel capacity.
if r.RemoteChanReserve >= localContributionAmt/5 {
return fmt.Errorf("remote channel reserve " +
"must be less than the %%20 of the " +
"channel capacity")
}
// Ensure that the initial remote balance does not
// exceed our local contribution as that would leave a
// negative balance on our side.
if r.PushAmt >= localContributionAmt {
return fmt.Errorf("amount pushed to remote " +
"peer for initial state must be " +
"below the local funding amount")
}
// In case this request want the fees subtracted from the local
// amount, we'll call the specialized method for that. This
// ensures that we won't deduct more that the specified balance
// from our wallet.
case r.SubtractFees:
dustLimit := w.cfg.DustLimit
selectedCoins, localContributionAmt, changeAmt,
err = CoinSelectSubtractFees(
r.FeeRate, r.LocalAmt, dustLimit, coins,
w.cfg.CoinSelectionStrategy,
fundingOutputWeight, changeType,
DefaultMaxFeeRatio,
)
if err != nil {
return err
}
// Otherwise do a normal coin selection where we target a given
// funding amount.
default:
dustLimit := w.cfg.DustLimit
localContributionAmt = r.LocalAmt
selectedCoins, changeAmt, err = CoinSelect(
r.FeeRate, r.LocalAmt, dustLimit, coins,
w.cfg.CoinSelectionStrategy,
fundingOutputWeight, changeType,
DefaultMaxFeeRatio,
)
if err != nil {
return err
}
}
// Sanity check: The addition of the outputs should not lead to the
// creation of dust.
if changeAmt != 0 && changeAmt < w.cfg.DustLimit {
return fmt.Errorf("change amount(%v) after coin "+
"select is below dust limit(%v)", changeAmt,
w.cfg.DustLimit)
}
// Record any change output(s) generated as a result of the
// coin selection.
var changeOutput *wire.TxOut
if changeAmt != 0 {
changeAddr, err := r.ChangeAddr()
if err != nil {
return err
}
changeScript, err := txscript.PayToAddrScript(changeAddr)
if err != nil {
return err
}
changeOutput = &wire.TxOut{
Value: int64(changeAmt),
PkScript: changeScript,
}
}
// Lock the selected coins. These coins are now "reserved",
// this prevents concurrent funding requests from referring to
// and this double-spending the same set of coins.
for _, coin := range selectedCoins {
outpoint := coin.OutPoint
_, err = w.cfg.CoinLeaser.LeaseOutput(
LndInternalLockID, outpoint,
DefaultReservationTimeout,
)
if err != nil {
return err
}
}
newIntent := &FullIntent{
ShimIntent: ShimIntent{
localFundingAmt: localContributionAmt,
remoteFundingAmt: r.RemoteAmt,
musig2: r.Musig2,
tapscriptRoot: r.TapscriptRoot,
},
InputCoins: selectedCoins,
coinLeaser: w.cfg.CoinLeaser,
coinSource: w.cfg.CoinSource,
signer: w.cfg.Signer,
}
if changeOutput != nil {
newIntent.ChangeOutputs = []*wire.TxOut{changeOutput}
}
intent = newIntent
return nil
})
if err != nil {
return nil, err
}
return intent, nil
}
// outpointsToCoins maps outpoints to coins in our wallet iff these coins are
// existent and returns an error otherwise.
func outpointsToCoins(outpoints []wire.OutPoint,
coinFromOutPoint func(wire.OutPoint) (*wallet.Coin, error)) (
[]wallet.Coin, error) {
var selectedCoins []wallet.Coin
for _, outpoint := range outpoints {
coin, err := coinFromOutPoint(
outpoint,
)
if err != nil {
return nil, err
}
selectedCoins = append(
selectedCoins, *coin,
)
}
return selectedCoins, nil
}
// FundingTxAvailable is an empty method that an assembler can implement to
// signal to callers that its able to provide the funding transaction for the
// channel via the intent it returns.
//
// NOTE: This method is a part of the FundingTxAssembler interface.
func (w *WalletAssembler) FundingTxAvailable() {}
// A compile-time assertion to ensure the WalletAssembler meets the
// FundingTxAssembler interface.
var _ FundingTxAssembler = (*WalletAssembler)(nil)
// SegWitV0DualFundingPrevOutputFetcher is a txscript.PrevOutputFetcher that
// knows about local and remote funding inputs.
//
// TODO(guggero): Support dual funding with p2tr inputs, currently only segwit
// v0 inputs are supported.
type SegWitV0DualFundingPrevOutputFetcher struct {
local CoinSource
remote *txscript.MultiPrevOutFetcher
}
var _ txscript.PrevOutputFetcher = (*SegWitV0DualFundingPrevOutputFetcher)(nil)
// NewSegWitV0DualFundingPrevOutputFetcher creates a new
// txscript.PrevOutputFetcher from the given local and remote inputs.
//
// NOTE: Since the actual pkScript and amounts aren't passed in, this will just
// make sure that nothing will panic when creating a SegWit v0 sighash. But this
// code will NOT WORK for transactions that spend any _remote_ Taproot inputs!
// So basically dual-funding won't work with Taproot inputs unless the UTXO info
// is exchanged between the peers.
func NewSegWitV0DualFundingPrevOutputFetcher(localSource CoinSource,
remoteInputs []*wire.TxIn) txscript.PrevOutputFetcher {
remote := txscript.NewMultiPrevOutFetcher(nil)
for _, inp := range remoteInputs {
// We add an empty output to prevent the sighash calculation
// from panicking. But this will always detect the inputs as
// SegWig v0!
remote.AddPrevOut(inp.PreviousOutPoint, &wire.TxOut{})
}
return &SegWitV0DualFundingPrevOutputFetcher{
local: localSource,
remote: remote,
}
}
// FetchPrevOutput attempts to fetch the previous output referenced by the
// passed outpoint.
//
// NOTE: This is a part of the txscript.PrevOutputFetcher interface.
func (d *SegWitV0DualFundingPrevOutputFetcher) FetchPrevOutput(
op wire.OutPoint) *wire.TxOut {
// Try the local source first. This will return nil if our internal
// wallet doesn't know the outpoint.
coin, err := d.local.CoinFromOutPoint(op)
if err == nil && coin != nil {
return &coin.TxOut
}
// Fall back to the remote
return d.remote.FetchPrevOutput(op)
}
package lnwallet
import (
"bytes"
"cmp"
"context"
"crypto/sha256"
"errors"
"fmt"
"slices"
"sync"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/txsort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/mempool"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// ErrChanClosing is returned when a caller attempts to close a channel
// that has already been closed or is in the process of being closed.
ErrChanClosing = fmt.Errorf("channel is being closed, operation disallowed")
// ErrNoWindow is returned when revocation window is exhausted.
ErrNoWindow = fmt.Errorf("unable to sign new commitment, the current" +
" revocation window is exhausted")
// ErrMaxWeightCost is returned when the cost/weight (see segwit)
// exceeds the widely used maximum allowed policy weight limit. In this
// case the commitment transaction can't be propagated through the
// network.
ErrMaxWeightCost = fmt.Errorf("commitment transaction exceed max " +
"available cost")
// ErrMaxHTLCNumber is returned when a proposed HTLC would exceed the
// maximum number of allowed HTLC's if committed in a state transition
ErrMaxHTLCNumber = fmt.Errorf("commitment transaction exceed max " +
"htlc number")
// ErrMaxPendingAmount is returned when a proposed HTLC would exceed
// the overall maximum pending value of all HTLCs if committed in a
// state transition.
ErrMaxPendingAmount = fmt.Errorf("commitment transaction exceed max" +
"overall pending htlc value")
// ErrBelowChanReserve is returned when a proposed HTLC would cause
// one of the peer's funds to dip below the channel reserve limit.
ErrBelowChanReserve = fmt.Errorf("commitment transaction dips peer " +
"below chan reserve")
// ErrBelowMinHTLC is returned when a proposed HTLC has a value that
// is below the minimum HTLC value constraint for either us or our
// peer depending on which flags are set.
ErrBelowMinHTLC = fmt.Errorf("proposed HTLC value is below minimum " +
"allowed HTLC value")
// ErrFeeBufferNotInitiator is returned when the FeeBuffer is enforced
// although the channel was not initiated (opened) locally.
ErrFeeBufferNotInitiator = fmt.Errorf("unable to enforce FeeBuffer, " +
"not initiator of the channel")
// ErrInvalidHTLCAmt signals that a proposed HTLC has a value that is
// not positive.
ErrInvalidHTLCAmt = fmt.Errorf("proposed HTLC value must be positive")
// ErrCannotSyncCommitChains is returned if, upon receiving a ChanSync
// message, the state machine deems that is unable to properly
// synchronize states with the remote peer. In this case we should fail
// the channel, but we won't automatically force close.
ErrCannotSyncCommitChains = fmt.Errorf("unable to sync commit chains")
// ErrInvalidLastCommitSecret is returned in the case that the
// commitment secret sent by the remote party in their
// ChannelReestablish message doesn't match the last secret we sent.
ErrInvalidLastCommitSecret = fmt.Errorf("commit secret is incorrect")
// ErrInvalidLocalUnrevokedCommitPoint is returned in the case that the
// commitment point sent by the remote party in their
// ChannelReestablish message doesn't match the last unrevoked commit
// point they sent us.
ErrInvalidLocalUnrevokedCommitPoint = fmt.Errorf("unrevoked commit " +
"point is invalid")
// ErrCommitSyncRemoteDataLoss is returned in the case that we receive
// a ChannelReestablish message from the remote that advertises a
// NextLocalCommitHeight that is lower than what they have already
// ACKed, or a RemoteCommitTailHeight that is lower than our revoked
// height. In this case we should force close the channel such that
// both parties can retrieve their funds.
ErrCommitSyncRemoteDataLoss = fmt.Errorf("possible remote commitment " +
"state data loss")
// ErrNoRevocationLogFound is returned when both the returned logs are
// nil from querying the revocation log bucket. In theory this should
// never happen as the query will return `ErrLogEntryNotFound`, yet
// we'd still perform a sanity check to make sure at least one of the
// logs is non-nil.
ErrNoRevocationLogFound = errors.New("no revocation log found")
// ErrOutputIndexOutOfRange is returned when an output index is greater
// than or equal to the length of a given transaction's outputs.
ErrOutputIndexOutOfRange = errors.New("output index is out of range")
// ErrRevLogDataMissing is returned when a certain wanted optional field
// in a revocation log entry is missing.
ErrRevLogDataMissing = errors.New("revocation log data missing")
// ErrForceCloseLocalDataLoss is returned in the case a user (or
// another sub-system) attempts to force close when we've detected that
// we've likely lost data ourselves.
ErrForceCloseLocalDataLoss = errors.New("cannot force close " +
"channel with local data loss")
// errNoNonce is returned when a nonce is required, but none is found.
errNoNonce = errors.New("no nonce found")
// errNoPartialSig is returned when a partial signature is required,
// but none is found.
errNoPartialSig = errors.New("no partial signature found")
// errQuit is returned when a quit signal was received, interrupting the
// current operation.
errQuit = errors.New("received quit signal")
)
// ErrCommitSyncLocalDataLoss is returned in the case that we receive a valid
// commit secret within the ChannelReestablish message from the remote node AND
// they advertise a RemoteCommitTailHeight higher than our current known
// height. This means we have lost some critical data, and must fail the
// channel and MUST NOT force close it. Instead we should wait for the remote
// to force close it, such that we can attempt to sweep our funds. The
// commitment point needed to sweep the remote's force close is encapsulated.
type ErrCommitSyncLocalDataLoss struct {
// ChannelPoint is the identifier for the channel that experienced data
// loss.
ChannelPoint wire.OutPoint
// CommitPoint is the last unrevoked commit point, sent to us by the
// remote when we determined we had lost state.
CommitPoint *btcec.PublicKey
}
// Error returns a string representation of the local data loss error.
func (e *ErrCommitSyncLocalDataLoss) Error() string {
return fmt.Sprintf("ChannelPoint(%v) with CommitPoint(%x) had "+
"possible local commitment state data loss", e.ChannelPoint,
e.CommitPoint.SerializeCompressed())
}
// PaymentHash represents the sha256 of a random value. This hash is used to
// uniquely track incoming/outgoing payments within this channel, as well as
// payments requested by the wallet/daemon.
type PaymentHash [32]byte
// commitment represents a commitment to a new state within an active channel.
// New commitments can be initiated by either side. Commitments are ordered
// into a commitment chain, with one existing for both parties. Each side can
// independently extend the other side's commitment chain, up to a certain
// "revocation window", which once reached, disallows new commitments until
// the local nodes receives the revocation for the remote node's chain tail.
type commitment struct {
// height represents the commitment height of this commitment, or the
// update number of this commitment.
height uint64
// whoseCommit indicates whether this is the local or remote node's
// version of the commitment.
whoseCommit lntypes.ChannelParty
// [our|their]MessageIndex are indexes into the HTLC log, up to which
// this commitment transaction includes. These indexes allow both sides
// to independently, and concurrent send create new commitments. Each
// new commitment sent to the remote party includes an index in the
// shared log which details which of their updates we're including in
// this new commitment.
messageIndices lntypes.Dual[uint64]
// [our|their]HtlcIndex are the current running counters for the HTLCs
// offered by either party. This value is incremented each time a party
// offers a new HTLC. The log update methods that consume HTLCs will
// reference these counters, rather than the running cumulative message
// counters.
ourHtlcIndex uint64
theirHtlcIndex uint64
// txn is the commitment transaction generated by including any HTLC
// updates whose index are below the two indexes listed above. If this
// commitment is being added to the remote chain, then this txn is
// their version of the commitment transactions. If the local commit
// chain is being modified, the opposite is true.
txn *wire.MsgTx
// sig is a signature for the above commitment transaction.
sig []byte
// [our|their]Balance represents the settled balances at this point
// within the commitment chain. This balance is computed by properly
// evaluating all the add/remove/settle log entries before the listed
// indexes.
//
// NOTE: This is the balance *after* subtracting any commitment fee,
// AND anchor output values.
ourBalance lnwire.MilliSatoshi
theirBalance lnwire.MilliSatoshi
// fee is the amount that will be paid as fees for this commitment
// transaction. The fee is recorded here so that it can be added back
// and recalculated for each new update to the channel state.
fee btcutil.Amount
// feePerKw is the fee per kw used to calculate this commitment
// transaction's fee.
feePerKw chainfee.SatPerKWeight
// dustLimit is the limit on the commitment transaction such that no
// output values should be below this amount.
dustLimit btcutil.Amount
// outgoingHTLCs is a slice of all the outgoing HTLC's (from our PoV)
// on this commitment transaction.
outgoingHTLCs []paymentDescriptor
// incomingHTLCs is a slice of all the incoming HTLC's (from our PoV)
// on this commitment transaction.
incomingHTLCs []paymentDescriptor
// customBlob stores opaque bytes that may be used by custom channels
// to store extra data for a given commitment state.
customBlob fn.Option[tlv.Blob]
// [outgoing|incoming]HTLCIndex is an index that maps an output index
// on the commitment transaction to the payment descriptor that
// represents the HTLC output.
//
// NOTE: that these fields are only populated if this commitment state
// belongs to the local node. These maps are used when validating any
// HTLC signatures which are part of the local commitment state. We use
// this map in order to locate the details needed to validate an HTLC
// signature while iterating of the outputs in the local commitment
// view.
outgoingHTLCIndex map[int32]*paymentDescriptor
incomingHTLCIndex map[int32]*paymentDescriptor
}
// locateOutputIndex is a small helper function to locate the output index of a
// particular HTLC within the current commitment transaction. The duplicate map
// passed in is to be retained for each output within the commitment
// transition. This ensures that we don't assign multiple HTLCs to the same
// index within the commitment transaction.
func locateOutputIndex(p *paymentDescriptor, tx *wire.MsgTx,
whoseCommit lntypes.ChannelParty, dups map[PaymentHash][]int32,
cltvs []uint32) (int32, error) {
// If this is their commitment transaction, we'll be trying to locate
// their pkScripts, otherwise we'll be looking for ours. This is
// required as the commitment states are asymmetric in order to ascribe
// blame in the case of a contract breach.
pkScript := p.theirPkScript
if whoseCommit.IsLocal() {
pkScript = p.ourPkScript
}
for i, txOut := range tx.TxOut {
cltv := cltvs[i]
if bytes.Equal(txOut.PkScript, pkScript) &&
txOut.Value == int64(p.Amount.ToSatoshis()) &&
cltv == p.Timeout {
// If this payment hash and index has already been
// found, then we'll continue in order to avoid any
// duplicate indexes.
if fn.Elem(int32(i), dups[p.RHash]) {
continue
}
idx := int32(i)
dups[p.RHash] = append(dups[p.RHash], idx)
return idx, nil
}
}
return 0, fmt.Errorf("unable to find htlc: script=%x, value=%v, "+
"cltv=%v", pkScript, p.Amount, p.Timeout)
}
// populateHtlcIndexes modifies the set of HTLCs locked-into the target view
// to have full indexing information populated. This information is required as
// we need to keep track of the indexes of each HTLC in order to properly write
// the current state to disk, and also to locate the paymentDescriptor
// corresponding to HTLC outputs in the commitment transaction.
func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType,
cltvs []uint32) error {
// First, we'll set up some state to allow us to locate the output
// index of the all the HTLCs within the commitment transaction. We
// must keep this index so we can validate the HTLC signatures sent to
// us.
dups := make(map[PaymentHash][]int32)
c.outgoingHTLCIndex = make(map[int32]*paymentDescriptor)
c.incomingHTLCIndex = make(map[int32]*paymentDescriptor)
// populateIndex is a helper function that populates the necessary
// indexes within the commitment view for a particular HTLC.
populateIndex := func(htlc *paymentDescriptor, incoming bool) error {
isDust := HtlcIsDust(
chanType, incoming, c.whoseCommit, c.feePerKw,
htlc.Amount.ToSatoshis(), c.dustLimit,
)
var err error
switch {
// If this is our commitment transaction, and this is a dust
// output then we mark it as such using a -1 index.
case c.whoseCommit.IsLocal() && isDust:
htlc.localOutputIndex = -1
// If this is the commitment transaction of the remote party,
// and this is a dust output then we mark it as such using a -1
// index.
case c.whoseCommit.IsRemote() && isDust:
htlc.remoteOutputIndex = -1
// If this is our commitment transaction, then we'll need to
// locate the output and the index so we can verify an HTLC
// signatures.
case c.whoseCommit.IsLocal():
htlc.localOutputIndex, err = locateOutputIndex(
htlc, c.txn, c.whoseCommit, dups, cltvs,
)
if err != nil {
return err
}
// As this is our commitment transactions, we need to
// keep track of the locations of each output on the
// transaction so we can verify any HTLC signatures
// sent to us after we construct the HTLC view.
if incoming {
c.incomingHTLCIndex[htlc.localOutputIndex] = htlc
} else {
c.outgoingHTLCIndex[htlc.localOutputIndex] = htlc
}
// Otherwise, this is there remote party's commitment
// transaction and we only need to populate the remote output
// index within the HTLC index.
case c.whoseCommit.IsRemote():
htlc.remoteOutputIndex, err = locateOutputIndex(
htlc, c.txn, c.whoseCommit, dups, cltvs,
)
if err != nil {
return err
}
default:
return fmt.Errorf("invalid commitment configuration")
}
return nil
}
// Finally, we'll need to locate the index within the commitment
// transaction of all the HTLC outputs. This index will be required
// later when we write the commitment state to disk, and also when
// generating signatures for each of the HTLC transactions.
for i := 0; i < len(c.outgoingHTLCs); i++ {
htlc := &c.outgoingHTLCs[i]
if err := populateIndex(htlc, false); err != nil {
return err
}
}
for i := 0; i < len(c.incomingHTLCs); i++ {
htlc := &c.incomingHTLCs[i]
if err := populateIndex(htlc, true); err != nil {
return err
}
}
return nil
}
// toDiskCommit converts the target commitment into a format suitable to be
// written to disk after an accepted state transition.
func (c *commitment) toDiskCommit(
whoseCommit lntypes.ChannelParty) *channeldb.ChannelCommitment {
numHtlcs := len(c.outgoingHTLCs) + len(c.incomingHTLCs)
commit := &channeldb.ChannelCommitment{
CommitHeight: c.height,
LocalLogIndex: c.messageIndices.Local,
LocalHtlcIndex: c.ourHtlcIndex,
RemoteLogIndex: c.messageIndices.Remote,
RemoteHtlcIndex: c.theirHtlcIndex,
LocalBalance: c.ourBalance,
RemoteBalance: c.theirBalance,
CommitFee: c.fee,
FeePerKw: btcutil.Amount(c.feePerKw),
CommitTx: c.txn,
CommitSig: c.sig,
Htlcs: make([]channeldb.HTLC, 0, numHtlcs),
CustomBlob: c.customBlob,
}
for _, htlc := range c.outgoingHTLCs {
outputIndex := htlc.localOutputIndex
if whoseCommit.IsRemote() {
outputIndex = htlc.remoteOutputIndex
}
h := channeldb.HTLC{
RHash: htlc.RHash,
Amt: htlc.Amount,
RefundTimeout: htlc.Timeout,
OutputIndex: outputIndex,
HtlcIndex: htlc.HtlcIndex,
LogIndex: htlc.LogIndex,
Incoming: false,
OnionBlob: htlc.OnionBlob,
BlindingPoint: htlc.BlindingPoint,
CustomRecords: htlc.CustomRecords.Copy(),
}
if whoseCommit.IsLocal() && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
}
commit.Htlcs = append(commit.Htlcs, h)
}
for _, htlc := range c.incomingHTLCs {
outputIndex := htlc.localOutputIndex
if whoseCommit.IsRemote() {
outputIndex = htlc.remoteOutputIndex
}
h := channeldb.HTLC{
RHash: htlc.RHash,
Amt: htlc.Amount,
RefundTimeout: htlc.Timeout,
OutputIndex: outputIndex,
HtlcIndex: htlc.HtlcIndex,
LogIndex: htlc.LogIndex,
Incoming: true,
OnionBlob: htlc.OnionBlob,
BlindingPoint: htlc.BlindingPoint,
CustomRecords: htlc.CustomRecords.Copy(),
}
if whoseCommit.IsLocal() && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
}
commit.Htlcs = append(commit.Htlcs, h)
}
return commit
}
// diskHtlcToPayDesc converts an HTLC previously written to disk within a
// commitment state to the form required to manipulate in memory within the
// commitment struct and updateLog. This function is used when we need to
// restore commitment state written to disk back into memory once we need to
// restart a channel session.
func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
htlc *channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing],
whoseCommit lntypes.ChannelParty,
auxLeaf input.AuxTapLeaf) (paymentDescriptor, error) {
// The proper pkScripts for this paymentDescriptor must be
// generated so we can easily locate them within the commitment
// transaction in the future.
var (
ourP2WSH, theirP2WSH []byte
ourWitnessScript, theirWitnessScript []byte
pd paymentDescriptor
chanType = lc.channelState.ChanType
)
// If the either output is dust from the local or remote node's
// perspective, then we don't need to generate the scripts as we only
// generate them in order to locate the outputs within the commitment
// transaction. As we'll mark dust with a special output index in the
// on-disk state snapshot.
isDustLocal := HtlcIsDust(
chanType, htlc.Incoming, lntypes.Local, feeRate,
htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit,
)
localCommitKeys := commitKeys.GetForParty(lntypes.Local)
if !isDustLocal && localCommitKeys != nil {
scriptInfo, err := genHtlcScript(
chanType, htlc.Incoming, lntypes.Local,
htlc.RefundTimeout, htlc.RHash, localCommitKeys,
auxLeaf,
)
if err != nil {
return pd, err
}
ourP2WSH = scriptInfo.PkScript()
ourWitnessScript = scriptInfo.WitnessScriptToSign()
}
isDustRemote := HtlcIsDust(
chanType, htlc.Incoming, lntypes.Remote, feeRate,
htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit,
)
remoteCommitKeys := commitKeys.GetForParty(lntypes.Remote)
if !isDustRemote && remoteCommitKeys != nil {
scriptInfo, err := genHtlcScript(
chanType, htlc.Incoming, lntypes.Remote,
htlc.RefundTimeout, htlc.RHash, remoteCommitKeys,
auxLeaf,
)
if err != nil {
return pd, err
}
theirP2WSH = scriptInfo.PkScript()
theirWitnessScript = scriptInfo.WitnessScriptToSign()
}
// Reconstruct the proper local/remote output indexes from the HTLC's
// persisted output index depending on whose commitment we are
// generating.
var (
localOutputIndex int32
remoteOutputIndex int32
)
if whoseCommit.IsLocal() {
localOutputIndex = htlc.OutputIndex
} else {
remoteOutputIndex = htlc.OutputIndex
}
// With the scripts reconstructed (depending on if this is our commit
// vs theirs or a pending commit for the remote party), we can now
// re-create the original payment descriptor.
return paymentDescriptor{
ChanID: lc.ChannelID(),
RHash: htlc.RHash,
Timeout: htlc.RefundTimeout,
Amount: htlc.Amt,
EntryType: Add,
HtlcIndex: htlc.HtlcIndex,
LogIndex: htlc.LogIndex,
OnionBlob: htlc.OnionBlob,
localOutputIndex: localOutputIndex,
remoteOutputIndex: remoteOutputIndex,
ourPkScript: ourP2WSH,
ourWitnessScript: ourWitnessScript,
theirPkScript: theirP2WSH,
theirWitnessScript: theirWitnessScript,
BlindingPoint: htlc.BlindingPoint,
CustomRecords: htlc.CustomRecords.Copy(),
}, nil
}
// extractPayDescs will convert all HTLC's present within a disk commit state
// to a set of incoming and outgoing payment descriptors. Once reconstructed,
// these payment descriptors can be re-inserted into the in-memory updateLog
// for each side.
func (lc *LightningChannel) extractPayDescs(feeRate chainfee.SatPerKWeight,
htlcs []channeldb.HTLC, commitKeys lntypes.Dual[*CommitmentKeyRing],
whoseCommit lntypes.ChannelParty,
auxLeaves fn.Option[CommitAuxLeaves]) ([]paymentDescriptor,
[]paymentDescriptor, error) {
var (
incomingHtlcs []paymentDescriptor
outgoingHtlcs []paymentDescriptor
)
// For each included HTLC within this commitment state, we'll convert
// the disk format into our in memory paymentDescriptor format,
// partitioning based on if we offered or received the HTLC.
for _, htlc := range htlcs {
// TODO(roasbeef): set isForwarded to false for all? need to
// persist state w.r.t to if forwarded or not, or can
// inadvertently trigger replays
htlc := htlc
auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
if htlc.Incoming {
leaves = l.IncomingHtlcLeaves
}
return leaves[htlc.HtlcIndex].AuxTapLeaf
},
)(auxLeaves)
payDesc, err := lc.diskHtlcToPayDesc(
feeRate, &htlc, commitKeys, whoseCommit, auxLeaf,
)
if err != nil {
return incomingHtlcs, outgoingHtlcs, err
}
if htlc.Incoming {
incomingHtlcs = append(incomingHtlcs, payDesc)
} else {
outgoingHtlcs = append(outgoingHtlcs, payDesc)
}
}
return incomingHtlcs, outgoingHtlcs, nil
}
// diskCommitToMemCommit converts the on-disk commitment format to our
// in-memory commitment format which is needed in order to properly resume
// channel operations after a restart.
func (lc *LightningChannel) diskCommitToMemCommit(
whoseCommit lntypes.ChannelParty,
diskCommit *channeldb.ChannelCommitment, localCommitPoint,
remoteCommitPoint *btcec.PublicKey) (*commitment, error) {
// First, we'll need to re-derive the commitment key ring for each
// party used within this particular state. If this is a pending commit
// (we extended but weren't able to complete the commitment dance
// before shutdown), then the localCommitPoint won't be set as we
// haven't yet received a responding commitment from the remote party.
var commitKeys lntypes.Dual[*CommitmentKeyRing]
if localCommitPoint != nil {
commitKeys.SetForParty(lntypes.Local, DeriveCommitmentKeys(
localCommitPoint, lntypes.Local,
lc.channelState.ChanType,
&lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
))
}
if remoteCommitPoint != nil {
commitKeys.SetForParty(lntypes.Remote, DeriveCommitmentKeys(
remoteCommitPoint, lntypes.Remote,
lc.channelState.ChanType,
&lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
))
}
auxResult, err := fn.MapOptionZ(
lc.leafStore,
func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(lc.channelState), *diskCommit,
*commitKeys.GetForParty(whoseCommit),
whoseCommit,
)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// With the key rings re-created, we'll now convert all the on-disk
// HTLC"s into paymentDescriptor's so we can re-insert them into our
// update log.
incomingHtlcs, outgoingHtlcs, err := lc.extractPayDescs(
chainfee.SatPerKWeight(diskCommit.FeePerKw),
diskCommit.Htlcs, commitKeys, whoseCommit, auxResult.AuxLeaves,
)
if err != nil {
return nil, err
}
messageIndices := lntypes.Dual[uint64]{
Local: diskCommit.LocalLogIndex,
Remote: diskCommit.RemoteLogIndex,
}
// With the necessary items generated, we'll now re-construct the
// commitment state as it was originally present in memory.
commit := &commitment{
height: diskCommit.CommitHeight,
whoseCommit: whoseCommit,
ourBalance: diskCommit.LocalBalance,
theirBalance: diskCommit.RemoteBalance,
messageIndices: messageIndices,
ourHtlcIndex: diskCommit.LocalHtlcIndex,
theirHtlcIndex: diskCommit.RemoteHtlcIndex,
txn: diskCommit.CommitTx,
sig: diskCommit.CommitSig,
fee: diskCommit.CommitFee,
feePerKw: chainfee.SatPerKWeight(diskCommit.FeePerKw),
incomingHTLCs: incomingHtlcs,
outgoingHTLCs: outgoingHtlcs,
customBlob: diskCommit.CustomBlob,
}
if whoseCommit.IsLocal() {
commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit
} else {
commit.dustLimit = lc.channelState.RemoteChanCfg.DustLimit
}
return commit, nil
}
// LightningChannel implements the state machine which corresponds to the
// current commitment protocol wire spec. The state machine implemented allows
// for asynchronous fully desynchronized, batched+pipelined updates to
// commitment transactions allowing for a high degree of non-blocking
// bi-directional payment throughput.
//
// In order to allow updates to be fully non-blocking, either side is able to
// create multiple new commitment states up to a pre-determined window size.
// This window size is encoded within InitialRevocationWindow. Before the start
// of a session, both side should send out revocation messages with nil
// preimages in order to populate their revocation window for the remote party.
//
// The state machine has for main methods:
// - .SignNextCommitment()
// - Called once when one wishes to sign the next commitment, either
// initiating a new state update, or responding to a received commitment.
// - .ReceiveNewCommitment()
// - Called upon receipt of a new commitment from the remote party. If the
// new commitment is valid, then a revocation should immediately be
// generated and sent.
// - .RevokeCurrentCommitment()
// - Revokes the current commitment. Should be called directly after
// receiving a new commitment.
// - .ReceiveRevocation()
// - Processes a revocation from the remote party. If successful creates a
// new defacto broadcastable state.
//
// See the individual comments within the above methods for further details.
type LightningChannel struct {
// Signer is the main signer instances that will be responsible for
// signing any HTLC and commitment transaction generated by the state
// machine.
Signer input.Signer
// leafStore is used to retrieve extra tapscript leaves for special
// custom channel types.
leafStore fn.Option[AuxLeafStore]
// signDesc is the primary sign descriptor that is capable of signing
// the commitment transaction that spends the multi-sig output.
signDesc *input.SignDescriptor
isClosed bool
// sigPool is a pool of workers that are capable of signing and
// validating signatures in parallel. This is utilized as an
// optimization to void serially signing or validating the HTLC
// signatures, of which there may be hundreds.
sigPool *SigPool
// auxSigner is a special signer used to obtain opaque signatures for
// custom channel variants.
auxSigner fn.Option[AuxSigner]
// auxResolver is an optional component that can be used to modify the
// way contracts are resolved.
auxResolver fn.Option[AuxContractResolver]
// Capacity is the total capacity of this channel.
Capacity btcutil.Amount
// currentHeight is the current height of our local commitment chain.
// This is also the same as the number of updates to the channel we've
// accepted.
currentHeight uint64
// commitChains is a Dual of the local and remote node's commitment
// chains. Any new commitments we initiate are added to Remote chain's
// tip. The Local portion of this field is our local commitment chain.
// Any new commitments received are added to the tip of this chain.
// The tail (or lowest height) in this chain is our current accepted
// state, which we are able to broadcast safely.
commitChains lntypes.Dual[*commitmentChain]
channelState *channeldb.OpenChannel
commitBuilder *CommitmentBuilder
// [local|remote]Log is a (mostly) append-only log storing all the HTLC
// updates to this channel. The log is walked backwards as HTLC updates
// are applied in order to re-construct a commitment transaction from a
// commitment. The log is compacted once a revocation is received.
updateLogs lntypes.Dual[*updateLog]
// log is a channel-specific logging instance.
log btclog.Logger
// taprootNonceProducer is used to generate a shachain tree for the
// purpose of generating verification nonces for taproot channels.
taprootNonceProducer shachain.Producer
// musigSessions holds the current musig2 pair session for the channel.
musigSessions *MusigPairSession
// pendingVerificationNonce is the initial verification nonce generated
// for musig2 channels when the state machine is intiated. Once we know
// the verification nonce of the remote party, then we can start to use
// the channel as normal.
pendingVerificationNonce *musig2.Nonces
// fundingOutput is the funding output (script+value).
fundingOutput wire.TxOut
// opts is the set of options that channel was initialized with.
opts *channelOpts
sync.RWMutex
}
// ChannelOpt is a functional option that lets callers modify how a new channel
// is created.
type ChannelOpt func(*channelOpts)
// channelOpts is the set of options used to create a new channel.
type channelOpts struct {
localNonce *musig2.Nonces
remoteNonce *musig2.Nonces
leafStore fn.Option[AuxLeafStore]
auxSigner fn.Option[AuxSigner]
auxResolver fn.Option[AuxContractResolver]
skipNonceInit bool
}
// WithLocalMusigNonces is used to bind an existing verification/local nonce to
// a new channel.
func WithLocalMusigNonces(nonce *musig2.Nonces) ChannelOpt {
return func(o *channelOpts) {
o.localNonce = nonce
}
}
// WithRemoteMusigNonces is used to bind the remote party's local/verification
// nonce to a new channel.
func WithRemoteMusigNonces(nonces *musig2.Nonces) ChannelOpt {
return func(o *channelOpts) {
o.remoteNonce = nonces
}
}
// WithSkipNonceInit is used to modify the way nonces are handled during
// channel initialization for taproot channels. If this option is specified,
// then when we receive the chan reest message from the remote party, we won't
// modify our nonce state. This is needed if we create a channel, get a channel
// ready message, then also get the chan reest message after that.
func WithSkipNonceInit() ChannelOpt {
return func(o *channelOpts) {
o.skipNonceInit = true
}
}
// WithLeafStore is used to specify a custom leaf store for the channel.
func WithLeafStore(store AuxLeafStore) ChannelOpt {
return func(o *channelOpts) {
o.leafStore = fn.Some[AuxLeafStore](store)
}
}
// WithAuxSigner is used to specify a custom aux signer for the channel.
func WithAuxSigner(signer AuxSigner) ChannelOpt {
return func(o *channelOpts) {
o.auxSigner = fn.Some[AuxSigner](signer)
}
}
// WithAuxResolver is used to specify a custom aux contract resolver for the
// channel.
func WithAuxResolver(resolver AuxContractResolver) ChannelOpt {
return func(o *channelOpts) {
o.auxResolver = fn.Some[AuxContractResolver](resolver)
}
}
// defaultChannelOpts returns the set of default options for a new channel.
func defaultChannelOpts() *channelOpts {
return &channelOpts{}
}
// NewLightningChannel creates a new, active payment channel given an
// implementation of the chain notifier, channel database, and the current
// settled channel state. Throughout state transitions, then channel will
// automatically persist pertinent state to the database in an efficient
// manner.
func NewLightningChannel(signer input.Signer,
state *channeldb.OpenChannel,
sigPool *SigPool, chanOpts ...ChannelOpt) (*LightningChannel, error) {
opts := defaultChannelOpts()
for _, optFunc := range chanOpts {
optFunc(opts)
}
localCommit := state.LocalCommitment
remoteCommit := state.RemoteCommitment
// First, initialize the update logs with their current counter values
// from the local and remote commitments.
localUpdateLog := newUpdateLog(
remoteCommit.LocalLogIndex, remoteCommit.LocalHtlcIndex,
)
remoteUpdateLog := newUpdateLog(
localCommit.RemoteLogIndex, localCommit.RemoteHtlcIndex,
)
updateLogs := lntypes.Dual[*updateLog]{
Local: localUpdateLog,
Remote: remoteUpdateLog,
}
logPrefix := fmt.Sprintf("ChannelPoint(%v):", state.FundingOutpoint)
taprootNonceProducer, err := channeldb.DeriveMusig2Shachain(
state.RevocationProducer,
)
if err != nil {
return nil, fmt.Errorf("unable to derive shachain: %w", err)
}
commitChains := lntypes.Dual[*commitmentChain]{
Local: newCommitmentChain(),
Remote: newCommitmentChain(),
}
lc := &LightningChannel{
Signer: signer,
leafStore: opts.leafStore,
auxSigner: opts.auxSigner,
auxResolver: opts.auxResolver,
sigPool: sigPool,
currentHeight: localCommit.CommitHeight,
commitChains: commitChains,
channelState: state,
commitBuilder: NewCommitmentBuilder(
state, opts.leafStore,
),
updateLogs: updateLogs,
Capacity: state.Capacity,
taprootNonceProducer: taprootNonceProducer,
log: walletLog.WithPrefix(logPrefix),
opts: opts,
}
switch {
// At this point, we may already have nonces that were passed in, so
// we'll check that now as this lets us skip some steps later.
case state.ChanType.IsTaproot() && opts.localNonce != nil:
lc.pendingVerificationNonce = opts.localNonce
// Otherwise, we'll generate the nonces here ourselves. This ensures
// we'll be ablve to process the chan syncmessag efrom the remote
// party.
case state.ChanType.IsTaproot() && opts.localNonce == nil:
_, err := lc.GenMusigNonces()
if err != nil {
return nil, err
}
}
if lc.pendingVerificationNonce != nil && opts.remoteNonce != nil {
err := lc.InitRemoteMusigNonces(opts.remoteNonce)
if err != nil {
return nil, err
}
}
// With the main channel struct reconstructed, we'll now restore the
// commitment state in memory and also the update logs themselves.
err = lc.restoreCommitState(&localCommit, &remoteCommit)
if err != nil {
return nil, err
}
// Create the sign descriptor which we'll be using very frequently to
// request a signature for the 2-of-2 multi-sig from the signer in
// order to complete channel state transitions.
if err := lc.createSignDesc(); err != nil {
return nil, err
}
return lc, nil
}
// createSignDesc derives the SignDescriptor for commitment transactions from
// other fields on the LightningChannel.
func (lc *LightningChannel) createSignDesc() error {
var (
fundingPkScript, multiSigScript []byte
err error
)
chanState := lc.channelState
localKey := chanState.LocalChanCfg.MultiSigKey.PubKey
remoteKey := chanState.RemoteChanCfg.MultiSigKey.PubKey
if chanState.ChanType.IsTaproot() {
fundingPkScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, int64(lc.channelState.Capacity),
chanState.TapscriptRoot,
)
if err != nil {
return err
}
} else {
multiSigScript, err = input.GenMultiSigScript(
localKey.SerializeCompressed(),
remoteKey.SerializeCompressed(),
)
if err != nil {
return err
}
fundingPkScript, err = input.WitnessScriptHash(multiSigScript)
if err != nil {
return err
}
}
lc.fundingOutput = wire.TxOut{
PkScript: fundingPkScript,
Value: int64(lc.channelState.Capacity),
}
lc.signDesc = &input.SignDescriptor{
KeyDesc: lc.channelState.LocalChanCfg.MultiSigKey,
WitnessScript: multiSigScript,
Output: &lc.fundingOutput,
HashType: txscript.SigHashAll,
InputIndex: 0,
}
return nil
}
// ResetState resets the state of the channel back to the default state. This
// ensures that any active goroutines which need to act based on on-chain
// events do so properly.
func (lc *LightningChannel) ResetState() {
lc.Lock()
lc.isClosed = false
lc.Unlock()
}
// logUpdateToPayDesc converts a LogUpdate into a matching paymentDescriptor
// entry that can be re-inserted into the update log. This method is used when
// we extended a state to the remote party, but the connection was obstructed
// before we could finish the commitment dance. In this case, we need to
// re-insert the original entries back into the update log so we can resume as
// if nothing happened.
func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
remoteUpdateLog *updateLog, commitHeight uint64,
feeRate chainfee.SatPerKWeight, remoteCommitKeys *CommitmentKeyRing,
remoteDustLimit btcutil.Amount,
auxLeaves fn.Option[CommitAuxLeaves]) (*paymentDescriptor, error) {
// Depending on the type of update message we'll map that to a distinct
// paymentDescriptor instance.
var pd *paymentDescriptor
switch wireMsg := logUpdate.UpdateMsg.(type) {
// For offered HTLC's, we'll map that to a paymentDescriptor with the
// type Add, ensuring we restore the necessary fields. From the PoV of
// the commitment chain, this HTLC was included in the remote chain,
// but not the local chain.
case *lnwire.UpdateAddHTLC:
// First, we'll map all the relevant fields in the
// UpdateAddHTLC message to their corresponding fields in the
// paymentDescriptor struct. We also set addCommitHeightRemote
// as we've included this HTLC in our local commitment chain
// for the remote party.
pd = &paymentDescriptor{
ChanID: wireMsg.ChanID,
RHash: wireMsg.PaymentHash,
Timeout: wireMsg.Expiry,
Amount: wireMsg.Amount,
EntryType: Add,
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
OnionBlob: wireMsg.OnionBlob,
BlindingPoint: wireMsg.BlindingPoint,
CustomRecords: wireMsg.CustomRecords.Copy(),
addCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}
isDustRemote := HtlcIsDust(
lc.channelState.ChanType, false, lntypes.Remote,
feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit,
)
if !isDustRemote {
auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[pd.HtlcIndex].AuxTapLeaf
},
)(auxLeaves)
scriptInfo, err := genHtlcScript(
lc.channelState.ChanType, false, lntypes.Remote,
wireMsg.Expiry, wireMsg.PaymentHash,
remoteCommitKeys, auxLeaf,
)
if err != nil {
return nil, err
}
pd.theirPkScript = scriptInfo.PkScript()
pd.theirWitnessScript = scriptInfo.WitnessScriptToSign()
}
// For HTLC's we're offered we'll fetch the original offered HTLC
// from the remote party's update log so we can retrieve the same
// paymentDescriptor that SettleHTLC would produce.
case *lnwire.UpdateFulfillHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
pd = &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
RPreimage: wireMsg.PaymentPreimage,
LogIndex: logUpdate.LogIndex,
ParentIndex: ogHTLC.HtlcIndex,
EntryType: Settle,
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}
// If we sent a failure for a prior incoming HTLC, then we'll consult
// the update log of the remote party so we can retrieve the
// information of the original HTLC we're failing. We also set the
// removal height for the remote commitment.
case *lnwire.UpdateFailHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
pd = &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: Fail,
FailReason: wireMsg.Reason[:],
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}
// HTLC fails due to malformed onion blobs are treated the exact same
// way as regular HTLC fails.
case *lnwire.UpdateFailMalformedHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
// TODO(roasbeef): err if nil?
pd = &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: MalformedFail,
FailCode: wireMsg.FailureCode,
ShaOnionBlob: wireMsg.ShaOnionBlob,
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}
// For fee updates we'll create a FeeUpdate type to add to the log. We
// reuse the amount field to hold the fee rate. Since the amount field
// is denominated in msat we won't lose precision when storing the
// sat/kw denominated feerate. Note that we set both the add and remove
// height to the same value, as we consider the fee update locked in by
// adding and removing it at the same height.
case *lnwire.UpdateFee:
pd = &paymentDescriptor{
ChanID: wireMsg.ChanID,
LogIndex: logUpdate.LogIndex,
Amount: lnwire.NewMSatFromSatoshis(
btcutil.Amount(wireMsg.FeePerKw),
),
EntryType: FeeUpdate,
addCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}
}
return pd, nil
}
// localLogUpdateToPayDesc converts a LogUpdate into a matching
// paymentDescriptor entry that can be re-inserted into the local update log.
// This method is used when we sent an update+sig, receive a revocation, but
// drop right before the counterparty can sign for the update we just sent. In
// this case, we need to re-insert the original entries back into the update
// log so we'll be expecting the peer to sign them. The height of the remote
// commitment is expected to be provided and we restore all log update entries
// with this height, even though the real height may be lower. In the way these
// fields are used elsewhere, this doesn't change anything.
func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
remoteUpdateLog *updateLog, commitHeight uint64) (*paymentDescriptor,
error) {
// Since Add updates aren't saved to disk under this key, the update will
// never be an Add.
switch wireMsg := logUpdate.UpdateMsg.(type) {
// For HTLCs that we settled, we'll fetch the original offered HTLC from
// the remote update log so we can retrieve the same paymentDescriptor
// that ReceiveHTLCSettle would produce.
case *lnwire.UpdateFulfillHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
RPreimage: wireMsg.PaymentPreimage,
LogIndex: logUpdate.LogIndex,
ParentIndex: ogHTLC.HtlcIndex,
EntryType: Settle,
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}, nil
// If we sent a failure for a prior incoming HTLC, then we'll consult the
// remote update log so we can retrieve the information of the original
// HTLC we're failing.
case *lnwire.UpdateFailHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: Fail,
FailReason: wireMsg.Reason[:],
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}, nil
// HTLC fails due to malformed onion blocks are treated the exact same
// way as regular HTLC fails.
case *lnwire.UpdateFailMalformedHTLC:
ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: MalformedFail,
FailCode: wireMsg.FailureCode,
ShaOnionBlob: wireMsg.ShaOnionBlob,
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}, nil
case *lnwire.UpdateFee:
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
LogIndex: logUpdate.LogIndex,
Amount: lnwire.NewMSatFromSatoshis(
btcutil.Amount(wireMsg.FeePerKw),
),
EntryType: FeeUpdate,
addCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
removeCommitHeights: lntypes.Dual[uint64]{
Remote: commitHeight,
},
}, nil
default:
return nil, fmt.Errorf("unknown message type: %T", wireMsg)
}
}
// remoteLogUpdateToPayDesc converts a LogUpdate into a matching
// paymentDescriptor entry that can be re-inserted into the update log. This
// method is used when we revoked a local commitment, but the connection was
// obstructed before we could sign a remote commitment that contains these
// updates. In this case, we need to re-insert the original entries back into
// the update log so we can resume as if nothing happened. The height of the
// latest local commitment is also expected to be provided. We are restoring all
// log update entries with this height, even though the real commitment height
// may be lower. In the way these fields are used elsewhere, this doesn't change
// anything.
func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
localUpdateLog *updateLog, commitHeight uint64) (*paymentDescriptor,
error) {
switch wireMsg := logUpdate.UpdateMsg.(type) {
case *lnwire.UpdateAddHTLC:
pd := &paymentDescriptor{
ChanID: wireMsg.ChanID,
RHash: wireMsg.PaymentHash,
Timeout: wireMsg.Expiry,
Amount: wireMsg.Amount,
EntryType: Add,
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
OnionBlob: wireMsg.OnionBlob,
BlindingPoint: wireMsg.BlindingPoint,
CustomRecords: wireMsg.CustomRecords.Copy(),
addCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
}
// We don't need to generate an htlc script yet. This will be
// done once we sign our remote commitment.
return pd, nil
// For HTLCs that the remote party settled, we'll fetch the original
// offered HTLC from the local update log so we can retrieve the same
// paymentDescriptor that ReceiveHTLCSettle would produce.
case *lnwire.UpdateFulfillHTLC:
ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
RPreimage: wireMsg.PaymentPreimage,
LogIndex: logUpdate.LogIndex,
ParentIndex: ogHTLC.HtlcIndex,
EntryType: Settle,
removeCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
}, nil
// If we received a failure for a prior outgoing HTLC, then we'll
// consult the local update log so we can retrieve the information of
// the original HTLC we're failing.
case *lnwire.UpdateFailHTLC:
ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: Fail,
FailReason: wireMsg.Reason[:],
removeCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
}, nil
// HTLC fails due to malformed onion blobs are treated the exact same
// way as regular HTLC fails.
case *lnwire.UpdateFailMalformedHTLC:
ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID)
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
Amount: ogHTLC.Amount,
RHash: ogHTLC.RHash,
ParentIndex: ogHTLC.HtlcIndex,
LogIndex: logUpdate.LogIndex,
EntryType: MalformedFail,
FailCode: wireMsg.FailureCode,
ShaOnionBlob: wireMsg.ShaOnionBlob,
removeCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
}, nil
// For fee updates we'll create a FeeUpdate type to add to the log. We
// reuse the amount field to hold the fee rate. Since the amount field
// is denominated in msat we won't lose precision when storing the
// sat/kw denominated feerate. Note that we set both the add and remove
// height to the same value, as we consider the fee update locked in by
// adding and removing it at the same height.
case *lnwire.UpdateFee:
return &paymentDescriptor{
ChanID: wireMsg.ChanID,
LogIndex: logUpdate.LogIndex,
Amount: lnwire.NewMSatFromSatoshis(
btcutil.Amount(wireMsg.FeePerKw),
),
EntryType: FeeUpdate,
addCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
removeCommitHeights: lntypes.Dual[uint64]{
Local: commitHeight,
},
}, nil
default:
return nil, errors.New("unknown message type")
}
}
// restoreCommitState will restore the local commitment chain and updateLog
// state to a consistent in-memory representation of the passed disk commitment.
// This method is to be used upon reconnection to our channel counter party.
// Once the connection has been established, we'll prepare our in memory state
// to re-sync states with the remote party, and also verify/extend new proposed
// commitment states.
func (lc *LightningChannel) restoreCommitState(
localCommitState, remoteCommitState *channeldb.ChannelCommitment) error {
// In order to reconstruct the pkScripts on each of the pending HTLC
// outputs (if any) we'll need to regenerate the current revocation for
// this current un-revoked state as well as retrieve the current
// revocation for the remote party.
ourRevPreImage, err := lc.channelState.RevocationProducer.AtIndex(
lc.currentHeight,
)
if err != nil {
return err
}
localCommitPoint := input.ComputeCommitmentPoint(ourRevPreImage[:])
remoteCommitPoint := lc.channelState.RemoteCurrentRevocation
// With the revocation state reconstructed, we can now convert the disk
// commitment into our in-memory commitment format, inserting it into
// the local commitment chain.
localCommit, err := lc.diskCommitToMemCommit(
lntypes.Local, localCommitState, localCommitPoint,
remoteCommitPoint,
)
if err != nil {
return err
}
lc.commitChains.Local.addCommitment(localCommit)
lc.log.Tracef("starting local commitment: %v",
lnutils.SpewLogClosure(lc.commitChains.Local.tail()))
// We'll also do the same for the remote commitment chain.
remoteCommit, err := lc.diskCommitToMemCommit(
lntypes.Remote, remoteCommitState, localCommitPoint,
remoteCommitPoint,
)
if err != nil {
return err
}
lc.commitChains.Remote.addCommitment(remoteCommit)
lc.log.Tracef("starting remote commitment: %v",
lnutils.SpewLogClosure(lc.commitChains.Remote.tail()))
var (
pendingRemoteCommit *commitment
pendingRemoteCommitDiff *channeldb.CommitDiff
pendingRemoteKeyChain *CommitmentKeyRing
)
// Next, we'll check to see if we have an un-acked commitment state we
// extended to the remote party but which was never ACK'd.
pendingRemoteCommitDiff, err = lc.channelState.RemoteCommitChainTip()
if err != nil && err != channeldb.ErrNoPendingCommit {
return err
}
if pendingRemoteCommitDiff != nil {
// If we have a pending remote commitment, then we'll also
// reconstruct the original commitment for that state,
// inserting it into the remote party's commitment chain. We
// don't pass our commit point as we don't have the
// corresponding state for the local commitment chain.
pendingCommitPoint := lc.channelState.RemoteNextRevocation
pendingRemoteCommit, err = lc.diskCommitToMemCommit(
lntypes.Remote, &pendingRemoteCommitDiff.Commitment,
nil, pendingCommitPoint,
)
if err != nil {
return err
}
lc.commitChains.Remote.addCommitment(pendingRemoteCommit)
lc.log.Debugf("pending remote commitment: %v",
lnutils.SpewLogClosure(lc.commitChains.Remote.tip()))
// We'll also re-create the set of commitment keys needed to
// fully re-derive the state.
pendingRemoteKeyChain = DeriveCommitmentKeys(
pendingCommitPoint, lntypes.Remote,
lc.channelState.ChanType,
&lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
)
}
// Fetch remote updates that we have acked but not yet signed for.
unsignedAckedUpdates, err := lc.channelState.UnsignedAckedUpdates()
if err != nil {
return err
}
// Fetch the local updates the peer still needs to sign for.
remoteUnsignedLocalUpdates, err := lc.channelState.RemoteUnsignedLocalUpdates()
if err != nil {
return err
}
// Finally, with the commitment states restored, we'll now restore the
// state logs based on the current local+remote commit, and any pending
// remote commit that exists.
err = lc.restoreStateLogs(
localCommit, remoteCommit, pendingRemoteCommit,
pendingRemoteCommitDiff, pendingRemoteKeyChain,
unsignedAckedUpdates, remoteUnsignedLocalUpdates,
)
if err != nil {
return err
}
return nil
}
// restoreStateLogs runs through the current locked-in HTLCs from the point of
// view of the channel and insert corresponding log entries (both local and
// remote) for each HTLC read from disk. This method is required to sync the
// in-memory state of the state machine with that read from persistent storage.
func (lc *LightningChannel) restoreStateLogs(
localCommitment, remoteCommitment, pendingRemoteCommit *commitment,
pendingRemoteCommitDiff *channeldb.CommitDiff,
pendingRemoteKeys *CommitmentKeyRing,
unsignedAckedUpdates,
remoteUnsignedLocalUpdates []channeldb.LogUpdate) error {
// We make a map of incoming HTLCs to the height of the remote
// commitment they were first added, and outgoing HTLCs to the height
// of the local commit they were first added. This will be used when we
// restore the update logs below.
incomingRemoteAddHeights := make(map[uint64]uint64)
outgoingLocalAddHeights := make(map[uint64]uint64)
// We start by setting the height of the incoming HTLCs on the pending
// remote commitment. We set these heights first since if there are
// duplicates, these will be overwritten by the lower height of the
// remoteCommitment below.
if pendingRemoteCommit != nil {
for _, r := range pendingRemoteCommit.incomingHTLCs {
incomingRemoteAddHeights[r.HtlcIndex] =
pendingRemoteCommit.height
}
}
// Now set the remote commit height of all incoming HTLCs found on the
// remote commitment.
for _, r := range remoteCommitment.incomingHTLCs {
incomingRemoteAddHeights[r.HtlcIndex] = remoteCommitment.height
}
// And finally we can do the same for the outgoing HTLCs.
for _, l := range localCommitment.outgoingHTLCs {
outgoingLocalAddHeights[l.HtlcIndex] = localCommitment.height
}
// If we have any unsigned acked updates to sign for, then the add is no
// longer on our local commitment, but is still on the remote's commitment.
// <---fail---
// <---sig----
// ----rev--->
// To ensure proper channel operation, we restore the add's addCommitHeightLocal
// field to the height of our local commitment.
for _, logUpdate := range unsignedAckedUpdates {
var htlcIdx uint64
switch wireMsg := logUpdate.UpdateMsg.(type) {
case *lnwire.UpdateFulfillHTLC:
htlcIdx = wireMsg.ID
case *lnwire.UpdateFailHTLC:
htlcIdx = wireMsg.ID
case *lnwire.UpdateFailMalformedHTLC:
htlcIdx = wireMsg.ID
default:
continue
}
// The htlcIdx is stored in the map with the local commitment
// height so the related add's addCommitHeightLocal field can be
// restored.
outgoingLocalAddHeights[htlcIdx] = localCommitment.height
}
// If there are local updates that the peer needs to sign for, then the
// corresponding add is no longer on the remote commitment, but is still on
// our local commitment.
// ----fail--->
// ----sig---->
// <---rev-----
// To ensure proper channel operation, we restore the add's addCommitHeightRemote
// field to the height of the remote commitment.
for _, logUpdate := range remoteUnsignedLocalUpdates {
var htlcIdx uint64
switch wireMsg := logUpdate.UpdateMsg.(type) {
case *lnwire.UpdateFulfillHTLC:
htlcIdx = wireMsg.ID
case *lnwire.UpdateFailHTLC:
htlcIdx = wireMsg.ID
case *lnwire.UpdateFailMalformedHTLC:
htlcIdx = wireMsg.ID
default:
continue
}
// The htlcIdx is stored in the map with the remote commitment
// height so the related add's addCommitHeightRemote field can be
// restored.
incomingRemoteAddHeights[htlcIdx] = remoteCommitment.height
}
// For each incoming HTLC within the local commitment, we add it to the
// remote update log. Since HTLCs are added first to the receiver's
// commitment, we don't have to restore outgoing HTLCs, as they will be
// restored from the remote commitment below.
for i := range localCommitment.incomingHTLCs {
htlc := localCommitment.incomingHTLCs[i]
// We'll need to set the add height of the HTLC. Since it is on
// this local commit, we can use its height as local add
// height. As remote add height we consult the incoming HTLC
// map we created earlier. Note that if this HTLC is not in
// incomingRemoteAddHeights, the remote add height will be set
// to zero, which indicates that it is not added yet.
htlc.addCommitHeights.Local = localCommitment.height
htlc.addCommitHeights.Remote =
incomingRemoteAddHeights[htlc.HtlcIndex]
// Restore the htlc back to the remote log.
lc.updateLogs.Remote.restoreHtlc(&htlc)
}
// Similarly, we'll do the same for the outgoing HTLCs within the
// remote commitment, adding them to the local update log.
for i := range remoteCommitment.outgoingHTLCs {
htlc := remoteCommitment.outgoingHTLCs[i]
// As for the incoming HTLCs, we'll use the current remote
// commit height as remote add height, and consult the map
// created above for the local add height.
htlc.addCommitHeights.Remote = remoteCommitment.height
htlc.addCommitHeights.Local =
outgoingLocalAddHeights[htlc.HtlcIndex]
// Restore the htlc back to the local log.
lc.updateLogs.Local.restoreHtlc(&htlc)
}
// If we have a dangling (un-acked) commit for the remote party, then we
// restore the updates leading up to this commit.
if pendingRemoteCommit != nil {
err := lc.restorePendingLocalUpdates(
pendingRemoteCommitDiff, pendingRemoteKeys,
)
if err != nil {
return err
}
}
// Restore unsigned acked remote log updates so that we can include them
// in our next signature.
err := lc.restorePendingRemoteUpdates(
unsignedAckedUpdates, localCommitment.height,
pendingRemoteCommit,
)
if err != nil {
return err
}
// Restore unsigned acked local log updates so we expect the peer to
// sign for them.
return lc.restorePeerLocalUpdates(
remoteUnsignedLocalUpdates, remoteCommitment.height,
)
}
// restorePendingRemoteUpdates restores the acked remote log updates that we
// haven't yet signed for.
func (lc *LightningChannel) restorePendingRemoteUpdates(
unsignedAckedUpdates []channeldb.LogUpdate,
localCommitmentHeight uint64,
pendingRemoteCommit *commitment) error {
lc.log.Debugf("Restoring %v dangling remote updates pending our sig",
len(unsignedAckedUpdates))
for _, logUpdate := range unsignedAckedUpdates {
logUpdate := logUpdate
payDesc, err := lc.remoteLogUpdateToPayDesc(
&logUpdate, lc.updateLogs.Local, localCommitmentHeight,
)
if err != nil {
return err
}
logIdx := payDesc.LogIndex
// Sanity check that we are not restoring a remote log update
// that we haven't received a sig for.
if logIdx >= lc.updateLogs.Remote.logIndex {
return fmt.Errorf("attempted to restore an "+
"unsigned remote update: log_index=%v",
logIdx)
}
// We previously restored Adds along with all the other updates,
// but this Add restoration was a no-op as every single one of
// these Adds was already restored since they're all incoming
// htlcs on the local commitment.
if payDesc.EntryType == Add {
continue
}
var (
height uint64
heightSet bool
)
// If we have a pending commitment for them, and this update
// is included in that commit, then we'll use this commitment
// height as this commitment will include these updates for
// their new remote commitment.
if pendingRemoteCommit != nil {
if logIdx < pendingRemoteCommit.messageIndices.Remote {
height = pendingRemoteCommit.height
heightSet = true
}
}
// Insert the update into the log. The log update index doesn't
// need to be incremented (hence the restore calls), because its
// final value was properly persisted with the last local
// commitment update.
switch payDesc.EntryType {
case FeeUpdate:
if heightSet {
payDesc.addCommitHeights.Remote = height
payDesc.removeCommitHeights.Remote = height
}
lc.updateLogs.Remote.restoreUpdate(payDesc)
default:
if heightSet {
payDesc.removeCommitHeights.Remote = height
}
lc.updateLogs.Remote.restoreUpdate(payDesc)
lc.updateLogs.Local.markHtlcModified(
payDesc.ParentIndex,
)
}
}
return nil
}
// restorePeerLocalUpdates restores the acked local log updates the peer still
// needs to sign for.
func (lc *LightningChannel) restorePeerLocalUpdates(updates []channeldb.LogUpdate,
remoteCommitmentHeight uint64) error {
lc.log.Debugf("Restoring %v local updates that the peer should sign",
len(updates))
for _, logUpdate := range updates {
logUpdate := logUpdate
payDesc, err := lc.localLogUpdateToPayDesc(
&logUpdate, lc.updateLogs.Remote,
remoteCommitmentHeight,
)
if err != nil {
return err
}
lc.updateLogs.Local.restoreUpdate(payDesc)
// Since Add updates are not stored and FeeUpdates don't have a
// corresponding entry in the remote update log, we only need to
// mark the htlc as modified if the update was Settle, Fail, or
// MalformedFail.
if payDesc.EntryType != FeeUpdate {
lc.updateLogs.Remote.markHtlcModified(
payDesc.ParentIndex,
)
}
}
return nil
}
// restorePendingLocalUpdates restores the local log updates leading up to the
// given pending remote commitment.
func (lc *LightningChannel) restorePendingLocalUpdates(
pendingRemoteCommitDiff *channeldb.CommitDiff,
pendingRemoteKeys *CommitmentKeyRing) error {
pendingCommit := pendingRemoteCommitDiff.Commitment
pendingHeight := pendingCommit.CommitHeight
lc.log.Debugf("Restoring pending remote commitment %v at commit "+
"height %v", pendingCommit.CommitTx.TxHash(), pendingHeight)
auxResult, err := fn.MapOptionZ(
lc.leafStore,
func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(lc.channelState), pendingCommit,
*pendingRemoteKeys, lntypes.Remote,
)
},
).Unpack()
if err != nil {
return fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// If we did have a dangling commit, then we'll examine which updates
// we included in that state and re-insert them into our update log.
for _, logUpdate := range pendingRemoteCommitDiff.LogUpdates {
logUpdate := logUpdate
payDesc, err := lc.logUpdateToPayDesc(
&logUpdate, lc.updateLogs.Remote, pendingHeight,
chainfee.SatPerKWeight(pendingCommit.FeePerKw),
pendingRemoteKeys,
lc.channelState.RemoteChanCfg.DustLimit,
auxResult.AuxLeaves,
)
if err != nil {
return err
}
// Earlier versions did not write the log index to disk for fee
// updates, so they will be unset. To account for this we set
// them to to current update log index.
if payDesc.EntryType == FeeUpdate && payDesc.LogIndex == 0 &&
lc.updateLogs.Local.logIndex > 0 {
payDesc.LogIndex = lc.updateLogs.Local.logIndex
lc.log.Debugf("Found FeeUpdate on "+
"pendingRemoteCommitDiff without logIndex, "+
"using %v", payDesc.LogIndex)
}
// At this point the restored update's logIndex must be equal
// to the update log, otherwise something is horribly wrong.
if payDesc.LogIndex != lc.updateLogs.Local.logIndex {
panic(fmt.Sprintf("log index mismatch: "+
"%v vs %v", payDesc.LogIndex,
lc.updateLogs.Local.logIndex))
}
switch payDesc.EntryType {
case Add:
// The HtlcIndex of the added HTLC _must_ be equal to
// the log's htlcCounter at this point. If it is not we
// panic to catch this.
// TODO(halseth): remove when cause of htlc entry bug
// is found.
if payDesc.HtlcIndex !=
lc.updateLogs.Local.htlcCounter {
panic(fmt.Sprintf("htlc index mismatch: "+
"%v vs %v", payDesc.HtlcIndex,
lc.updateLogs.Local.htlcCounter))
}
lc.updateLogs.Local.appendHtlc(payDesc)
case FeeUpdate:
lc.updateLogs.Local.appendUpdate(payDesc)
default:
lc.updateLogs.Local.appendUpdate(payDesc)
lc.updateLogs.Remote.markHtlcModified(
payDesc.ParentIndex,
)
}
}
return nil
}
// HtlcRetribution contains all the items necessary to seep a revoked HTLC
// transaction from a revoked commitment transaction broadcast by the remote
// party.
type HtlcRetribution struct {
// SignDesc is a design descriptor capable of generating the necessary
// signatures to satisfy the revocation clause of the HTLC's public key
// script.
SignDesc input.SignDescriptor
// OutPoint is the target outpoint of this HTLC pointing to the
// breached commitment transaction.
OutPoint wire.OutPoint
// SecondLevelWitnessScript is the witness script that will be created
// if the second level HTLC transaction for this output is
// broadcast/confirmed. We provide this as if the remote party attempts
// to go to the second level to claim the HTLC then we'll need to
// update the SignDesc above accordingly to sweep properly.
SecondLevelWitnessScript []byte
// SecondLevelTapTweak is the tap tweak value needed to spend the
// second level output in case the breaching party attempts to publish
// it.
SecondLevelTapTweak [32]byte
// IsIncoming is a boolean flag that indicates whether or not this
// HTLC was accepted from the counterparty. A false value indicates that
// this HTLC was offered by us. This flag is used determine the exact
// witness type should be used to sweep the output.
IsIncoming bool
// ResolutionBlob is a blob used for aux channels that permits a
// spender of this output to claim all funds.
ResolutionBlob fn.Option[tlv.Blob]
}
// BreachRetribution contains all the data necessary to bring a channel
// counterparty to justice claiming ALL lingering funds within the channel in
// the scenario that they broadcast a revoked commitment transaction. A
// BreachRetribution is created by the closeObserver if it detects an
// uncooperative close of the channel which uses a revoked commitment
// transaction. The BreachRetribution is then sent over the ContractBreach
// channel in order to allow the subscriber of the channel to dispatch justice.
type BreachRetribution struct {
// BreachTxHash is the transaction hash which breached the channel
// contract by spending from the funding multi-sig with a revoked
// commitment transaction.
BreachTxHash chainhash.Hash
// BreachHeight records the block height confirming the breach
// transaction, used as a height hint when registering for
// confirmations.
BreachHeight uint32
// ChainHash is the chain that the contract beach was identified
// within. This is also the resident chain of the contract (the chain
// the contract was created on).
ChainHash chainhash.Hash
// RevokedStateNum is the revoked state number which was broadcast.
RevokedStateNum uint64
// LocalOutputSignDesc is a SignDescriptor which is capable of
// generating the signature necessary to sweep the output within the
// breach transaction that pays directly us.
//
// NOTE: A nil value indicates that the local output is considered dust
// according to the remote party's dust limit.
LocalOutputSignDesc *input.SignDescriptor
// LocalOutpoint is the outpoint of the output paying to us (the local
// party) within the breach transaction.
LocalOutpoint wire.OutPoint
// LocalDelay is the CSV delay for the to_remote script on the breached
// commitment.
LocalDelay uint32
// RemoteOutputSignDesc is a SignDescriptor which is capable of
// generating the signature required to claim the funds as described
// within the revocation clause of the remote party's commitment
// output.
//
// NOTE: A nil value indicates that the local output is considered dust
// according to the remote party's dust limit.
RemoteOutputSignDesc *input.SignDescriptor
// RemoteOutpoint is the outpoint of the output paying to the remote
// party within the breach transaction.
RemoteOutpoint wire.OutPoint
// RemoteDelay specifies the CSV delay applied to to-local scripts on
// the breaching commitment transaction.
RemoteDelay uint32
// HtlcRetributions is a slice of HTLC retributions for each output
// active HTLC output within the breached commitment transaction.
HtlcRetributions []HtlcRetribution
// KeyRing contains the derived public keys used to construct the
// breaching commitment transaction. This allows downstream clients to
// have access to the public keys used in the scripts.
KeyRing *CommitmentKeyRing
// LocalResolutionBlob is a blob used for aux channels that permits an
// honest party to sweep the local commitment output.
LocalResolutionBlob fn.Option[tlv.Blob]
// RemoteResolutionBlob is a blob used for aux channels that permits an
// honest party to sweep the remote commitment output.
RemoteResolutionBlob fn.Option[tlv.Blob]
}
// NewBreachRetribution creates a new fully populated BreachRetribution for the
// passed channel, at a particular revoked state number. If the spend
// transaction that the breach retribution should target is known, then it can
// be provided via the spendTx parameter. Otherwise, if the spendTx parameter is
// nil, then the revocation log will be checked to see if it contains the info
// required to construct the BreachRetribution. If the revocation log is missing
// the required fields then ErrRevLogDataMissing will be returned.
func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64,
breachHeight uint32, spendTx *wire.MsgTx,
leafStore fn.Option[AuxLeafStore],
auxResolver fn.Option[AuxContractResolver]) (*BreachRetribution,
error) {
// Query the on-disk revocation log for the snapshot which was recorded
// at this particular state num. Based on whether a legacy revocation
// log is returned or not, we will process them differently.
revokedLog, revokedLogLegacy, err := chanState.FindPreviousState(
stateNum,
)
if err != nil {
return nil, err
}
// Sanity check that at least one of the logs is returned.
if revokedLog == nil && revokedLogLegacy == nil {
return nil, ErrNoRevocationLogFound
}
// With the state number broadcast known, we can now derive/restore the
// proper revocation preimage necessary to sweep the remote party's
// output.
revocationPreimage, err := chanState.RevocationStore.LookUp(stateNum)
if err != nil {
return nil, err
}
commitmentSecret, commitmentPoint := btcec.PrivKeyFromBytes(
revocationPreimage[:],
)
// With the commitment point generated, we can now generate the four
// keys we'll need to reconstruct the commitment state,
keyRing := DeriveCommitmentKeys(
commitmentPoint, lntypes.Remote, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg,
)
// Next, reconstruct the scripts as they were present at this state
// number so we can have the proper witness script to sign and include
// within the final witness.
var leaseExpiry uint32
if chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = chanState.ThawHeight
}
auxResult, err := fn.MapOptionZ(
leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromRevocation(revokedLog)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// Since it is the remote breach we are reconstructing, the output
// going to us will be a to-remote script with our local params.
remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
)(auxResult.AuxLeaves)
isRemoteInitiator := !chanState.IsInitiator
ourScript, ourDelay, err := CommitScriptToRemote(
chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey,
leaseExpiry, remoteAuxLeaf,
)
if err != nil {
return nil, err
}
localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
)(auxResult.AuxLeaves)
theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay)
theirScript, err := CommitScriptToSelf(
chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey,
keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf,
)
if err != nil {
return nil, err
}
// Define an empty breach retribution that will be overwritten based on
// different version of the revocation log found.
var br *BreachRetribution
// Define our and their amounts, that will be overwritten below.
var ourAmt, theirAmt int64
// If the returned *RevocationLog is non-nil, use it to derive the info
// we need.
if revokedLog != nil {
br, ourAmt, theirAmt, err = createBreachRetribution(
revokedLog, spendTx, chanState, keyRing,
commitmentSecret, leaseExpiry, auxResult.AuxLeaves,
)
if err != nil {
return nil, err
}
} else {
// The returned revocation log is in legacy format, which is a
// *ChannelCommitment.
//
// NOTE: this branch is kept for compatibility such that for
// old nodes which refuse to migrate the legacy revocation log
// data can still function. This branch can be deleted once we
// are confident that no legacy format is in use.
br, ourAmt, theirAmt, err = createBreachRetributionLegacy(
revokedLogLegacy, chanState, keyRing, commitmentSecret,
ourScript, theirScript, leaseExpiry,
)
if err != nil {
return nil, err
}
}
// Conditionally instantiate a sign descriptor for each of the
// commitment outputs. If either is considered dust using the remote
// party's dust limit, the respective sign descriptor will be nil.
//
// If our balance exceeds the remote party's dust limit, instantiate
// the sign descriptor for our output.
if ourAmt >= int64(chanState.RemoteChanCfg.DustLimit) {
// As we're about to sweep our own output w/o a delay, we'll
// obtain the witness script for the success/delay path.
witnessScript, err := ourScript.WitnessScriptForPath(
input.ScriptPathDelay,
)
if err != nil {
return nil, err
}
br.LocalOutputSignDesc = &input.SignDescriptor{
SingleTweak: keyRing.LocalCommitKeyTweak,
KeyDesc: chanState.LocalChanCfg.PaymentBasePoint,
WitnessScript: witnessScript,
Output: &wire.TxOut{
PkScript: ourScript.PkScript(),
Value: ourAmt,
},
HashType: sweepSigHash(chanState.ChanType),
}
// For taproot channels, we'll make sure to set the script path
// spend (as our output on their revoked tx still needs the
// delay), and set the control block.
if scriptTree, ok := ourScript.(input.TapscriptDescriptor); ok {
//nolint:ll
br.LocalOutputSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
input.ScriptPathDelay,
)
if err != nil {
return nil, err
}
//nolint:ll
br.LocalOutputSignDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
// At this point, we'll check to see if we need any extra
// resolution data for this output.
resolveReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanState.ChanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
FundingBlob: chanState.CustomBlob,
Type: input.TaprootRemoteCommitSpend,
CloseType: Breach,
CommitTx: spendTx,
SignDesc: *br.LocalOutputSignDesc,
KeyRing: keyRing,
CsvDelay: ourDelay,
BreachCsvDelay: fn.Some(theirDelay),
CommitFee: chanState.RemoteCommitment.CommitFee,
}
if revokedLog != nil {
resolveReq.CommitBlob = revokedLog.CustomBlob.ValOpt()
}
resolveBlob := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
return a.ResolveContract(resolveReq)
},
)
if err := resolveBlob.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
br.LocalResolutionBlob = resolveBlob.OkToSome()
}
// Similarly, if their balance exceeds the remote party's dust limit,
// assemble the sign descriptor for their output, which we can sweep.
if theirAmt >= int64(chanState.RemoteChanCfg.DustLimit) {
// As we're trying to defend the channel against a breach
// attempt from the remote party, we want to obain the
// revocation witness script here.
witnessScript, err := theirScript.WitnessScriptForPath(
input.ScriptPathRevocation,
)
if err != nil {
return nil, err
}
br.RemoteOutputSignDesc = &input.SignDescriptor{
KeyDesc: chanState.LocalChanCfg.
RevocationBasePoint,
DoubleTweak: commitmentSecret,
WitnessScript: witnessScript,
Output: &wire.TxOut{
PkScript: theirScript.PkScript(),
Value: theirAmt,
},
HashType: sweepSigHash(chanState.ChanType),
}
// For taproot channels, the remote output (the revoked output)
// is spent with a script path to ensure all information 3rd
// parties need to sweep anchors is revealed on chain.
scriptTree, ok := theirScript.(input.TapscriptDescriptor)
if ok {
//nolint:ll
br.RemoteOutputSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
input.ScriptPathRevocation,
)
if err != nil {
return nil, err
}
//nolint:ll
br.RemoteOutputSignDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
// At this point, we'll check to see if we need any extra
// resolution data for this output.
resolveReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanState.ChanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
FundingBlob: chanState.CustomBlob,
Type: input.TaprootCommitmentRevoke,
CloseType: Breach,
CommitTx: spendTx,
SignDesc: *br.RemoteOutputSignDesc,
KeyRing: keyRing,
CsvDelay: theirDelay,
BreachCsvDelay: fn.Some(theirDelay),
CommitFee: chanState.RemoteCommitment.CommitFee,
}
if revokedLog != nil {
resolveReq.CommitBlob = revokedLog.CustomBlob.ValOpt()
}
resolveBlob := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
return a.ResolveContract(resolveReq)
},
)
if err := resolveBlob.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
br.RemoteResolutionBlob = resolveBlob.OkToSome()
}
// Finally, with all the necessary data constructed, we can pad the
// BreachRetribution struct which houses all the data necessary to
// swiftly bring justice to the cheating remote party.
br.BreachHeight = breachHeight
br.RevokedStateNum = stateNum
br.LocalDelay = ourDelay
br.RemoteDelay = theirDelay
return br, nil
}
// createHtlcRetribution is a helper function to construct an HtlcRetribution
// based on the passed params.
func createHtlcRetribution(chanState *channeldb.OpenChannel,
keyRing *CommitmentKeyRing, commitHash chainhash.Hash,
commitmentSecret *btcec.PrivateKey, leaseExpiry uint32,
htlc *channeldb.HTLCEntry,
auxLeaves fn.Option[CommitAuxLeaves]) (HtlcRetribution, error) {
var emptyRetribution HtlcRetribution
theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay)
isRemoteInitiator := !chanState.IsInitiator
// We'll generate the original second level witness script now, as
// we'll need it if we're revoking an HTLC output on the remote
// commitment transaction, and *they* go to the second level.
secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] {
return fn.MapOption(func(val uint16) input.AuxTapLeaf {
idx := input.HtlcIndex(val)
if htlc.Incoming.Val {
leaves := l.IncomingHtlcLeaves[idx]
return leaves.SecondLevelLeaf
}
return l.OutgoingHtlcLeaves[idx].SecondLevelLeaf
})(htlc.HtlcIndex.ValOpt())
},
)(auxLeaves)
secondLevelScript, err := SecondLevelHtlcScript(
chanState.ChanType, isRemoteInitiator,
keyRing.RevocationKey, keyRing.ToLocalKey, theirDelay,
leaseExpiry, fn.FlattenOption(secondLevelAuxLeaf),
)
if err != nil {
return emptyRetribution, err
}
// If this is an incoming HTLC, then this means that they were the
// sender of the HTLC (relative to us). So we'll re-generate the sender
// HTLC script. Otherwise, is this was an outgoing HTLC that we sent,
// then from the PoV of the remote commitment state, they're the
// receiver of this HTLC.
htlcLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) fn.Option[input.AuxTapLeaf] {
return fn.MapOption(func(val uint16) input.AuxTapLeaf {
idx := input.HtlcIndex(val)
if htlc.Incoming.Val {
leaves := l.IncomingHtlcLeaves[idx]
return leaves.AuxTapLeaf
}
return l.OutgoingHtlcLeaves[idx].AuxTapLeaf
})(htlc.HtlcIndex.ValOpt())
},
)(auxLeaves)
scriptInfo, err := genHtlcScript(
chanState.ChanType, htlc.Incoming.Val, lntypes.Remote,
htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing,
fn.FlattenOption(htlcLeaf),
)
if err != nil {
return emptyRetribution, err
}
signDesc := input.SignDescriptor{
KeyDesc: chanState.LocalChanCfg.
RevocationBasePoint,
DoubleTweak: commitmentSecret,
WitnessScript: scriptInfo.WitnessScriptToSign(),
Output: &wire.TxOut{
PkScript: scriptInfo.PkScript(),
Value: int64(htlc.Amt.Val.Int()),
},
HashType: sweepSigHash(chanState.ChanType),
}
// For taproot HTLC outputs, we need to set the sign method to key
// spend, and also set the tap tweak root needed to derive the proper
// private key.
if scriptTree, ok := scriptInfo.(input.TapscriptDescriptor); ok {
signDesc.SignMethod = input.TaprootKeySpendSignMethod
signDesc.TapTweak = scriptTree.TapTweak()
}
// The second level script we sign will always be the success path.
secondLevelWitnessScript, err := secondLevelScript.WitnessScriptForPath(
input.ScriptPathSuccess,
)
if err != nil {
return emptyRetribution, err
}
// If this is a taproot output, we'll also need to obtain the second
// level tap tweak as well.
var secondLevelTapTweak [32]byte
if scriptTree, ok := secondLevelScript.(input.TapscriptDescriptor); ok {
copy(secondLevelTapTweak[:], scriptTree.TapTweak())
}
return HtlcRetribution{
SignDesc: signDesc,
OutPoint: wire.OutPoint{
Hash: commitHash,
Index: uint32(htlc.OutputIndex.Val),
},
SecondLevelWitnessScript: secondLevelWitnessScript,
IsIncoming: htlc.Incoming.Val,
SecondLevelTapTweak: secondLevelTapTweak,
}, nil
}
// createBreachRetribution creates a partially initiated BreachRetribution
// using a RevocationLog. Returns the constructed retribution, our amount,
// their amount, and a possible non-nil error. If the spendTx parameter is
// non-nil, then it will be used to glean the breach transaction's to-local and
// to-remote output amounts. Otherwise, the RevocationLog will be checked to
// see if these fields are present there. If they are not, then
// ErrRevLogDataMissing is returned.
func createBreachRetribution(revokedLog *channeldb.RevocationLog,
spendTx *wire.MsgTx, chanState *channeldb.OpenChannel,
keyRing *CommitmentKeyRing, commitmentSecret *btcec.PrivateKey,
leaseExpiry uint32,
auxLeaves fn.Option[CommitAuxLeaves]) (*BreachRetribution, int64, int64,
error) {
commitHash := revokedLog.CommitTxHash
// Create the htlc retributions.
htlcRetributions := make([]HtlcRetribution, len(revokedLog.HTLCEntries))
for i, htlc := range revokedLog.HTLCEntries {
hr, err := createHtlcRetribution(
chanState, keyRing, commitHash.Val,
commitmentSecret, leaseExpiry, htlc, auxLeaves,
)
if err != nil {
return nil, 0, 0, err
}
htlcRetributions[i] = hr
}
var ourAmt, theirAmt int64
// Construct the our outpoint.
ourOutpoint := wire.OutPoint{
Hash: commitHash.Val,
}
if revokedLog.OurOutputIndex.Val != channeldb.OutputIndexEmpty {
ourOutpoint.Index = uint32(revokedLog.OurOutputIndex.Val)
// If the spend transaction is provided, then we use it to get
// the value of our output.
if spendTx != nil {
// Sanity check that OurOutputIndex is within range.
if int(ourOutpoint.Index) >= len(spendTx.TxOut) {
return nil, 0, 0, fmt.Errorf("%w: ours=%v, "+
"len(TxOut)=%v",
ErrOutputIndexOutOfRange,
ourOutpoint.Index, len(spendTx.TxOut),
)
}
// Read the amounts from the breach transaction.
//
// NOTE: ourAmt here includes commit fee and anchor
// amount (if enabled).
ourAmt = spendTx.TxOut[ourOutpoint.Index].Value
} else {
// Otherwise, we check to see if the revocation log
// contains our output amount. Due to a previous
// migration, this field may be empty in which case an
// error will be returned.
b, err := revokedLog.OurBalance.ValOpt().UnwrapOrErr(
ErrRevLogDataMissing,
)
if err != nil {
return nil, 0, 0, err
}
ourAmt = int64(b.Int().ToSatoshis())
}
}
// Construct the their outpoint.
theirOutpoint := wire.OutPoint{
Hash: commitHash.Val,
}
if revokedLog.TheirOutputIndex.Val != channeldb.OutputIndexEmpty {
theirOutpoint.Index = uint32(revokedLog.TheirOutputIndex.Val)
// If the spend transaction is provided, then we use it to get
// the value of the remote parties' output.
if spendTx != nil {
// Sanity check that TheirOutputIndex is within range.
if int(revokedLog.TheirOutputIndex.Val) >=
len(spendTx.TxOut) {
return nil, 0, 0, fmt.Errorf("%w: theirs=%v, "+
"len(TxOut)=%v",
ErrOutputIndexOutOfRange,
revokedLog.TheirOutputIndex,
len(spendTx.TxOut),
)
}
// Read the amounts from the breach transaction.
theirAmt = spendTx.TxOut[theirOutpoint.Index].Value
} else {
// Otherwise, we check to see if the revocation log
// contains remote parties' output amount. Due to a
// previous migration, this field may be empty in which
// case an error will be returned.
b, err := revokedLog.TheirBalance.ValOpt().UnwrapOrErr(
ErrRevLogDataMissing,
)
if err != nil {
return nil, 0, 0, err
}
theirAmt = int64(b.Int().ToSatoshis())
}
}
return &BreachRetribution{
BreachTxHash: commitHash.Val,
ChainHash: chanState.ChainHash,
LocalOutpoint: ourOutpoint,
RemoteOutpoint: theirOutpoint,
HtlcRetributions: htlcRetributions,
KeyRing: keyRing,
}, ourAmt, theirAmt, nil
}
// createBreachRetributionLegacy creates a partially initiated
// BreachRetribution using a ChannelCommitment. Returns the constructed
// retribution, our amount, their amount, and a possible non-nil error.
func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment,
chanState *channeldb.OpenChannel, keyRing *CommitmentKeyRing,
commitmentSecret *btcec.PrivateKey,
ourScript, theirScript input.ScriptDescriptor,
leaseExpiry uint32) (*BreachRetribution, int64, int64, error) {
commitHash := revokedLog.CommitTx.TxHash()
ourOutpoint := wire.OutPoint{
Hash: commitHash,
}
theirOutpoint := wire.OutPoint{
Hash: commitHash,
}
// In order to fully populate the breach retribution struct, we'll need
// to find the exact index of the commitment outputs.
for i, txOut := range revokedLog.CommitTx.TxOut {
switch {
case bytes.Equal(txOut.PkScript, ourScript.PkScript()):
ourOutpoint.Index = uint32(i)
case bytes.Equal(txOut.PkScript, theirScript.PkScript()):
theirOutpoint.Index = uint32(i)
}
}
// With the commitment outputs located, we'll now generate all the
// retribution structs for each of the HTLC transactions active on the
// remote commitment transaction.
htlcRetributions := make([]HtlcRetribution, len(revokedLog.Htlcs))
for i, htlc := range revokedLog.Htlcs {
// If the HTLC is dust, then we'll skip it as it doesn't have
// an output on the commitment transaction.
if HtlcIsDust(
chanState.ChanType, htlc.Incoming, lntypes.Remote,
chainfee.SatPerKWeight(revokedLog.FeePerKw),
htlc.Amt.ToSatoshis(),
chanState.RemoteChanCfg.DustLimit,
) {
continue
}
entry, err := channeldb.NewHTLCEntryFromHTLC(htlc)
if err != nil {
return nil, 0, 0, err
}
hr, err := createHtlcRetribution(
chanState, keyRing, commitHash,
commitmentSecret, leaseExpiry, entry,
fn.None[CommitAuxLeaves](),
)
if err != nil {
return nil, 0, 0, err
}
htlcRetributions[i] = hr
}
// Compute the balances in satoshis.
ourAmt := int64(revokedLog.LocalBalance.ToSatoshis())
theirAmt := int64(revokedLog.RemoteBalance.ToSatoshis())
return &BreachRetribution{
BreachTxHash: commitHash,
ChainHash: chanState.ChainHash,
LocalOutpoint: ourOutpoint,
RemoteOutpoint: theirOutpoint,
HtlcRetributions: htlcRetributions,
KeyRing: keyRing,
}, ourAmt, theirAmt, nil
}
// HtlcIsDust determines if an HTLC output is dust or not depending on two
// bits: if the HTLC is incoming and if the HTLC will be placed on our
// commitment transaction, or theirs. These two pieces of information are
// required as we currently used second-level HTLC transactions as off-chain
// covenants. Depending on the two bits, we'll either be using a timeout or
// success transaction which have different weights.
func HtlcIsDust(chanType channeldb.ChannelType,
incoming bool, whoseCommit lntypes.ChannelParty,
feePerKw chainfee.SatPerKWeight, htlcAmt, dustLimit btcutil.Amount,
) bool {
// First we'll determine the fee required for this HTLC based on if this is
// an incoming HTLC or not, and also on whose commitment transaction it
// will be placed on.
var htlcFee btcutil.Amount
switch {
// If this is an incoming HTLC on our commitment transaction, then the
// second-level transaction will be a success transaction.
case incoming && whoseCommit.IsLocal():
htlcFee = HtlcSuccessFee(chanType, feePerKw)
// If this is an incoming HTLC on their commitment transaction, then
// we'll be using a second-level timeout transaction as they've added
// this HTLC.
case incoming && whoseCommit.IsRemote():
htlcFee = HtlcTimeoutFee(chanType, feePerKw)
// If this is an outgoing HTLC on our commitment transaction, then
// we'll be using a timeout transaction as we're the sender of the
// HTLC.
case !incoming && whoseCommit.IsLocal():
htlcFee = HtlcTimeoutFee(chanType, feePerKw)
// If this is an outgoing HTLC on their commitment transaction, then
// we'll be using an HTLC success transaction as they're the receiver
// of this HTLC.
case !incoming && whoseCommit.IsRemote():
htlcFee = HtlcSuccessFee(chanType, feePerKw)
}
return (htlcAmt - htlcFee) < dustLimit
}
// HtlcView represents the "active" HTLCs at a particular point within the
// history of the HTLC update log.
type HtlcView struct {
// NextHeight is the height of the commitment transaction that will be
// created using this view.
NextHeight uint64
// Updates is a Dual of the Local and Remote HTLCs.
Updates lntypes.Dual[[]*paymentDescriptor]
// FeePerKw is the fee rate in sat/kw of the commitment transaction.
FeePerKw chainfee.SatPerKWeight
}
// AuxOurUpdates returns the outgoing HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor {
return fn.Map(v.Updates.Local, newAuxHtlcDescriptor)
}
// AuxTheirUpdates returns the incoming HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor {
return fn.Map(v.Updates.Remote, newAuxHtlcDescriptor)
}
// FetchLatestAuxHTLCView returns the latest HTLC view of the lightning channel
// as a safe copy that can be used outside the wallet code in concurrent access.
func (lc *LightningChannel) FetchLatestAuxHTLCView() AuxHtlcView {
// This read lock is important, because we access both the local and
// remote log indexes as well as the underlying payment descriptors of
// the HTLCs when creating the view.
lc.RLock()
defer lc.RUnlock()
return newAuxHtlcView(lc.fetchHTLCView(
lc.updateLogs.Remote.logIndex, lc.updateLogs.Local.logIndex,
))
}
// fetchHTLCView returns all the candidate HTLC updates which should be
// considered for inclusion within a commitment based on the passed HTLC log
// indexes.
func (lc *LightningChannel) fetchHTLCView(theirLogIndex,
ourLogIndex uint64) *HtlcView {
var ourHTLCs []*paymentDescriptor
for e := lc.updateLogs.Local.Front(); e != nil; e = e.Next() {
htlc := e.Value
// This HTLC is active from this point-of-view iff the log
// index of the state update is below the specified index in
// our update log.
if htlc.LogIndex < ourLogIndex {
ourHTLCs = append(ourHTLCs, htlc)
}
}
var theirHTLCs []*paymentDescriptor
for e := lc.updateLogs.Remote.Front(); e != nil; e = e.Next() {
htlc := e.Value
// If this is an incoming HTLC, then it is only active from
// this point-of-view if the index of the HTLC addition in
// their log is below the specified view index.
if htlc.LogIndex < theirLogIndex {
theirHTLCs = append(theirHTLCs, htlc)
}
}
return &HtlcView{
Updates: lntypes.Dual[[]*paymentDescriptor]{
Local: ourHTLCs,
Remote: theirHTLCs,
},
}
}
// fetchCommitmentView returns a populated commitment which expresses the state
// of the channel from the point of view of a local or remote chain, evaluating
// the HTLC log up to the passed indexes. This function is used to construct
// both local and remote commitment transactions in order to sign or verify new
// commitment updates. A fully populated commitment is returned which reflects
// the proper balances for both sides at this point in the commitment chain.
func (lc *LightningChannel) fetchCommitmentView(
whoseCommitChain lntypes.ChannelParty,
ourLogIndex, ourHtlcIndex, theirLogIndex, theirHtlcIndex uint64,
keyRing *CommitmentKeyRing) (*commitment, error) {
commitChain := lc.commitChains.Local
dustLimit := lc.channelState.LocalChanCfg.DustLimit
if whoseCommitChain.IsRemote() {
commitChain = lc.commitChains.Remote
dustLimit = lc.channelState.RemoteChanCfg.DustLimit
}
nextHeight := commitChain.tip().height + 1
// Run through all the HTLCs that will be covered by this transaction
// in order to update their commitment addition height, and to adjust
// the balances on the commitment transaction accordingly. Note that
// these balances will be *before* taking a commitment fee from the
// initiator.
htlcView := lc.fetchHTLCView(theirLogIndex, ourLogIndex)
ourBalance, theirBalance, _, filteredHTLCView, err := lc.computeView(
htlcView, whoseCommitChain, true,
fn.None[chainfee.SatPerKWeight](),
)
if err != nil {
return nil, err
}
feePerKw := filteredHTLCView.FeePerKw
htlcView.NextHeight = nextHeight
filteredHTLCView.NextHeight = nextHeight
// Actually generate unsigned commitment transaction for this view.
commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx(
ourBalance, theirBalance, whoseCommitChain, feePerKw,
nextHeight, htlcView, filteredHTLCView, keyRing,
commitChain.tip(),
)
if err != nil {
return nil, err
}
// We'll assert that there hasn't been a mistake during fee calculation
// leading to a fee too low.
var totalOut btcutil.Amount
for _, txOut := range commitTx.txn.TxOut {
totalOut += btcutil.Amount(txOut.Value)
}
fee := lc.channelState.Capacity - totalOut
var witnessWeight int64
if lc.channelState.ChanType.IsTaproot() {
witnessWeight = input.TaprootKeyPathWitnessSize
} else {
witnessWeight = input.WitnessCommitmentTxWeight
}
// Since the transaction is not signed yet, we use the witness weight
// used for weight calculation.
uTx := btcutil.NewTx(commitTx.txn)
weight := blockchain.GetTransactionWeight(uTx) + witnessWeight
effFeeRate := chainfee.SatPerKWeight(fee) * 1000 /
chainfee.SatPerKWeight(weight)
if effFeeRate < chainfee.AbsoluteFeePerKwFloor {
return nil, fmt.Errorf("height=%v, for ChannelPoint(%v) "+
"attempts to create commitment with feerate %v: %v",
nextHeight, lc.channelState.FundingOutpoint,
effFeeRate, spew.Sdump(commitTx))
}
// Given the custom blob of the past state, and this new HTLC view,
// we'll generate a new blob for the latest commitment.
newCommitBlob, err := fn.MapOptionZ(
lc.leafStore,
func(s AuxLeafStore) fn.Result[fn.Option[tlv.Blob]] {
return updateAuxBlob(
s, lc.channelState,
commitChain.tip().customBlob, htlcView,
whoseCommitChain, ourBalance, theirBalance,
*keyRing,
)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
messageIndices := lntypes.Dual[uint64]{
Local: ourLogIndex,
Remote: theirLogIndex,
}
// With the commitment view created, store the resulting balances and
// transaction with the other parameters for this height.
c := &commitment{
ourBalance: commitTx.ourBalance,
theirBalance: commitTx.theirBalance,
txn: commitTx.txn,
fee: commitTx.fee,
messageIndices: messageIndices,
ourHtlcIndex: ourHtlcIndex,
theirHtlcIndex: theirHtlcIndex,
height: nextHeight,
feePerKw: feePerKw,
dustLimit: dustLimit,
whoseCommit: whoseCommitChain,
customBlob: newCommitBlob,
}
// In order to ensure _none_ of the HTLC's associated with this new
// commitment are mutated, we'll manually copy over each HTLC to its
// respective slice.
c.outgoingHTLCs = make(
[]paymentDescriptor, len(filteredHTLCView.Updates.Local),
)
for i, htlc := range filteredHTLCView.Updates.Local {
c.outgoingHTLCs[i] = *htlc
}
c.incomingHTLCs = make(
[]paymentDescriptor, len(filteredHTLCView.Updates.Remote),
)
for i, htlc := range filteredHTLCView.Updates.Remote {
c.incomingHTLCs[i] = *htlc
}
// Finally, we'll populate all the HTLC indexes so we can track the
// locations of each HTLC in the commitment state. We pass in the sorted
// slice of CLTV deltas in order to properly locate HTLCs that otherwise
// have the same payment hash and amount.
err = c.populateHtlcIndexes(lc.channelState.ChanType, commitTx.cltvs)
if err != nil {
return nil, err
}
return c, nil
}
// fundingTxIn returns the funding output as a transaction input. The input
// returned by this function uses a max sequence number, so it isn't able to be
// used with RBF by default.
func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn {
return *wire.NewTxIn(&chanState.FundingOutpoint, nil, nil)
}
// evaluateHTLCView processes all update entries in both HTLC update logs,
// producing a final view which is the result of properly applying all adds,
// settles, timeouts and fee updates found in both logs. The resulting view
// returned reflects the current state of HTLCs within the remote or local
// commitment chain, and the current commitment fee rate.
//
// The return values of this function are as follows:
// 1. The new htlcView reflecting the current channel state.
// 2. A Dual of the updates which have not yet been committed in
// 'whoseCommitChain's commitment chain.
func (lc *LightningChannel) evaluateHTLCView(view *HtlcView,
whoseCommitChain lntypes.ChannelParty, nextHeight uint64) (*HtlcView,
lntypes.Dual[[]*paymentDescriptor], lntypes.Dual[int64], error) {
// We initialize the view's fee rate to the fee rate of the unfiltered
// view. If any fee updates are found when evaluating the view, it will
// be updated.
newView := &HtlcView{
FeePerKw: view.FeePerKw,
NextHeight: nextHeight,
}
noUncommitted := lntypes.Dual[[]*paymentDescriptor]{}
// The fee rate of our view is always the last UpdateFee message from
// the channel's OpeningParty.
openerUpdates := view.Updates.GetForParty(lc.channelState.Initiator())
feeUpdates := fn.Filter(openerUpdates, func(u *paymentDescriptor) bool {
return u.EntryType == FeeUpdate
})
lastFeeUpdate := fn.Last(feeUpdates)
lastFeeUpdate.WhenSome(func(pd *paymentDescriptor) {
newView.FeePerKw = chainfee.SatPerKWeight(
pd.Amount.ToSatoshis(),
)
})
// We use two maps, one for the local log and one for the remote log to
// keep track of which entries we need to skip when creating the final
// htlc view. We skip an entry whenever we find a settle or a timeout
// modifying an entry.
skip := lntypes.Dual[fn.Set[uint64]]{
Local: fn.NewSet[uint64](),
Remote: fn.NewSet[uint64](),
}
balanceDeltas := lntypes.Dual[int64]{}
parties := [2]lntypes.ChannelParty{lntypes.Local, lntypes.Remote}
for _, party := range parties {
// First we run through non-add entries in both logs,
// populating the skip sets.
resolutions := fn.Filter(
view.Updates.GetForParty(party),
func(pd *paymentDescriptor) bool {
switch pd.EntryType {
case Settle, Fail, MalformedFail:
return true
default:
return false
}
},
)
for _, entry := range resolutions {
addEntry, err := lc.fetchParent(
entry, whoseCommitChain, party.CounterParty(),
)
if err != nil {
noDeltas := lntypes.Dual[int64]{}
return nil, noUncommitted, noDeltas, err
}
skipSet := skip.GetForParty(party.CounterParty())
skipSet.Add(addEntry.HtlcIndex)
rmvHeight := entry.removeCommitHeights.GetForParty(
whoseCommitChain,
)
if rmvHeight == 0 {
switch {
// If an incoming HTLC is being settled, then
// this means that the preimage has been
// received by the settling party Therefore, we
// increase the settling party's balance by the
// HTLC amount.
case entry.EntryType == Settle:
delta := int64(entry.Amount)
balanceDeltas.ModifyForParty(
party,
func(acc int64) int64 {
return acc + delta
},
)
// Otherwise, this HTLC is being failed out,
// therefore the value of the HTLC should
// return to the failing party's counterparty.
case entry.EntryType != Settle:
delta := int64(entry.Amount)
balanceDeltas.ModifyForParty(
party.CounterParty(),
func(acc int64) int64 {
return acc + delta
},
)
}
}
}
}
// Next we take a second pass through all the log entries, skipping any
// settled HTLCs, and debiting the chain state balance due to any newly
// added HTLCs.
for _, party := range parties {
liveAdds := fn.Filter(
view.Updates.GetForParty(party),
func(pd *paymentDescriptor) bool {
isAdd := pd.EntryType == Add
shouldSkip := skip.GetForParty(party).
Contains(pd.HtlcIndex)
return isAdd && !shouldSkip
},
)
for _, entry := range liveAdds {
// Skip the entries that have already had their add
// commit height set for this commit chain.
addHeight := entry.addCommitHeights.GetForParty(
whoseCommitChain,
)
if addHeight == 0 {
// If this is a new incoming (un-committed)
// HTLC, then we need to update their balance
// accordingly by subtracting the amount of
// the HTLC that are funds pending.
// Similarly, we need to debit our balance if
// this is an out going HTLC to reflect the
// pending balance.
balanceDeltas.ModifyForParty(
party,
func(acc int64) int64 {
return acc - int64(entry.Amount)
},
)
}
}
newView.Updates.SetForParty(party, liveAdds)
}
// Create a function that is capable of identifying whether or not the
// paymentDescriptor has been committed in the commitment chain
// corresponding to whoseCommitmentChain.
isUncommitted := func(update *paymentDescriptor) bool {
switch update.EntryType {
case Add:
return update.addCommitHeights.GetForParty(
whoseCommitChain,
) == 0
case FeeUpdate:
return update.addCommitHeights.GetForParty(
whoseCommitChain,
) == 0
case Settle, Fail, MalformedFail:
return update.removeCommitHeights.GetForParty(
whoseCommitChain,
) == 0
default:
panic("invalid paymentDescriptor EntryType")
}
}
// Collect all of the updates that haven't had their commit heights set
// for the commitment chain corresponding to whoseCommitmentChain.
uncommittedUpdates := lntypes.MapDual(
view.Updates,
func(us []*paymentDescriptor) []*paymentDescriptor {
return fn.Filter(us, isUncommitted)
},
)
return newView, uncommittedUpdates, balanceDeltas, nil
}
// fetchParent is a helper that looks up update log parent entries in the
// appropriate log.
func (lc *LightningChannel) fetchParent(entry *paymentDescriptor,
whoseCommitChain, whoseUpdateLog lntypes.ChannelParty,
) (*paymentDescriptor, error) {
var (
updateLog *updateLog
logName string
)
if whoseUpdateLog.IsRemote() {
updateLog = lc.updateLogs.Remote
logName = "remote"
} else {
updateLog = lc.updateLogs.Local
logName = "local"
}
addEntry := updateLog.lookupHtlc(entry.ParentIndex)
switch {
// We check if the parent entry is not found at this point.
// This could happen for old versions of lnd, and we return an
// error to gracefully shut down the state machine if such an
// entry is still in the logs.
case addEntry == nil:
return nil, fmt.Errorf("unable to find parent entry "+
"%d in %v update log: %v\nUpdatelog: %v",
entry.ParentIndex, logName,
lnutils.SpewLogClosure(entry),
lnutils.SpewLogClosure(updateLog))
// The parent add height should never be zero at this point. If
// that's the case we probably forgot to send a new commitment.
case addEntry.addCommitHeights.GetForParty(whoseCommitChain) == 0:
return nil, fmt.Errorf("parent entry %d for update %d "+
"had zero %v add height", entry.ParentIndex,
entry.LogIndex, whoseCommitChain)
}
return addEntry, nil
}
// generateRemoteHtlcSigJobs generates a series of HTLC signature jobs for the
// sig pool, along with a channel that if closed, will cancel any jobs after
// they have been submitted to the sigPool. This method is to be used when
// generating a new commitment for the remote party. The jobs generated by the
// signature can be submitted to the sigPool to generate all the signatures
// asynchronously and in parallel.
func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing,
chanState *channeldb.OpenChannel, leaseExpiry uint32,
remoteCommitView *commitment,
leafStore fn.Option[AuxLeafStore]) ([]SignJob, []AuxSigJob,
chan struct{}, error) {
var (
isRemoteInitiator = !chanState.IsInitiator
localChanCfg = chanState.LocalChanCfg
remoteChanCfg = chanState.RemoteChanCfg
chanType = chanState.ChanType
)
txHash := remoteCommitView.txn.TxHash()
dustLimit := remoteChanCfg.DustLimit
feePerKw := remoteCommitView.feePerKw
sigHashType := HtlcSigHashType(chanType)
// With the keys generated, we'll make a slice with enough capacity to
// hold potentially all the HTLCs. The actual slice may be a bit
// smaller (than its total capacity) and some HTLCs may be dust.
numSigs := len(remoteCommitView.incomingHTLCs) +
len(remoteCommitView.outgoingHTLCs)
sigBatch := make([]SignJob, 0, numSigs)
auxSigBatch := make([]AuxSigJob, 0, numSigs)
var err error
cancelChan := make(chan struct{})
diskCommit := remoteCommitView.toDiskCommit(lntypes.Remote)
auxResult, err := fn.MapOptionZ(
leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(chanState), *diskCommit,
*keyRing, lntypes.Remote,
)
},
).Unpack()
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to fetch aux leaves: "+
"%w", err)
}
// For each outgoing and incoming HTLC, if the HTLC isn't considered a
// dust output after taking into account second-level HTLC fees, then a
// sigJob will be generated and appended to the current batch.
for _, htlc := range remoteCommitView.incomingHTLCs {
if HtlcIsDust(
chanType, true, lntypes.Remote, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
// If the HTLC isn't dust, then we'll create an empty sign job
// to add to the batch momentarily.
var sigJob SignJob
sigJob.Cancel = cancelChan
sigJob.Resp = make(chan SignJobResp, 1)
// As this is an incoming HTLC and we're sinning the commitment
// transaction of the remote node, we'll need to generate an
// HTLC timeout transaction for them. The output of the timeout
// transaction needs to account for fees, so we'll compute the
// required fee and output now.
htlcFee := HtlcTimeoutFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
},
)(auxResult.AuxLeaves)
// With the fee calculate, we can properly create the HTLC
// timeout transaction using the HTLC amount minus the fee.
op := wire.OutPoint{
Hash: txHash,
Index: uint32(htlc.remoteOutputIndex),
}
sigJob.Tx, err = CreateHtlcTimeoutTx(
chanType, isRemoteInitiator, op, outputAmt,
htlc.Timeout, uint32(remoteChanCfg.CsvDelay),
leaseExpiry, keyRing.RevocationKey, keyRing.ToLocalKey,
auxLeaf,
)
if err != nil {
return nil, nil, nil, err
}
// Construct a full hash cache as we may be signing a segwit v1
// sighash.
txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex]
prevFetcher := txscript.NewCannedPrevOutputFetcher(
txOut.PkScript, int64(htlc.Amount.ToSatoshis()),
)
hashCache := txscript.NewTxSigHashes(sigJob.Tx, prevFetcher)
// Finally, we'll generate a sign descriptor to generate a
// signature to give to the remote party for this commitment
// transaction. Note we use the raw HTLC amount.
sigJob.SignDesc = input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlc.theirWitnessScript,
Output: txOut,
PrevOutputFetcher: prevFetcher,
HashType: sigHashType,
SigHashes: hashCache,
InputIndex: 0,
}
sigJob.OutputIndex = htlc.remoteOutputIndex
// If this is a taproot channel, then we'll need to set the
// method type to ensure we generate a valid signature.
if chanType.IsTaproot() {
//nolint:ll
sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod
}
sigBatch = append(sigBatch, sigJob)
auxSigBatch = append(auxSigBatch, NewAuxSigJob(
sigJob, *keyRing, true, newAuxHtlcDescriptor(&htlc),
remoteCommitView.customBlob, auxLeaf, cancelChan,
))
}
for _, htlc := range remoteCommitView.outgoingHTLCs {
if HtlcIsDust(
chanType, false, lntypes.Remote, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
sigJob := SignJob{}
sigJob.Cancel = cancelChan
sigJob.Resp = make(chan SignJobResp, 1)
// As this is an outgoing HTLC and we're signing the commitment
// transaction of the remote node, we'll need to generate an
// HTLC success transaction for them. The output of the timeout
// transaction needs to account for fees, so we'll compute the
// required fee and output now.
htlcFee := HtlcSuccessFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
auxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
},
)(auxResult.AuxLeaves)
// With the proper output amount calculated, we can now
// generate the success transaction using the remote party's
// CSV delay.
op := wire.OutPoint{
Hash: txHash,
Index: uint32(htlc.remoteOutputIndex),
}
sigJob.Tx, err = CreateHtlcSuccessTx(
chanType, isRemoteInitiator, op, outputAmt,
uint32(remoteChanCfg.CsvDelay), leaseExpiry,
keyRing.RevocationKey, keyRing.ToLocalKey,
auxLeaf,
)
if err != nil {
return nil, nil, nil, err
}
// Construct a full hash cache as we may be signing a segwit v1
// sighash.
txOut := remoteCommitView.txn.TxOut[htlc.remoteOutputIndex]
prevFetcher := txscript.NewCannedPrevOutputFetcher(
txOut.PkScript, int64(htlc.Amount.ToSatoshis()),
)
hashCache := txscript.NewTxSigHashes(sigJob.Tx, prevFetcher)
// Finally, we'll generate a sign descriptor to generate a
// signature to give to the remote party for this commitment
// transaction. Note we use the raw HTLC amount.
sigJob.SignDesc = input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlc.theirWitnessScript,
Output: txOut,
PrevOutputFetcher: prevFetcher,
HashType: sigHashType,
SigHashes: hashCache,
InputIndex: 0,
}
sigJob.OutputIndex = htlc.remoteOutputIndex
// If this is a taproot channel, then we'll need to set the
// method type to ensure we generate a valid signature.
if chanType.IsTaproot() {
//nolint:ll
sigJob.SignDesc.SignMethod = input.TaprootScriptSpendSignMethod
}
sigBatch = append(sigBatch, sigJob)
auxSigBatch = append(auxSigBatch, NewAuxSigJob(
sigJob, *keyRing, false, newAuxHtlcDescriptor(&htlc),
remoteCommitView.customBlob, auxLeaf, cancelChan,
))
}
return sigBatch, auxSigBatch, cancelChan, nil
}
// createCommitDiff will create a commit diff given a new pending commitment
// for the remote party and the necessary signatures for the remote party to
// validate this new state. This function is called right before sending the
// new commitment to the remote party. The commit diff returned contains all
// information necessary for retransmission.
func (lc *LightningChannel) createCommitDiff(newCommit *commitment,
commitSig lnwire.Sig, htlcSigs []lnwire.Sig,
auxSigs []fn.Option[tlv.Blob]) (*channeldb.CommitDiff, error) {
var (
logUpdates []channeldb.LogUpdate
ackAddRefs []channeldb.AddRef
settleFailRefs []channeldb.SettleFailRef
openCircuitKeys []models.CircuitKey
closedCircuitKeys []models.CircuitKey
)
// We'll now run through our local update log to locate the items which
// were only just committed within this pending state. This will be the
// set of items we need to retransmit if we reconnect and find that
// they didn't process this new state fully.
for e := lc.updateLogs.Local.Front(); e != nil; e = e.Next() {
pd := e.Value
// If this entry wasn't committed at the exact height of this
// remote commitment, then we'll skip it as it was already
// lingering in the log.
if pd.addCommitHeights.Remote != newCommit.height &&
pd.removeCommitHeights.Remote != newCommit.height {
continue
}
// We'll map the type of the paymentDescriptor to one of the
// four messages that it corresponds to. With this set of
// messages obtained, we can simply read from disk and re-send
// them in the case of a needed channel sync.
switch pd.EntryType {
case Add:
// Gather any references for circuits opened by this Add
// HTLC.
if pd.OpenCircuitKey != nil {
openCircuitKeys = append(
openCircuitKeys, *pd.OpenCircuitKey,
)
}
case Settle, Fail, MalformedFail:
// Gather the fwd pkg references from any settle or fail
// packets, if they exist.
if pd.SourceRef != nil {
ackAddRefs = append(ackAddRefs, *pd.SourceRef)
}
if pd.DestRef != nil {
settleFailRefs = append(
settleFailRefs, *pd.DestRef,
)
}
if pd.ClosedCircuitKey != nil {
closedCircuitKeys = append(
closedCircuitKeys, *pd.ClosedCircuitKey,
)
}
case FeeUpdate:
// Nothing special to do.
}
logUpdates = append(logUpdates, pd.toLogUpdate())
}
// With the set of log updates mapped into wire messages, we'll now
// convert the in-memory commit into a format suitable for writing to
// disk.
diskCommit := newCommit.toDiskCommit(lntypes.Remote)
// We prepare the commit sig message to be sent to the remote party.
commitSigMsg := &lnwire.CommitSig{
ChanID: lnwire.NewChanIDFromOutPoint(
lc.channelState.FundingOutpoint,
),
CommitSig: commitSig,
HtlcSigs: htlcSigs,
}
// Encode and check the size of the custom records now.
auxCustomRecords, err := fn.MapOptionZ(
lc.auxSigner,
func(s AuxSigner) fn.Result[lnwire.CustomRecords] {
blobOption, err := s.PackSigs(auxSigs).Unpack()
if err != nil {
return fn.Err[lnwire.CustomRecords](err)
}
// We now serialize the commit sig message without the
// custom records to make sure we have space for them.
var buf bytes.Buffer
err = commitSigMsg.Encode(&buf, 0)
if err != nil {
return fn.Err[lnwire.CustomRecords](err)
}
// The number of available bytes is the max message size
// minus the size of the message without the custom
// records. We also subtract 8 bytes for encoding
// overhead of the custom records (just some safety
// padding).
available := lnwire.MaxMsgBody - buf.Len() - 8
blob := blobOption.UnwrapOr(nil)
if len(blob) > available {
err = fmt.Errorf("aux sigs size %d exceeds "+
"max allowed size of %d", len(blob),
available)
return fn.Err[lnwire.CustomRecords](err)
}
records, err := lnwire.ParseCustomRecords(blob)
if err != nil {
return fn.Err[lnwire.CustomRecords](err)
}
return fn.Ok(records)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("error packing aux sigs: %w", err)
}
commitSigMsg.CustomRecords = auxCustomRecords
return &channeldb.CommitDiff{
Commitment: *diskCommit,
CommitSig: commitSigMsg,
LogUpdates: logUpdates,
OpenedCircuitKeys: openCircuitKeys,
ClosedCircuitKeys: closedCircuitKeys,
AddAcks: ackAddRefs,
SettleFailAcks: settleFailRefs,
}, nil
}
// getUnsignedAckedUpdates returns all remote log updates that we haven't
// signed for yet ourselves.
func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
// Fetch the last remote update that we have signed for.
lastRemoteCommitted :=
lc.commitChains.Remote.tail().messageIndices.Remote
// Fetch the last remote update that we have acked.
lastLocalCommitted :=
lc.commitChains.Local.tail().messageIndices.Remote
// We'll now run through the remote update log to locate the items that
// we haven't signed for yet. This will be the set of items we need to
// restore if we reconnect in order to produce the signature that the
// remote party expects.
var logUpdates []channeldb.LogUpdate
for e := lc.updateLogs.Remote.Front(); e != nil; e = e.Next() {
pd := e.Value
// Skip all remote updates that we have already included in our
// commit chain.
if pd.LogIndex < lastRemoteCommitted {
continue
}
// Skip all remote updates that we haven't acked yet. At the
// moment this function is called, there shouldn't be any, but
// we check it anyway to make this function more generally
// usable.
if pd.LogIndex >= lastLocalCommitted {
continue
}
logUpdates = append(logUpdates, pd.toLogUpdate())
}
return logUpdates
}
// CalcFeeBuffer calculates a FeeBuffer in accordance with the recommended
// amount specified in BOLT 02. It accounts for two times the current fee rate
// plus an additional htlc at this higher fee rate which allows our peer to add
// an htlc even if our channel is drained locally.
// See: https://github.com/lightning/bolts/blob/master/02-peer-protocol.md
func CalcFeeBuffer(feePerKw chainfee.SatPerKWeight,
commitWeight lntypes.WeightUnit) lnwire.MilliSatoshi {
// Account for a 100% in fee rate increase.
bufferFeePerKw := 2 * feePerKw
feeBuffer := lnwire.NewMSatFromSatoshis(
// Account for an additional htlc at the higher fee level.
bufferFeePerKw.FeeForWeight(commitWeight + input.HTLCWeight),
)
return feeBuffer
}
// BufferType is used to determine what kind of additional buffer should be left
// when evaluating the usable balance of a channel.
type BufferType uint8
const (
// NoBuffer means no additional buffer is accounted for. This is
// important when verifying an already locked-in commitment state.
NoBuffer BufferType = iota
// FeeBuffer accounts for several edge cases. One of them is where
// a locally drained channel might become unusable due to the non-opener
// of the channel not being able to add a non-dust htlc to the channel
// state because we as a channel opener cannot pay the additional fees
// an htlc would require on the commitment tx.
// See: https://github.com/lightningnetwork/lightning-rfc/issues/728
//
// Moreover it mitigates the situation where htlcs are added
// simultaneously to the commitment transaction. This cannot be avoided
// until the feature __option_simplified_update__ is available in the
// protocol and deployed widely in the network.
// More information about the issue and the simplified commitment flow
// can be found here:
// https://github.com/lightningnetwork/lnd/issues/7657
// https://github.com/lightning/bolts/pull/867
//
// The last advantage is that we can react to fee spikes (up or down)
// by accounting for at least twice the size of the current fee rate
// (BOLT02). It also accounts for decreases in the fee rate because
// former dust htlcs might now become normal outputs so the overall
// fee might increase although the fee rate decreases (this is only true
// for non-anchor channels because htlcs have to account for their
// fee of the second-level covenant transactions).
FeeBuffer
// AdditionalHtlc just accounts for an additional htlc which is helpful
// when deciding about a fee update of the commitment transaction.
// Leaving always room for an additional htlc makes sure that even
// though we are the opener of a channel a new fee update will always
// allow an htlc from our peer to be added to the commitment tx.
AdditionalHtlc
)
// String returns a human readable name for the buffer type.
func (b BufferType) String() string {
switch b {
case NoBuffer:
return "nobuffer"
case FeeBuffer:
return "feebuffer"
case AdditionalHtlc:
return "additionalhtlc"
default:
return "unknown"
}
}
// applyCommitFee applies the commitFee including a buffer to the balance amount
// and verifies that it does not become negative. This function returns the new
// balance, the exact buffer amount (excluding the commitment fee) and the
// commitment fee.
func (lc *LightningChannel) applyCommitFee(
balance lnwire.MilliSatoshi, commitWeight lntypes.WeightUnit,
feePerKw chainfee.SatPerKWeight,
buffer BufferType) (lnwire.MilliSatoshi, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, error) {
commitFee := feePerKw.FeeForWeight(commitWeight)
commitFeeMsat := lnwire.NewMSatFromSatoshis(commitFee)
var bufferAmt lnwire.MilliSatoshi
switch buffer {
// The FeeBuffer is subtracted from the balance. It is of predefined
// size add keeps room for an up to 2x increase in fees of the
// commitment tx and an additional htlc at this fee level reserved for
// the peer.
case FeeBuffer:
// Make sure that we are the initiator of the channel before we
// apply the FeeBuffer.
if !lc.channelState.IsInitiator {
return 0, 0, 0, ErrFeeBufferNotInitiator
}
// The FeeBuffer already includes the commitFee.
bufferAmt = CalcFeeBuffer(feePerKw, commitWeight)
if bufferAmt < balance {
newBalance := balance - bufferAmt
return newBalance, bufferAmt - commitFeeMsat,
commitFeeMsat, nil
}
// The AdditionalHtlc buffer type does NOT keep a FeeBuffer but solely
// keeps space for an additional htlc on the commitment tx which our
// peer can add.
case AdditionalHtlc:
additionalHtlcFee := lnwire.NewMSatFromSatoshis(
feePerKw.FeeForWeight(input.HTLCWeight),
)
bufferAmt = commitFeeMsat + additionalHtlcFee
newBalance := balance - bufferAmt
if bufferAmt < balance {
return newBalance, additionalHtlcFee, commitFeeMsat, nil
}
// The default case does not account for any buffer on the local balance
// but just subtracts the commit fee.
default:
if commitFeeMsat < balance {
newBalance := balance - commitFeeMsat
return newBalance, 0, commitFeeMsat, nil
}
}
// We still return the amount and bufferAmt here to log them at a later
// stage.
return balance, bufferAmt, commitFeeMsat, ErrBelowChanReserve
}
// validateCommitmentSanity is used to validate the current state of the
// commitment transaction in terms of the ChannelConstraints that we and our
// remote peer agreed upon during the funding workflow. The
// predict[Our|Their]Add should parameters should be set to a valid
// paymentDescriptor if we are validating in the state when adding a new HTLC,
// or nil otherwise.
func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
ourLogCounter uint64, whoseCommitChain lntypes.ChannelParty,
buffer BufferType, predictOurAdd, predictTheirAdd *paymentDescriptor,
) error {
// First fetch the initial balance before applying any updates.
commitChain := lc.commitChains.Local
if whoseCommitChain.IsRemote() {
commitChain = lc.commitChains.Remote
}
ourInitialBalance := commitChain.tip().ourBalance
theirInitialBalance := commitChain.tip().theirBalance
// Fetch all updates not committed.
view := lc.fetchHTLCView(theirLogCounter, ourLogCounter)
// If we are checking if we can add a new HTLC, we add this to the
// appropriate update log, in order to validate the sanity of the
// commitment resulting from _actually adding_ this HTLC to the state.
if predictOurAdd != nil {
view.Updates.Local = append(view.Updates.Local, predictOurAdd)
}
if predictTheirAdd != nil {
view.Updates.Remote = append(
view.Updates.Remote, predictTheirAdd,
)
}
ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView(
view, whoseCommitChain, false,
fn.None[chainfee.SatPerKWeight](),
)
if err != nil {
return err
}
feePerKw := filteredView.FeePerKw
// Ensure that the fee being applied is enough to be relayed across the
// network in a reasonable time frame.
if feePerKw < chainfee.FeePerKwFloor {
return fmt.Errorf("commitment fee per kw %v below fee floor %v",
feePerKw, chainfee.FeePerKwFloor)
}
// The channel opener has to account for the commitment fee. This
// includes also a buffer type. Depending on whether we are the opener
// of the channel we either want to enforce a buffer on the local
// amount.
var (
bufferAmt lnwire.MilliSatoshi
commitFee lnwire.MilliSatoshi
)
if lc.channelState.IsInitiator {
ourBalance, bufferAmt, commitFee, err = lc.applyCommitFee(
ourBalance, commitWeight, feePerKw, buffer,
)
if err != nil {
lc.log.Errorf("Cannot pay for the CommitmentFee of "+
"the ChannelState: ourBalance is negative "+
"after applying the fee: ourBalance=%v, "+
"commitFee=%v, feeBuffer=%v (type=%v) "+
"local_chan_initiator", ourBalance,
commitFee, bufferAmt, buffer)
return err
}
} else {
// No FeeBuffer is enforced when we are not the initiator of
// the channel. We cannot do this, because if our peer does not
// enforce the FeeBuffer (older LND software) the peer might
// bring his balance below the FeeBuffer making the channel
// stuck because locally we will never put another outgoing HTLC
// on the channel state. The FeeBuffer should ONLY be enforced
// if we locally pay for the commitment transaction.
theirBalance, bufferAmt, commitFee, err = lc.applyCommitFee(
theirBalance, commitWeight, feePerKw, NoBuffer,
)
if err != nil {
lc.log.Errorf("Cannot pay for the CommitmentFee "+
"of the ChannelState: theirBalance is "+
"negative after applying the fee: "+
"theirBalance=%v, commitFee=%v, feeBuffer=%v "+
"(type=%v) remote_chan_initiator",
theirBalance, commitFee, bufferAmt, buffer)
return err
}
}
// The commitment fee was accounted for successfully now make sure we
// still do have enough left to account for the channel reserve.
// If the added HTLCs will decrease the balance, make sure they won't
// dip the local and remote balances below the channel reserves.
ourReserve := lnwire.NewMSatFromSatoshis(
lc.channelState.LocalChanCfg.ChanReserve,
)
theirReserve := lnwire.NewMSatFromSatoshis(
lc.channelState.RemoteChanCfg.ChanReserve,
)
switch {
// TODO(ziggie): Allow the peer dip us below the channel reserve when
// our local balance would increase during this commitment dance or
// allow us to dip the peer below its reserve then their balance would
// increase during this commitment dance. This is needed for splicing
// when e.g. a new channel (bigger capacity) has a higher required
// reserve and the peer would need to add an additional htlc to push the
// missing amount to our side and viceversa.
// See: https://github.com/lightningnetwork/lnd/issues/8249
case ourBalance < ourInitialBalance && ourBalance < ourReserve:
lc.log.Debugf("Funds below chan reserve: ourBalance=%v, "+
"ourReserve=%v, commitFee=%v, feeBuffer=%v "+
"chan_initiator=%v", ourBalance, ourReserve,
commitFee, bufferAmt, lc.channelState.IsInitiator)
return fmt.Errorf("%w: our balance below chan reserve",
ErrBelowChanReserve)
case theirBalance < theirInitialBalance && theirBalance < theirReserve:
lc.log.Debugf("Funds below chan reserve: theirBalance=%v, "+
"theirReserve=%v", theirBalance, theirReserve)
return fmt.Errorf("%w: their balance below chan reserve",
ErrBelowChanReserve)
}
// validateUpdates take a set of updates, and validates them against
// the passed channel constraints.
validateUpdates := func(updates []*paymentDescriptor,
constraints *channeldb.ChannelConfig) error {
// We keep track of the number of HTLCs in flight for the
// commitment, and the amount in flight.
var numInFlight uint16
var amtInFlight lnwire.MilliSatoshi
// Go through all updates, checking that they don't violate the
// channel constraints.
for _, entry := range updates {
if entry.EntryType == Add {
// An HTLC is being added, this will add to the
// number and amount in flight.
amtInFlight += entry.Amount
numInFlight++
// Check that the HTLC amount is positive.
if entry.Amount == 0 {
return ErrInvalidHTLCAmt
}
// Check that the value of the HTLC they added
// is above our minimum.
if entry.Amount < constraints.MinHTLC {
return ErrBelowMinHTLC
}
}
}
// Now that we know the total value of added HTLCs, we check
// that this satisfy the MaxPendingAmont constraint.
if amtInFlight > constraints.MaxPendingAmount {
return ErrMaxPendingAmount
}
// In this step, we verify that the total number of active
// HTLCs does not exceed the constraint of the maximum number
// of HTLCs in flight.
if numInFlight > constraints.MaxAcceptedHtlcs {
return ErrMaxHTLCNumber
}
return nil
}
// First check that the remote updates won't violate it's channel
// constraints.
err = validateUpdates(
filteredView.Updates.Remote, &lc.channelState.RemoteChanCfg,
)
if err != nil {
return err
}
// Secondly check that our updates won't violate our channel
// constraints.
err = validateUpdates(
filteredView.Updates.Local, &lc.channelState.LocalChanCfg,
)
if err != nil {
return err
}
return nil
}
// CommitSigs holds the set of related signatures for a new commitment
// transaction state.
type CommitSigs struct {
// CommitSig is the normal commitment signature. This will only be a
// non-zero commitment signature for non-taproot channels.
CommitSig lnwire.Sig
// HtlcSigs is the set of signatures for all HTLCs in the commitment
// transaction. Depending on the channel type, these will either be
// ECDSA or Schnorr signatures.
HtlcSigs []lnwire.Sig
// PartialSig is the musig2 partial signature for taproot commitment
// transactions.
PartialSig lnwire.OptPartialSigWithNonceTLV
// AuxSigBlob is the blob containing all the auxiliary signatures for
// this new commitment state.
AuxSigBlob tlv.Blob
}
// NewCommitState wraps the various signatures needed to properly
// propose/accept a new commitment state. This includes the signer's nonce for
// musig2 channels.
type NewCommitState struct {
*CommitSigs
// PendingHTLCs is the set of new/pending HTLCs produced by this
// commitment state.
PendingHTLCs []channeldb.HTLC
}
// SignNextCommitment signs a new commitment which includes any previous
// unsettled HTLCs, any new HTLCs, and any modifications to prior HTLCs
// committed in previous commitment updates. Signing a new commitment
// decrements the available revocation window by 1. After a successful method
// call, the remote party's commitment chain is extended by a new commitment
// which includes all updates to the HTLC log prior to this method invocation.
// The first return parameter is the signature for the commitment transaction
// itself, while the second parameter is a slice of all HTLC signatures (if
// any). The HTLC signatures are sorted according to the BIP 69 order of the
// HTLC's on the commitment transaction. Finally, the new set of pending HTLCs
// for the remote party's commitment are also returned.
//
//nolint:funlen
func (lc *LightningChannel) SignNextCommitment(
ctx context.Context) (*NewCommitState, error) {
lc.Lock()
defer lc.Unlock()
// Check for empty commit sig. This should never happen, but we don't
// dare to fail hard here. We assume peers can deal with the empty sig
// and continue channel operation. We log an error so that the bug
// causing this can be tracked down.
if !lc.oweCommitment(lntypes.Local) {
lc.log.Errorf("sending empty commit sig")
}
var (
sig lnwire.Sig
partialSig *lnwire.PartialSigWithNonce
htlcSigs []lnwire.Sig
)
// If we're awaiting for an ACK to a commitment signature, or if we
// don't yet have the initial next revocation point of the remote
// party, then we're unable to create new states. Each time we create a
// new state, we consume a prior revocation point.
commitPoint := lc.channelState.RemoteNextRevocation
unacked := lc.commitChains.Remote.hasUnackedCommitment()
if unacked || commitPoint == nil {
lc.log.Tracef("waiting for remote ack=%v, nil "+
"RemoteNextRevocation: %v", unacked, commitPoint == nil)
return nil, ErrNoWindow
}
// Determine the last update on the remote log that has been locked in.
remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote
remoteHtlcIndex := lc.commitChains.Local.tail().theirHtlcIndex
// Before we extend this new commitment to the remote commitment chain,
// ensure that we aren't violating any of the constraints the remote
// party set up when we initially set up the channel. If we are, then
// we'll abort this state transition.
// We do not enforce the FeeBuffer here because when we reach this
// point all updates will have to get locked-in so we enforce the
// minimum requirement.
err := lc.validateCommitmentSanity(
remoteACKedIndex, lc.updateLogs.Local.logIndex, lntypes.Remote,
NoBuffer, nil, nil,
)
if err != nil {
return nil, err
}
// Grab the next commitment point for the remote party. This will be
// used within fetchCommitmentView to derive all the keys necessary to
// construct the commitment state.
keyRing := DeriveCommitmentKeys(
commitPoint, lntypes.Remote, lc.channelState.ChanType,
&lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg,
)
// Create a new commitment view which will calculate the evaluated
// state of the remote node's new commitment including our latest added
// HTLCs. The view includes the latest balances for both sides on the
// remote node's chain, and also update the addition height of any new
// HTLC log entries. When we creating a new remote view, we include
// _all_ of our changes (pending or committed) but only the remote
// node's changes up to the last change we've ACK'd.
newCommitView, err := lc.fetchCommitmentView(
lntypes.Remote, lc.updateLogs.Local.logIndex,
lc.updateLogs.Local.htlcCounter, remoteACKedIndex,
remoteHtlcIndex, keyRing,
)
if err != nil {
return nil, err
}
lc.log.Tracef("extending remote chain to height %v, "+
"local_log=%v, remote_log=%v",
newCommitView.height,
lc.updateLogs.Local.logIndex, remoteACKedIndex)
lc.log.Tracef("remote chain: our_balance=%v, "+
"their_balance=%v, commit_tx: %v",
newCommitView.ourBalance,
newCommitView.theirBalance,
lnutils.SpewLogClosure(newCommitView.txn))
// With the commitment view constructed, if there are any HTLC's, we'll
// need to generate signatures of each of them for the remote party's
// commitment state. We do so in two phases: first we generate and
// submit the set of signature jobs to the worker pool.
var leaseExpiry uint32
if lc.channelState.ChanType.HasLeaseExpiration() {
leaseExpiry = lc.channelState.ThawHeight
}
sigBatch, auxSigBatch, cancelChan, err := genRemoteHtlcSigJobs(
keyRing, lc.channelState, leaseExpiry, newCommitView,
lc.leafStore,
)
if err != nil {
return nil, err
}
// We'll need to send over the signatures to the remote party in the
// order as they appear on the commitment transaction after BIP 69
// sorting.
slices.SortFunc(sigBatch, func(i, j SignJob) int {
return cmp.Compare(i.OutputIndex, j.OutputIndex)
})
slices.SortFunc(auxSigBatch, func(i, j AuxSigJob) int {
return cmp.Compare(i.OutputIndex, j.OutputIndex)
})
lc.sigPool.SubmitSignBatch(sigBatch)
err = fn.MapOptionZ(lc.auxSigner, func(a AuxSigner) error {
return a.SubmitSecondLevelSigBatch(
NewAuxChanState(lc.channelState), newCommitView.txn,
auxSigBatch,
)
})
if err != nil {
return nil, fmt.Errorf("error submitting second level sig "+
"batch: %w", err)
}
// While the jobs are being carried out, we'll Sign their version of
// the new commitment transaction while we're waiting for the rest of
// the HTLC signatures to be processed.
//
// TODO(roasbeef): abstract into CommitSigner interface?
if lc.channelState.ChanType.IsTaproot() {
// In this case, we'll send out a partial signature as this is
// a musig2 channel. The encoded normal ECDSA signature will be
// just blank.
remoteSession := lc.musigSessions.RemoteSession
musig, err := remoteSession.SignCommit(
newCommitView.txn,
)
if err != nil {
close(cancelChan)
return nil, err
}
partialSig = musig.ToWireSig()
} else {
lc.signDesc.SigHashes = input.NewTxSigHashesV0Only(
newCommitView.txn,
)
rawSig, err := lc.Signer.SignOutputRaw(
newCommitView.txn, lc.signDesc,
)
if err != nil {
close(cancelChan)
return nil, err
}
sig, err = lnwire.NewSigFromSignature(rawSig)
if err != nil {
close(cancelChan)
return nil, err
}
}
// Iterate through all the responses to gather each of the signatures
// in the order they were submitted.
htlcSigs = make([]lnwire.Sig, 0, len(sigBatch))
auxSigs := make([]fn.Option[tlv.Blob], 0, len(auxSigBatch))
for i := range sigBatch {
htlcSigJob := sigBatch[i]
var jobResp SignJobResp
select {
case jobResp = <-htlcSigJob.Resp:
case <-ctx.Done():
return nil, errQuit
}
// If an error occurred, then we'll cancel any other active
// jobs.
if jobResp.Err != nil {
close(cancelChan)
return nil, jobResp.Err
}
htlcSigs = append(htlcSigs, jobResp.Sig)
if lc.auxSigner.IsNone() {
continue
}
auxHtlcSigJob := auxSigBatch[i]
var auxJobResp AuxSigJobResp
select {
case auxJobResp = <-auxHtlcSigJob.Resp:
case <-ctx.Done():
return nil, errQuit
}
// If an error occurred, then we'll cancel any other active
// jobs.
if auxJobResp.Err != nil {
close(cancelChan)
return nil, auxJobResp.Err
}
auxSigs = append(auxSigs, auxJobResp.SigBlob)
}
// As we're about to proposer a new commitment state for the remote
// party, we'll write this pending state to disk before we exit, so we
// can retransmit it if necessary.
commitDiff, err := lc.createCommitDiff(
newCommitView, sig, htlcSigs, auxSigs,
)
if err != nil {
return nil, err
}
err = lc.channelState.AppendRemoteCommitChain(commitDiff)
if err != nil {
return nil, err
}
// TODO(roasbeef): check that one eclair bug
// * need to retransmit on first state still?
// * after initial reconnect
// Extend the remote commitment chain by one with the addition of our
// latest commitment update.
lc.commitChains.Remote.addCommitment(newCommitView)
auxSigBlob, err := commitDiff.CommitSig.CustomRecords.Serialize()
if err != nil {
return nil, fmt.Errorf("unable to serialize aux sig blob: %w",
err)
}
return &NewCommitState{
CommitSigs: &CommitSigs{
CommitSig: sig,
HtlcSigs: htlcSigs,
PartialSig: lnwire.MaybePartialSigWithNonce(partialSig),
AuxSigBlob: auxSigBlob,
},
PendingHTLCs: commitDiff.Commitment.Htlcs,
}, nil
}
// resignMusigCommit is used to resign a commitment transaction for taproot
// channels when we need to retransmit a signature after a channel reestablish
// message. Taproot channels use musig2, which means we must use fresh nonces
// each time. After we receive the channel reestablish message, we learn the
// nonce we need to use for the remote party. As a result, we need to generate
// the partial signature again with the new nonce.
func (lc *LightningChannel) resignMusigCommit(
commitTx *wire.MsgTx) (lnwire.OptPartialSigWithNonceTLV, error) {
remoteSession := lc.musigSessions.RemoteSession
musig, err := remoteSession.SignCommit(commitTx)
if err != nil {
var none lnwire.OptPartialSigWithNonceTLV
return none, err
}
partialSig := lnwire.MaybePartialSigWithNonce(musig.ToWireSig())
return partialSig, nil
}
// ProcessChanSyncMsg processes a ChannelReestablish message sent by the remote
// connection upon re establishment of our connection with them. This method
// will return a single message if we are currently out of sync, otherwise a
// nil lnwire.Message will be returned. If it is decided that our level of
// de-synchronization is irreconcilable, then an error indicating the issue
// will be returned. In this case that an error is returned, the channel should
// be force closed, as we cannot continue updates.
//
// One of two message sets will be returned:
//
// - CommitSig+Updates: if we have a pending remote commit which they claim to
// have not received
// - RevokeAndAck: if we sent a revocation message that they claim to have
// not received
//
// If we detect a scenario where we need to send a CommitSig+Updates, this
// method also returns two sets models.CircuitKeys identifying the circuits
// that were opened and closed, respectively, as a result of signing the
// previous commitment txn. This allows the link to clear its mailbox of those
// circuits in case they are still in memory, and ensure the switch's circuit
// map has been updated by deleting the closed circuits.
//
//nolint:funlen
func (lc *LightningChannel) ProcessChanSyncMsg(ctx context.Context,
msg *lnwire.ChannelReestablish) ([]lnwire.Message, []models.CircuitKey,
[]models.CircuitKey, error) {
// Now we'll examine the state we have, vs what was contained in the
// chain sync message. If we're de-synchronized, then we'll send a
// batch of messages which when applied will kick start the chain
// resync.
var (
updates []lnwire.Message
openedCircuits []models.CircuitKey
closedCircuits []models.CircuitKey
)
// If the remote party included the optional fields, then we'll verify
// their correctness first, as it will influence our decisions below.
hasRecoveryOptions := msg.LocalUnrevokedCommitPoint != nil
if hasRecoveryOptions && msg.RemoteCommitTailHeight != 0 {
// We'll check that they've really sent a valid commit
// secret from our shachain for our prior height, but only if
// this isn't the first state.
heightSecret, err := lc.channelState.RevocationProducer.AtIndex(
msg.RemoteCommitTailHeight - 1,
)
if err != nil {
return nil, nil, nil, err
}
commitSecretCorrect := bytes.Equal(
heightSecret[:], msg.LastRemoteCommitSecret[:],
)
// If the commit secret they sent is incorrect then we'll fail
// the channel as the remote node has an inconsistent state.
if !commitSecretCorrect {
// In this case, we'll return an error to indicate the
// remote node sent us the wrong values. This will let
// the caller act accordingly.
lc.log.Errorf("sync failed: remote provided invalid " +
"commit secret!")
return nil, nil, nil, ErrInvalidLastCommitSecret
}
}
// If this is a taproot channel, then we expect that the remote party
// has sent the next verification nonce. If they haven't, then we'll
// bail out, otherwise we'll init our local session then continue as
// normal.
switch {
case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsNone():
return nil, nil, nil, fmt.Errorf("remote verification nonce " +
"not sent")
case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsSome():
if lc.opts.skipNonceInit {
// Don't call InitRemoteMusigNonces if we have already
// done so.
break
}
nextNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoNonce)
if err != nil {
return nil, nil, nil, err
}
err = lc.InitRemoteMusigNonces(&musig2.Nonces{
PubNonce: nextNonce,
})
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to init "+
"remote nonce: %w", err)
}
}
// If we detect that this is is a restored channel, then we can skip a
// portion of the verification, as we already know that we're unable to
// proceed with any updates.
isRestoredChan := lc.channelState.HasChanStatus(
channeldb.ChanStatusRestored,
)
// Take note of our current commit chain heights before we begin adding
// more to them.
var (
localTailHeight = lc.commitChains.Local.tail().height
remoteTailHeight = lc.commitChains.Remote.tail().height
remoteTipHeight = lc.commitChains.Remote.tip().height
)
// We'll now check that their view of our local chain is up-to-date.
// This means checking that what their view of our local chain tail
// height is what they believe. Note that the tail and tip height will
// always be the same for the local chain at this stage, as we won't
// store any received commitment to disk before it is ACKed.
switch {
// If their reported height for our local chain tail is ahead of our
// view, then we're behind!
case msg.RemoteCommitTailHeight > localTailHeight || isRestoredChan:
lc.log.Errorf("sync failed with local data loss: remote "+
"believes our tail height is %v, while we have %v!",
msg.RemoteCommitTailHeight, localTailHeight)
if isRestoredChan {
lc.log.Warnf("detected restored triggering DLP")
}
// We must check that we had recovery options to ensure the
// commitment secret matched up, and the remote is just not
// lying about its height.
if !hasRecoveryOptions {
// At this point we the remote is either lying about
// its height, or we are actually behind but the remote
// doesn't support data loss protection. In either case
// it is not safe for us to keep using the channel, so
// we mark it borked and fail the channel.
lc.log.Errorf("sync failed: local data loss, but no " +
"recovery option.")
return nil, nil, nil, ErrCannotSyncCommitChains
}
// In this case, we've likely lost data and shouldn't proceed
// with channel updates.
return nil, nil, nil, &ErrCommitSyncLocalDataLoss{
ChannelPoint: lc.channelState.FundingOutpoint,
CommitPoint: msg.LocalUnrevokedCommitPoint,
}
// If the height of our commitment chain reported by the remote party
// is behind our view of the chain, then they probably lost some state,
// and we'll force close the channel.
case msg.RemoteCommitTailHeight+1 < localTailHeight:
lc.log.Errorf("sync failed: remote believes our tail height is "+
"%v, while we have %v!",
msg.RemoteCommitTailHeight, localTailHeight)
return nil, nil, nil, ErrCommitSyncRemoteDataLoss
// Their view of our commit chain is consistent with our view.
case msg.RemoteCommitTailHeight == localTailHeight:
// In sync, don't have to do anything.
// We owe them a revocation if the tail of our current commitment chain
// is one greater than what they _think_ our commitment tail is. In
// this case we'll re-send the last revocation message that we sent.
// This will be the revocation message for our prior chain tail.
case msg.RemoteCommitTailHeight+1 == localTailHeight:
lc.log.Debugf("sync: remote believes our tail height is %v, "+
"while we have %v, we owe them a revocation",
msg.RemoteCommitTailHeight, localTailHeight)
heightToRetransmit := localTailHeight - 1
revocationMsg, err := lc.generateRevocation(heightToRetransmit)
if err != nil {
return nil, nil, nil, err
}
updates = append(updates, revocationMsg)
// Next, as a precaution, we'll check a special edge case. If
// they initiated a state transition, we sent the revocation,
// but died before the signature was sent. We re-transmit our
// revocation, but also initiate a state transition to re-sync
// them.
if lc.OweCommitment() {
newCommit, err := lc.SignNextCommitment(ctx)
switch {
// If we signed this state, then we'll accumulate
// another update to send over.
case err == nil:
customRecords, err := lnwire.ParseCustomRecords(
newCommit.AuxSigBlob,
)
if err != nil {
sErr := fmt.Errorf("error parsing aux "+
"sigs: %w", err)
return nil, nil, nil, sErr
}
commitSig := &lnwire.CommitSig{
ChanID: lnwire.NewChanIDFromOutPoint(
lc.channelState.FundingOutpoint,
),
CommitSig: newCommit.CommitSig,
HtlcSigs: newCommit.HtlcSigs,
PartialSig: newCommit.PartialSig,
CustomRecords: customRecords,
}
updates = append(updates, commitSig)
// If we get a failure due to not knowing their next
// point, then this is fine as they'll either send
// ChannelReady, or revoke their next state to allow
// us to continue forwards.
case err == ErrNoWindow:
// Otherwise, this is an error and we'll treat it as
// such.
default:
return nil, nil, nil, err
}
}
// There should be no other possible states.
default:
lc.log.Errorf("sync failed: remote believes our tail height is "+
"%v, while we have %v!",
msg.RemoteCommitTailHeight, localTailHeight)
return nil, nil, nil, ErrCannotSyncCommitChains
}
// Now check if our view of the remote chain is consistent with what
// they tell us.
switch {
// The remote's view of what their next commit height is 2+ states
// ahead of us, we most likely lost data, or the remote is trying to
// trick us. Since we have no way of verifying whether they are lying
// or not, we will fail the channel, but should not force close it
// automatically.
case msg.NextLocalCommitHeight > remoteTipHeight+1:
lc.log.Errorf("sync failed: remote's next commit height is %v, "+
"while we believe it is %v!",
msg.NextLocalCommitHeight, remoteTipHeight+1)
return nil, nil, nil, ErrCannotSyncCommitChains
// They are waiting for a state they have already ACKed.
case msg.NextLocalCommitHeight <= remoteTailHeight:
lc.log.Errorf("sync failed: remote's next commit height is %v, "+
"while we believe it is %v!",
msg.NextLocalCommitHeight, remoteTipHeight+1)
// They previously ACKed our current tail, and now they are
// waiting for it. They probably lost state.
return nil, nil, nil, ErrCommitSyncRemoteDataLoss
// They have received our latest commitment, life is good.
case msg.NextLocalCommitHeight == remoteTipHeight+1:
// We owe them a commitment if the tip of their chain (from our Pov) is
// equal to what they think their next commit height should be. We'll
// re-send all the updates necessary to recreate this state, along
// with the commit sig.
case msg.NextLocalCommitHeight == remoteTipHeight:
lc.log.Debugf("sync: remote's next commit height is %v, while "+
"we believe it is %v, we owe them a commitment",
msg.NextLocalCommitHeight, remoteTipHeight+1)
// Grab the current remote chain tip from the database. This
// commit diff contains all the information required to re-sync
// our states.
commitDiff, err := lc.channelState.RemoteCommitChainTip()
if err != nil {
return nil, nil, nil, err
}
var commitUpdates []lnwire.Message
// Next, we'll need to send over any updates we sent as part of
// this new proposed commitment state.
for _, logUpdate := range commitDiff.LogUpdates {
commitUpdates = append(
commitUpdates, logUpdate.UpdateMsg,
)
}
// If this is a taproot channel, then we need to regenerate the
// musig2 signature for the remote party, using their fresh
// nonce.
if lc.channelState.ChanType.IsTaproot() {
partialSig, err := lc.resignMusigCommit(
commitDiff.Commitment.CommitTx,
)
if err != nil {
return nil, nil, nil, err
}
commitDiff.CommitSig.PartialSig = partialSig
}
// With the batch of updates accumulated, we'll now re-send the
// original CommitSig message required to re-sync their remote
// commitment chain with our local version of their chain.
commitUpdates = append(commitUpdates, commitDiff.CommitSig)
// NOTE: If a revocation is not owed, then updates is empty.
if lc.channelState.LastWasRevoke {
// If lastWasRevoke is set to true, a revocation was last and we
// need to reorder the updates so that the revocation stored in
// updates comes after the LogUpdates+CommitSig.
//
// ---logupdates--->
// ---commitsig---->
// ---revocation--->
updates = append(commitUpdates, updates...)
} else {
// Otherwise, the revocation should come before LogUpdates
// + CommitSig.
//
// ---revocation--->
// ---logupdates--->
// ---commitsig---->
updates = append(updates, commitUpdates...)
}
openedCircuits = commitDiff.OpenedCircuitKeys
closedCircuits = commitDiff.ClosedCircuitKeys
// There should be no other possible states as long as the commit chain
// can have at most two elements. If that's the case, something is
// wrong.
default:
lc.log.Errorf("sync failed: remote's next commit height is %v, "+
"while we believe it is %v!",
msg.NextLocalCommitHeight, remoteTipHeight)
return nil, nil, nil, ErrCannotSyncCommitChains
}
// If we didn't have recovery options, then the final check cannot be
// performed, and we'll return early.
if !hasRecoveryOptions {
return updates, openedCircuits, closedCircuits, nil
}
// At this point we have determined that either the commit heights are
// in sync, or that we are in a state we can recover from. As a final
// check, we ensure that the commitment point sent to us by the remote
// is valid.
var commitPoint *btcec.PublicKey
switch {
// If their height is one beyond what we know their current height to
// be, then we need to compare their current unrevoked commitment point
// as that's what they should send.
case msg.NextLocalCommitHeight == remoteTailHeight+1:
commitPoint = lc.channelState.RemoteCurrentRevocation
// Alternatively, if their height is two beyond what we know their best
// height to be, then they're holding onto two commitments, and the
// highest unrevoked point is their next revocation.
//
// TODO(roasbeef): verify this in the spec...
case msg.NextLocalCommitHeight == remoteTailHeight+2:
commitPoint = lc.channelState.RemoteNextRevocation
}
// Only if this is a tweakless channel will we attempt to verify the
// commitment point, as otherwise it has no validity requirements.
tweakless := lc.channelState.ChanType.IsTweakless()
if !tweakless && commitPoint != nil &&
!commitPoint.IsEqual(msg.LocalUnrevokedCommitPoint) {
lc.log.Errorf("sync failed: remote sent invalid commit point "+
"for height %v!",
msg.NextLocalCommitHeight)
return nil, nil, nil, ErrInvalidLocalUnrevokedCommitPoint
}
return updates, openedCircuits, closedCircuits, nil
}
// computeView takes the given HtlcView, and calculates the balances, filtered
// view (settling unsettled HTLCs), commitment weight and feePerKw, after
// applying the HTLCs to the latest commitment. The returned balances are the
// balances *before* subtracting the commitment fee from the initiator's
// balance. It accepts a "dry run" feerate argument to calculate a potential
// commitment transaction fee.
//
// If the updateState boolean is set true, the add and remove heights of the
// HTLCs will be set to the next commitment height.
func (lc *LightningChannel) computeView(view *HtlcView,
whoseCommitChain lntypes.ChannelParty, updateState bool,
dryRunFee fn.Option[chainfee.SatPerKWeight]) (lnwire.MilliSatoshi,
lnwire.MilliSatoshi, lntypes.WeightUnit, *HtlcView, error) {
commitChain := lc.commitChains.Local
dustLimit := lc.channelState.LocalChanCfg.DustLimit
if whoseCommitChain.IsRemote() {
commitChain = lc.commitChains.Remote
dustLimit = lc.channelState.RemoteChanCfg.DustLimit
}
// Since the fetched htlc view will include all updates added after the
// last committed state, we start with the balances reflecting that
// state.
ourBalance := commitChain.tip().ourBalance
theirBalance := commitChain.tip().theirBalance
// Add the fee from the previous commitment state back to the
// initiator's balance, so that the fee can be recalculated and
// re-applied in case fee estimation parameters have changed or the
// number of outstanding HTLCs has changed.
if lc.channelState.IsInitiator {
ourBalance += lnwire.NewMSatFromSatoshis(
commitChain.tip().fee)
} else if !lc.channelState.IsInitiator {
theirBalance += lnwire.NewMSatFromSatoshis(
commitChain.tip().fee)
}
nextHeight := commitChain.tip().height + 1
// Initiate feePerKw to the last committed fee for this chain as we'll
// need this to determine which HTLCs are dust, and also the final fee
// rate.
view.FeePerKw = commitChain.tip().feePerKw
view.NextHeight = nextHeight
// We evaluate the view at this stage, meaning settled and failed HTLCs
// will remove their corresponding added HTLCs. The resulting filtered
// view will only have Add entries left, making it easy to compare the
// channel constraints to the final commitment state. If any fee
// updates are found in the logs, the commitment fee rate should be
// changed, so we'll also set the feePerKw to this new value.
filteredHTLCView, uncommitted, deltas, err := lc.evaluateHTLCView(
view, whoseCommitChain, nextHeight,
)
if err != nil {
return 0, 0, 0, nil, err
}
// Add the balance deltas to the balances we got from the commitment
// state.
if deltas.Local >= 0 {
ourBalance += lnwire.MilliSatoshi(deltas.Local)
} else {
ourBalance -= lnwire.MilliSatoshi(-1 * deltas.Local)
}
if deltas.Remote >= 0 {
theirBalance += lnwire.MilliSatoshi(deltas.Remote)
} else {
theirBalance -= lnwire.MilliSatoshi(-1 * deltas.Remote)
}
if updateState {
for _, party := range lntypes.BothParties {
for _, u := range uncommitted.GetForParty(party) {
u.setCommitHeight(whoseCommitChain, nextHeight)
if whoseCommitChain == lntypes.Local &&
u.EntryType == Settle {
lc.recordSettlement(party, u.Amount)
}
}
}
}
feePerKw := filteredHTLCView.FeePerKw
// Here we override the view's fee-rate if a dry-run fee-rate was
// passed in.
if !updateState {
feePerKw = dryRunFee.UnwrapOr(feePerKw)
}
// We need to first check ourBalance and theirBalance to be negative
// because MilliSathoshi is a unsigned type and can underflow in the
// code above. This should never happen for views which do not
// include new updates (remote or local).
if int64(ourBalance) < 0 {
err := fmt.Errorf("%w: our balance", ErrBelowChanReserve)
return 0, 0, 0, nil, err
}
if int64(theirBalance) < 0 {
err := fmt.Errorf("%w: their balance", ErrBelowChanReserve)
return 0, 0, 0, nil, err
}
// Now go through all HTLCs at this stage, to calculate the total
// weight, needed to calculate the transaction fee.
var totalHtlcWeight lntypes.WeightUnit
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
lc.channelState.ChanType, false, whoseCommitChain,
feePerKw, htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
totalHtlcWeight += input.HTLCWeight
}
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
lc.channelState.ChanType, true, whoseCommitChain,
feePerKw, htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
totalHtlcWeight += input.HTLCWeight
}
totalCommitWeight := CommitWeight(lc.channelState.ChanType) +
totalHtlcWeight
return ourBalance, theirBalance, totalCommitWeight, filteredHTLCView, nil
}
// recordSettlement updates the lifetime payment flow values in persistent state
// of the LightningChannel, adding amt to the total received by the redeemer.
func (lc *LightningChannel) recordSettlement(
redeemer lntypes.ChannelParty, amt lnwire.MilliSatoshi) {
if redeemer == lntypes.Local {
lc.channelState.TotalMSatReceived += amt
} else {
lc.channelState.TotalMSatSent += amt
}
}
// genHtlcSigValidationJobs generates a series of signatures verification jobs
// meant to verify all the signatures for HTLC's attached to a newly created
// commitment state. The jobs generated are fully populated, and can be sent
// directly into the pool of workers.
//
//nolint:funlen
func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel,
localCommitmentView *commitment, keyRing *CommitmentKeyRing,
htlcSigs []lnwire.Sig, leaseExpiry uint32,
leafStore fn.Option[AuxLeafStore], auxSigner fn.Option[AuxSigner],
sigBlob fn.Option[tlv.Blob]) ([]VerifyJob, []AuxVerifyJob, error) {
var (
isLocalInitiator = chanState.IsInitiator
localChanCfg = chanState.LocalChanCfg
chanType = chanState.ChanType
)
txHash := localCommitmentView.txn.TxHash()
feePerKw := localCommitmentView.feePerKw
sigHashType := HtlcSigHashType(chanType)
// With the required state generated, we'll create a slice with large
// enough capacity to hold verification jobs for all HTLC's in this
// view. In the case that we have some dust outputs, then the actual
// length will be smaller than the total capacity.
numHtlcs := len(localCommitmentView.incomingHTLCs) +
len(localCommitmentView.outgoingHTLCs)
verifyJobs := make([]VerifyJob, 0, numHtlcs)
auxVerifyJobs := make([]AuxVerifyJob, 0, numHtlcs)
diskCommit := localCommitmentView.toDiskCommit(lntypes.Local)
auxResult, err := fn.MapOptionZ(
leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(chanState), *diskCommit,
*keyRing, lntypes.Local,
)
},
).Unpack()
if err != nil {
return nil, nil, fmt.Errorf("unable to fetch aux leaves: %w",
err)
}
// If we have a sig blob, then we'll attempt to map that to individual
// blobs for each HTLC we might need a signature for.
auxHtlcSigs, err := fn.MapOptionZ(
auxSigner, func(a AuxSigner) fn.Result[[]fn.Option[tlv.Blob]] {
return a.UnpackSigs(sigBlob)
},
).Unpack()
if err != nil {
return nil, nil, fmt.Errorf("error unpacking aux sigs: %w",
err)
}
// We'll iterate through each output in the commitment transaction,
// populating the sigHash closure function if it's detected to be an
// HLTC output. Given the sighash, and the signing key, we'll be able
// to validate each signature within the worker pool.
i := 0
for index := range localCommitmentView.txn.TxOut {
var (
htlcIndex uint64
sigHash func() ([]byte, error)
sig input.Signature
htlc *paymentDescriptor
incoming bool
auxLeaf input.AuxTapLeaf
err error
)
outputIndex := int32(index)
switch {
// If this output index is found within the incoming HTLC
// index, then this means that we need to generate an HTLC
// success transaction in order to validate the signature.
//nolint:ll
case localCommitmentView.incomingHTLCIndex[outputIndex] != nil:
htlc = localCommitmentView.incomingHTLCIndex[outputIndex]
htlcIndex = htlc.HtlcIndex
incoming = true
sigHash = func() ([]byte, error) {
op := wire.OutPoint{
Hash: txHash,
Index: uint32(htlc.localOutputIndex),
}
htlcFee := HtlcSuccessFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
auxLeaf := fn.FlatMapOption(func(
l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
idx := htlc.HtlcIndex
return leaves[idx].SecondLevelLeaf
})(auxResult.AuxLeaves)
successTx, err := CreateHtlcSuccessTx(
chanType, isLocalInitiator, op,
outputAmt, uint32(localChanCfg.CsvDelay),
leaseExpiry, keyRing.RevocationKey,
keyRing.ToLocalKey, auxLeaf,
)
if err != nil {
return nil, err
}
htlcAmt := int64(htlc.Amount.ToSatoshis())
if chanType.IsTaproot() {
// TODO(roasbeef): add abstraction in
// front
prevFetcher := txscript.NewCannedPrevOutputFetcher( //nolint:ll
htlc.ourPkScript, htlcAmt,
)
hashCache := txscript.NewTxSigHashes(
successTx, prevFetcher,
)
tapLeaf := txscript.NewBaseTapLeaf(
htlc.ourWitnessScript,
)
return txscript.CalcTapscriptSignaturehash( //nolint:ll
hashCache, sigHashType,
successTx, 0, prevFetcher,
tapLeaf,
)
}
hashCache := input.NewTxSigHashesV0Only(successTx)
sigHash, err := txscript.CalcWitnessSigHash(
htlc.ourWitnessScript, hashCache,
sigHashType, successTx, 0,
htlcAmt,
)
if err != nil {
return nil, err
}
return sigHash, nil
}
// Make sure there are more signatures left.
if i >= len(htlcSigs) {
return nil, nil, fmt.Errorf("not enough HTLC " +
"signatures")
}
// If this is a taproot channel, then we'll convert it
// to a schnorr signature, so we can get correct type
// from ToSignature below.
if chanType.IsTaproot() {
htlcSigs[i].ForceSchnorr()
}
// With the sighash generated, we'll also store the
// signature so it can be written to disk if this state
// is valid.
sig, err = htlcSigs[i].ToSignature()
if err != nil {
return nil, nil, err
}
htlc.sig = sig
// Otherwise, if this is an outgoing HTLC, then we'll need to
// generate a timeout transaction so we can verify the
// signature presented.
//nolint:ll
case localCommitmentView.outgoingHTLCIndex[outputIndex] != nil:
htlc = localCommitmentView.outgoingHTLCIndex[outputIndex]
htlcIndex = htlc.HtlcIndex
sigHash = func() ([]byte, error) {
op := wire.OutPoint{
Hash: txHash,
Index: uint32(htlc.localOutputIndex),
}
htlcFee := HtlcTimeoutFee(chanType, feePerKw)
outputAmt := htlc.Amount.ToSatoshis() - htlcFee
auxLeaf := fn.FlatMapOption(func(
l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
idx := htlc.HtlcIndex
return leaves[idx].SecondLevelLeaf
})(auxResult.AuxLeaves)
timeoutTx, err := CreateHtlcTimeoutTx(
chanType, isLocalInitiator, op,
outputAmt, htlc.Timeout,
uint32(localChanCfg.CsvDelay),
leaseExpiry, keyRing.RevocationKey,
keyRing.ToLocalKey, auxLeaf,
)
if err != nil {
return nil, err
}
htlcAmt := int64(htlc.Amount.ToSatoshis())
if chanType.IsTaproot() {
// TODO(roasbeef): add abstraction in
// front
prevFetcher := txscript.NewCannedPrevOutputFetcher( //nolint:ll
htlc.ourPkScript, htlcAmt,
)
hashCache := txscript.NewTxSigHashes(
timeoutTx, prevFetcher,
)
tapLeaf := txscript.NewBaseTapLeaf(
htlc.ourWitnessScript,
)
return txscript.CalcTapscriptSignaturehash( //nolint:ll
hashCache, sigHashType,
timeoutTx, 0, prevFetcher,
tapLeaf,
)
}
hashCache := input.NewTxSigHashesV0Only(
timeoutTx,
)
sigHash, err := txscript.CalcWitnessSigHash(
htlc.ourWitnessScript, hashCache,
sigHashType, timeoutTx, 0,
htlcAmt,
)
if err != nil {
return nil, err
}
return sigHash, nil
}
// Make sure there are more signatures left.
if i >= len(htlcSigs) {
return nil, nil, fmt.Errorf("not enough HTLC " +
"signatures")
}
// If this is a taproot channel, then we'll convert it
// to a schnorr signature, so we can get correct type
// from ToSignature below.
if chanType.IsTaproot() {
htlcSigs[i].ForceSchnorr()
}
// With the sighash generated, we'll also store the
// signature so it can be written to disk if this state
// is valid.
sig, err = htlcSigs[i].ToSignature()
if err != nil {
return nil, nil, err
}
htlc.sig = sig
default:
continue
}
verifyJobs = append(verifyJobs, VerifyJob{
HtlcIndex: htlcIndex,
PubKey: keyRing.RemoteHtlcKey,
Sig: sig,
SigHash: sigHash,
})
if len(auxHtlcSigs) > i {
auxSig := auxHtlcSigs[i]
auxVerifyJob := NewAuxVerifyJob(
auxSig, *keyRing, incoming,
newAuxHtlcDescriptor(htlc),
localCommitmentView.customBlob, auxLeaf,
)
if htlc.CustomRecords == nil {
htlc.CustomRecords = make(lnwire.CustomRecords)
}
// As this HTLC has a custom signature associated with
// it, store it in the custom records map so we can
// write to disk later.
sigType := htlcCustomSigType.TypeVal()
htlc.CustomRecords[uint64(sigType)] = auxSig.UnwrapOr(
nil,
)
auxVerifyJobs = append(auxVerifyJobs, auxVerifyJob)
}
i++
}
// If we received a number of HTLC signatures that doesn't match our
// commitment, we'll return an error now.
if len(htlcSigs) != i {
return nil, nil, fmt.Errorf("number of htlc sig mismatch. "+
"Expected %v sigs, got %v", i, len(htlcSigs))
}
return verifyJobs, auxVerifyJobs, nil
}
// InvalidCommitSigError is a struct that implements the error interface to
// report a failure to validate a commitment signature for a remote peer.
// We'll use the items in this struct to generate a rich error message for the
// remote peer when we receive an invalid signature from it. Doing so can
// greatly aide in debugging cross implementation issues.
type InvalidCommitSigError struct {
commitHeight uint64
commitSig []byte
sigHash []byte
commitTx []byte
}
// Error returns a detailed error string including the exact transaction that
// caused an invalid commitment signature.
func (i *InvalidCommitSigError) Error() string {
return fmt.Sprintf("rejected commitment: commit_height=%v, "+
"invalid_commit_sig=%x, commit_tx=%x, sig_hash=%x", i.commitHeight,
i.commitSig[:], i.commitTx, i.sigHash[:])
}
// A compile time flag to ensure that InvalidCommitSigError implements the
// error interface.
var _ error = (*InvalidCommitSigError)(nil)
// InvalidPartialCommitSigError is used when we encounter an invalid musig2
// partial signature.
type InvalidPartialCommitSigError struct {
InvalidCommitSigError
*invalidPartialSigError
}
// Error returns a detailed error string including the exact transaction that
// caused an invalid partial commit sig signature.
func (i *InvalidPartialCommitSigError) Error() string {
return fmt.Sprintf("rejected commitment: commit_height=%v, "+
"commit_tx=%x -- %v", i.commitHeight, i.commitTx,
i.invalidPartialSigError)
}
// InvalidHtlcSigError is a struct that implements the error interface to
// report a failure to validate an htlc signature from a remote peer. We'll use
// the items in this struct to generate a rich error message for the remote
// peer when we receive an invalid signature from it. Doing so can greatly aide
// in debugging across implementation issues.
type InvalidHtlcSigError struct {
commitHeight uint64
htlcSig []byte
htlcIndex uint64
sigHash []byte
commitTx []byte
}
// Error returns a detailed error string including the exact transaction that
// caused an invalid htlc signature.
func (i *InvalidHtlcSigError) Error() string {
return fmt.Sprintf("rejected commitment: commit_height=%v, "+
"invalid_htlc_sig=%x, commit_tx=%x, sig_hash=%x", i.commitHeight,
i.htlcSig, i.commitTx, i.sigHash[:])
}
// A compile time flag to ensure that InvalidCommitSigError implements the
// error interface.
var _ error = (*InvalidCommitSigError)(nil)
// ReceiveNewCommitment process a signature for a new commitment state sent by
// the remote party. This method should be called in response to the
// remote party initiating a new change, or when the remote party sends a
// signature fully accepting a new state we've initiated. If we are able to
// successfully validate the signature, then the generated commitment is added
// to our local commitment chain. Once we send a revocation for our prior
// state, then this newly added commitment becomes our current accepted channel
// state.
//
//nolint:funlen
func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error {
lc.Lock()
defer lc.Unlock()
// Check for empty commit sig. Because of a previously existing bug, it
// is possible that we receive an empty commit sig from nodes running an
// older version. This is a relaxation of the spec, but it is still
// possible to handle it. To not break any channels with those older
// nodes, we just log the event. This check is also not totally
// reliable, because it could be that we've sent out a new sig, but the
// remote hasn't received it yet. We could then falsely assume that they
// should add our updates to their remote commitment tx.
if !lc.oweCommitment(lntypes.Remote) {
lc.log.Warnf("empty commit sig message received")
}
// Determine the last update on the local log that has been locked in.
localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local
localHtlcIndex := lc.commitChains.Remote.tail().ourHtlcIndex
// Ensure that this new local update from the remote node respects all
// the constraints we specified during initial channel setup. If not,
// then we'll abort the channel as they've violated our constraints.
//
// We do not enforce the FeeBuffer here because when we reach this
// point all updates will have to get locked-in (we already received
// the UpdateAddHTLC msg from our peer prior to receiving the
// commit-sig).
err := lc.validateCommitmentSanity(
lc.updateLogs.Remote.logIndex, localACKedIndex, lntypes.Local,
NoBuffer, nil, nil,
)
if err != nil {
return err
}
// We're receiving a new commitment which attempts to extend our local
// commitment chain height by one, so fetch the proper commitment point
// as this will be needed to derive the keys required to construct the
// commitment.
nextHeight := lc.currentHeight + 1
commitSecret, err := lc.channelState.RevocationProducer.AtIndex(nextHeight)
if err != nil {
return err
}
commitPoint := input.ComputeCommitmentPoint(commitSecret[:])
keyRing := DeriveCommitmentKeys(
commitPoint, lntypes.Local, lc.channelState.ChanType,
&lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg,
)
// With the current commitment point re-calculated, construct the new
// commitment view which includes all the entries (pending or committed)
// we know of in the remote node's HTLC log, but only our local changes
// up to the last change the remote node has ACK'd.
localCommitmentView, err := lc.fetchCommitmentView(
lntypes.Local, localACKedIndex, localHtlcIndex,
lc.updateLogs.Remote.logIndex, lc.updateLogs.Remote.htlcCounter,
keyRing,
)
if err != nil {
return err
}
lc.log.Tracef("extending local chain to height %v, "+
"local_log=%v, remote_log=%v",
localCommitmentView.height,
localACKedIndex, lc.updateLogs.Remote.logIndex)
lc.log.Tracef("local chain: our_balance=%v, "+
"their_balance=%v, commit_tx: %v",
localCommitmentView.ourBalance, localCommitmentView.theirBalance,
lnutils.SpewLogClosure(localCommitmentView.txn))
var auxSigBlob fn.Option[tlv.Blob]
if commitSigs.AuxSigBlob != nil {
auxSigBlob = fn.Some(commitSigs.AuxSigBlob)
}
// As an optimization, we'll generate a series of jobs for the worker
// pool to verify each of the HTLC signatures presented. Once
// generated, we'll submit these jobs to the worker pool.
var leaseExpiry uint32
if lc.channelState.ChanType.HasLeaseExpiration() {
leaseExpiry = lc.channelState.ThawHeight
}
verifyJobs, auxVerifyJobs, err := genHtlcSigValidationJobs(
lc.channelState, localCommitmentView, keyRing,
commitSigs.HtlcSigs, leaseExpiry, lc.leafStore, lc.auxSigner,
auxSigBlob,
)
if err != nil {
return err
}
cancelChan := make(chan struct{})
verifyResps := lc.sigPool.SubmitVerifyBatch(verifyJobs, cancelChan)
localCommitTx := localCommitmentView.txn
// While the HTLC verification jobs are proceeding asynchronously,
// we'll ensure that the newly constructed commitment state has a valid
// signature.
//
// To do that we'll, construct the sighash of the commitment
// transaction corresponding to this newly proposed state update. If
// this is a taproot channel, then in order to validate the sighash,
// we'll need to call into the relevant tapscript methods.
if lc.channelState.ChanType.IsTaproot() {
localSession := lc.musigSessions.LocalSession
partialSig, err := commitSigs.PartialSig.UnwrapOrErrV(
errNoPartialSig,
)
if err != nil {
return err
}
// As we want to ensure we never write nonces to disk, we'll
// use the shachain state to generate a nonce for our next
// local state. Similar to generateRevocation, we do height + 2
// (next height + 1) here, as this is for the _next_ local
// state, and we're about to accept height + 1.
localCtrNonce := WithLocalCounterNonce(
nextHeight+1, lc.taprootNonceProducer,
)
nextVerificationNonce, err := localSession.VerifyCommitSig(
localCommitTx, &partialSig, localCtrNonce,
)
if err != nil {
close(cancelChan)
var sigErr invalidPartialSigError
if errors.As(err, &sigErr) {
// If we fail to validate their commitment
// signature, we'll generate a special error to
// send over the protocol. We'll include the
// exact signature and commitment we failed to
// verify against in order to aide debugging.
var txBytes bytes.Buffer
_ = localCommitTx.Serialize(&txBytes)
return &InvalidPartialCommitSigError{
invalidPartialSigError: &sigErr,
InvalidCommitSigError: InvalidCommitSigError{ //nolint:ll
commitHeight: nextHeight,
commitTx: txBytes.Bytes(),
},
}
}
return err
}
// Now that we have the next verification nonce for our local
// session, we'll refresh it to yield a new session we'll use
// for the next incoming signature.
newLocalSession, err := lc.musigSessions.LocalSession.Refresh(
nextVerificationNonce,
)
if err != nil {
return err
}
lc.musigSessions.LocalSession = newLocalSession
} else {
multiSigScript := lc.signDesc.WitnessScript
prevFetcher := txscript.NewCannedPrevOutputFetcher(
multiSigScript, int64(lc.channelState.Capacity),
)
hashCache := txscript.NewTxSigHashes(localCommitTx, prevFetcher)
sigHash, err := txscript.CalcWitnessSigHash(
multiSigScript, hashCache, txscript.SigHashAll,
localCommitTx, 0, int64(lc.channelState.Capacity),
)
if err != nil {
// TODO(roasbeef): fetchview has already mutated the
// HTLCs... * need to either roll-back, or make pure
return err
}
verifyKey := lc.channelState.RemoteChanCfg.MultiSigKey.PubKey
cSig, err := commitSigs.CommitSig.ToSignature()
if err != nil {
return err
}
if !cSig.Verify(sigHash, verifyKey) {
close(cancelChan)
// If we fail to validate their commitment signature,
// we'll generate a special error to send over the
// protocol. We'll include the exact signature and
// commitment we failed to verify against in order to
// aide debugging.
var txBytes bytes.Buffer
_ = localCommitTx.Serialize(&txBytes)
return &InvalidCommitSigError{
commitHeight: nextHeight,
commitSig: commitSigs.CommitSig.ToSignatureBytes(), //nolint:ll
sigHash: sigHash,
commitTx: txBytes.Bytes(),
}
}
}
// With the primary commitment transaction validated, we'll check each
// of the HTLC validation jobs.
for i := 0; i < len(verifyJobs); i++ {
// In the case that a single signature is invalid, we'll exit
// early and cancel all the outstanding verification jobs.
htlcErr := <-verifyResps
if htlcErr != nil {
close(cancelChan)
sig, err := lnwire.NewSigFromSignature(
htlcErr.Sig,
)
if err != nil {
return err
}
sigHash, err := htlcErr.SigHash()
if err != nil {
return err
}
var txBytes bytes.Buffer
err = localCommitTx.Serialize(&txBytes)
if err != nil {
return err
}
return &InvalidHtlcSigError{
commitHeight: nextHeight,
htlcSig: sig.ToSignatureBytes(),
htlcIndex: htlcErr.HtlcIndex,
sigHash: sigHash,
commitTx: txBytes.Bytes(),
}
}
}
// Now that we know all the normal sigs are valid, we'll also verify
// the aux jobs, if any exist.
err = fn.MapOptionZ(lc.auxSigner, func(a AuxSigner) error {
return a.VerifySecondLevelSigs(
NewAuxChanState(lc.channelState), localCommitTx,
auxVerifyJobs,
)
})
if err != nil {
return fmt.Errorf("unable to validate aux sigs: %w", err)
}
// The signature checks out, so we can now add the new commitment to
// our local commitment chain. For regular channels, we can just
// serialize the ECDSA sig. For taproot channels, we'll serialize the
// partial sig that includes the nonce that was used for signing.
if lc.channelState.ChanType.IsTaproot() {
partialSig, err := commitSigs.PartialSig.UnwrapOrErrV(
errNoPartialSig,
)
if err != nil {
return err
}
var sigBytes [lnwire.PartialSigWithNonceLen]byte
b := bytes.NewBuffer(sigBytes[0:0])
if err := partialSig.Encode(b); err != nil {
return err
}
localCommitmentView.sig = sigBytes[:]
} else {
localCommitmentView.sig = commitSigs.CommitSig.ToSignatureBytes() //nolint:ll
}
lc.commitChains.Local.addCommitment(localCommitmentView)
return nil
}
// IsChannelClean returns true if neither side has pending commitments, neither
// side has HTLC's, and all updates are locked in irrevocably. Internally, it
// utilizes the oweCommitment function by calling it for local and remote
// evaluation. We check if we have a pending commitment for our local state
// since this function may be called by sub-systems that are not the link (e.g.
// the rpcserver), and the ReceiveNewCommitment & RevokeCurrentCommitment calls
// are not atomic, even though link processing ensures no updates can happen in
// between.
func (lc *LightningChannel) IsChannelClean() bool {
lc.RLock()
defer lc.RUnlock()
// Check whether we have a pending commitment for our local state.
if lc.commitChains.Local.hasUnackedCommitment() {
return false
}
// Check whether our counterparty has a pending commitment for their
// state.
if lc.commitChains.Remote.hasUnackedCommitment() {
return false
}
// We call ActiveHtlcs to ensure there are no HTLCs on either
// commitment.
if len(lc.channelState.ActiveHtlcs()) != 0 {
return false
}
// Now check that both local and remote commitments are signing the
// same updates.
if lc.oweCommitment(lntypes.Local) {
return false
}
if lc.oweCommitment(lntypes.Remote) {
return false
}
// If we reached this point, the channel has no HTLCs and both
// commitments sign the same updates.
return true
}
// OweCommitment returns a boolean value reflecting whether we need to send
// out a commitment signature because there are outstanding local updates and/or
// updates in the local commit tx that aren't reflected in the remote commit tx
// yet.
func (lc *LightningChannel) OweCommitment() bool {
lc.RLock()
defer lc.RUnlock()
return lc.oweCommitment(lntypes.Local)
}
// NeedCommitment returns a boolean value reflecting whether we are waiting on
// a commitment signature because there are outstanding remote updates and/or
// updates in the remote commit tx that aren't reflected in the local commit tx
// yet.
func (lc *LightningChannel) NeedCommitment() bool {
lc.RLock()
defer lc.RUnlock()
return lc.oweCommitment(lntypes.Remote)
}
// oweCommitment is the internal version of OweCommitment. This function expects
// to be executed with a lock held.
func (lc *LightningChannel) oweCommitment(issuer lntypes.ChannelParty) bool {
var (
remoteUpdatesPending, localUpdatesPending bool
lastLocalCommit = lc.commitChains.Local.tip()
lastRemoteCommit = lc.commitChains.Remote.tip()
perspective string
)
if issuer.IsLocal() {
perspective = "local"
// There are local updates pending if our local update log is
// not in sync with our remote commitment tx.
localUpdatesPending = lc.updateLogs.Local.logIndex !=
lastRemoteCommit.messageIndices.Local
// There are remote updates pending if their remote commitment
// tx (our local commitment tx) contains updates that we don't
// have added to our remote commitment tx yet.
remoteUpdatesPending = lastLocalCommit.messageIndices.Remote !=
lastRemoteCommit.messageIndices.Remote
} else {
perspective = "remote"
// There are local updates pending (local updates from the
// perspective of the remote party) if the remote party has
// updates to their remote tx pending for which they haven't
// signed yet.
localUpdatesPending = lc.updateLogs.Remote.logIndex !=
lastLocalCommit.messageIndices.Remote
// There are remote updates pending (remote updates from the
// perspective of the remote party) if we have updates on our
// remote commitment tx that they haven't added to theirs yet.
remoteUpdatesPending = lastRemoteCommit.messageIndices.Local !=
lastLocalCommit.messageIndices.Local
}
// If any of the conditions above is true, we owe a commitment
// signature.
oweCommitment := localUpdatesPending || remoteUpdatesPending
lc.log.Tracef("%v owes commit: %v (local updates: %v, "+
"remote updates %v)", perspective, oweCommitment,
localUpdatesPending, remoteUpdatesPending)
return oweCommitment
}
// NumPendingUpdates returns the number of updates originated by whoseUpdates
// that have not been committed to the *tip* of whoseCommit's commitment chain.
func (lc *LightningChannel) NumPendingUpdates(whoseUpdates lntypes.ChannelParty,
whoseCommit lntypes.ChannelParty) uint64 {
lc.RLock()
defer lc.RUnlock()
lastCommit := lc.commitChains.GetForParty(whoseCommit).tip()
updateIndex := lc.updateLogs.GetForParty(whoseUpdates).logIndex
return updateIndex - lastCommit.messageIndices.GetForParty(whoseUpdates)
}
// RevokeCurrentCommitment revokes the next lowest unrevoked commitment
// transaction in the local commitment chain. As a result the edge of our
// revocation window is extended by one, and the tail of our local commitment
// chain is advanced by a single commitment. This now lowest unrevoked
// commitment becomes our currently accepted state within the channel. This
// method also returns the set of HTLC's currently active within the commitment
// transaction and the htlcs the were resolved. This return value allows callers
// to act once an HTLC has been locked into our commitment transaction.
func (lc *LightningChannel) RevokeCurrentCommitment() (*lnwire.RevokeAndAck,
[]channeldb.HTLC, map[uint64]bool, error) {
lc.Lock()
defer lc.Unlock()
revocationMsg, err := lc.generateRevocation(lc.currentHeight)
if err != nil {
return nil, nil, nil, err
}
lc.log.Tracef("revoking height=%v, now at height=%v",
lc.commitChains.Local.tail().height,
lc.currentHeight+1)
// Advance our tail, as we've revoked our previous state.
lc.commitChains.Local.advanceTail()
lc.currentHeight++
// Additionally, generate a channel delta for this state transition for
// persistent storage.
chainTail := lc.commitChains.Local.tail()
newCommitment := chainTail.toDiskCommit(lntypes.Local)
// Get the unsigned acked remotes updates that are currently in memory.
// We need them after a restart to sync our remote commitment with what
// is committed locally.
unsignedAckedUpdates := lc.getUnsignedAckedUpdates()
finalHtlcs, err := lc.channelState.UpdateCommitment(
newCommitment, unsignedAckedUpdates,
)
if err != nil {
return nil, nil, nil, err
}
lc.log.Tracef("state transition accepted: "+
"our_balance=%v, their_balance=%v, unsigned_acked_updates=%v",
chainTail.ourBalance,
chainTail.theirBalance,
len(unsignedAckedUpdates))
revocationMsg.ChanID = lnwire.NewChanIDFromOutPoint(
lc.channelState.FundingOutpoint,
)
return revocationMsg, newCommitment.Htlcs, finalHtlcs, nil
}
// ReceiveRevocation processes a revocation sent by the remote party for the
// lowest unrevoked commitment within their commitment chain. We receive a
// revocation either during the initial session negotiation wherein revocation
// windows are extended, or in response to a state update that we initiate. If
// successful, then the remote commitment chain is advanced by a single
// commitment, and a log compaction is attempted.
//
// The returned values correspond to:
// 1. The forwarding package corresponding to the remote commitment height
// that was revoked.
// 2. The set of HTLCs present on the current valid commitment transaction
// for the remote party.
func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
*channeldb.FwdPkg, []channeldb.HTLC, error) {
lc.Lock()
defer lc.Unlock()
// Ensure that the new pre-image can be placed in preimage store.
store := lc.channelState.RevocationStore
revocation, err := chainhash.NewHash(revMsg.Revocation[:])
if err != nil {
return nil, nil, err
}
if err := store.AddNextEntry(revocation); err != nil {
return nil, nil, err
}
// Verify that if we use the commitment point computed based off of the
// revealed secret to derive a revocation key with our revocation base
// point, then it matches the current revocation of the remote party.
currentCommitPoint := lc.channelState.RemoteCurrentRevocation
derivedCommitPoint := input.ComputeCommitmentPoint(revMsg.Revocation[:])
if !derivedCommitPoint.IsEqual(currentCommitPoint) {
return nil, nil, fmt.Errorf("revocation key mismatch")
}
// Now that we've verified that the prior commitment has been properly
// revoked, we'll advance the revocation state we track for the remote
// party: the new current revocation is what was previously the next
// revocation, and the new next revocation is set to the key included
// in the message.
lc.channelState.RemoteCurrentRevocation = lc.channelState.RemoteNextRevocation
lc.channelState.RemoteNextRevocation = revMsg.NextRevocationKey
lc.log.Tracef("remote party accepted state transition, revoked height "+
"%v, now at %v",
lc.commitChains.Remote.tail().height,
lc.commitChains.Remote.tail().height+1)
// Add one to the remote tail since this will be height *after* we write
// the revocation to disk, the local height will remain unchanged.
remoteChainTail := lc.commitChains.Remote.tail().height + 1
localChainTail := lc.commitChains.Local.tail().height
source := lc.ShortChanID()
// Determine the set of htlcs that can be forwarded as a result of
// having received the revocation. We will simultaneously construct the
// log updates and payment descriptors, allowing us to persist the log
// updates to disk and optimistically buffer the forwarding package in
// memory.
var (
addUpdatesToForward []channeldb.LogUpdate
settleFailUpdatesToForward []channeldb.LogUpdate
)
var addIndex, settleFailIndex uint16
for e := lc.updateLogs.Remote.Front(); e != nil; e = e.Next() {
pd := e.Value
// Fee updates are local to this particular channel, and should
// never be forwarded.
if pd.EntryType == FeeUpdate {
continue
}
if pd.isForwarded {
continue
}
// For each type of HTLC, we will only consider forwarding it if
// both of the remote and local heights are non-zero. If either
// of these values is zero, it has yet to be committed in both
// the local and remote chains.
committedAdd := pd.addCommitHeights.Remote > 0 &&
pd.addCommitHeights.Local > 0
committedRmv := pd.removeCommitHeights.Remote > 0 &&
pd.removeCommitHeights.Local > 0
// Using the height of the remote and local commitments,
// preemptively compute whether or not to forward this HTLC for
// the case in which this in an Add HTLC, or if this is a
// Settle, Fail, or MalformedFail.
shouldFwdAdd := remoteChainTail == pd.addCommitHeights.Remote &&
localChainTail >= pd.addCommitHeights.Local
shouldFwdRmv := remoteChainTail ==
pd.removeCommitHeights.Remote &&
localChainTail >= pd.removeCommitHeights.Local
// We'll only forward any new HTLC additions iff, it's "freshly
// locked in". Meaning that the HTLC was only *just* considered
// locked-in at this new state. By doing this we ensure that we
// don't re-forward any already processed HTLC's after a
// restart.
switch {
case pd.EntryType == Add && committedAdd && shouldFwdAdd:
// Construct a reference specifying the location that
// this forwarded Add will be written in the forwarding
// package constructed at this remote height.
pd.SourceRef = &channeldb.AddRef{
Height: remoteChainTail,
Index: addIndex,
}
addIndex++
pd.isForwarded = true
// At this point we put the update into our list of
// updates that we will eventually put into the
// FwdPkg at this height.
addUpdatesToForward = append(
addUpdatesToForward, pd.toLogUpdate(),
)
case pd.EntryType != Add && committedRmv && shouldFwdRmv:
// Construct a reference specifying the location that
// this forwarded Settle/Fail will be written in the
// forwarding package constructed at this remote height.
pd.DestRef = &channeldb.SettleFailRef{
Source: source,
Height: remoteChainTail,
Index: settleFailIndex,
}
settleFailIndex++
pd.isForwarded = true
// At this point we put the update into our list of
// updates that we will eventually put into the
// FwdPkg at this height.
settleFailUpdatesToForward = append(
settleFailUpdatesToForward, pd.toLogUpdate(),
)
default:
// The update was not "freshly locked in" so we will
// ignore it as we construct the forwarding package.
continue
}
}
// We use the remote commitment chain's tip as it will soon become the tail
// once advanceTail is called.
remoteMessageIndex := lc.commitChains.Remote.tip().messageIndices.Local
localMessageIndex := lc.commitChains.Local.tail().messageIndices.Local
localPeerUpdates := lc.unsignedLocalUpdates(
remoteMessageIndex, localMessageIndex,
)
// Now that we have gathered the set of HTLCs to forward, separated by
// type, construct a forwarding package using the height that the remote
// commitment chain will be extended after persisting the revocation.
fwdPkg := channeldb.NewFwdPkg(
source, remoteChainTail, addUpdatesToForward,
settleFailUpdatesToForward,
)
// We will soon be saving the current remote commitment to revocation
// log bucket, which is `lc.channelState.RemoteCommitment`. After that,
// the `RemoteCommitment` will be replaced with a newer version found
// in `CommitDiff`. Thus we need to compute the output indexes here
// before the change since the indexes are meant for the current,
// revoked remote commitment.
ourOutputIndex, theirOutputIndex, err := findOutputIndexesFromRemote(
revocation, lc.channelState, lc.leafStore,
)
if err != nil {
return nil, nil, err
}
// Now that we have a new verification nonce from them, we can refresh
// our remote musig2 session which allows us to create another state.
if lc.channelState.ChanType.IsTaproot() {
localNonce, err := revMsg.LocalNonce.UnwrapOrErrV(errNoNonce)
if err != nil {
return nil, nil, err
}
session, err := lc.musigSessions.RemoteSession.Refresh(
&musig2.Nonces{
PubNonce: localNonce,
},
)
if err != nil {
return nil, nil, err
}
lc.musigSessions.RemoteSession = session
}
// At this point, the revocation has been accepted, and we've rotated
// the current revocation key+hash for the remote party. Therefore we
// sync now to ensure the revocation producer state is consistent with
// the current commitment height and also to advance the on-disk
// commitment chain.
err = lc.channelState.AdvanceCommitChainTail(
fwdPkg, localPeerUpdates,
ourOutputIndex, theirOutputIndex,
)
if err != nil {
return nil, nil, err
}
// Since they revoked the current lowest height in their commitment
// chain, we can advance their chain by a single commitment.
lc.commitChains.Remote.advanceTail()
// As we've just completed a new state transition, attempt to see if we
// can remove any entries from the update log which have been removed
// from the PoV of both commitment chains.
compactLogs(
lc.updateLogs.Local, lc.updateLogs.Remote, localChainTail,
remoteChainTail,
)
remoteHTLCs := lc.channelState.RemoteCommitment.Htlcs
return fwdPkg, remoteHTLCs, nil
}
// LoadFwdPkgs loads any pending log updates from disk and returns the payment
// descriptors to be processed by the link.
func (lc *LightningChannel) LoadFwdPkgs() ([]*channeldb.FwdPkg, error) {
return lc.channelState.LoadFwdPkgs()
}
// AckAddHtlcs sets a bit in the FwdFilter of a forwarding package belonging to
// this channel, that corresponds to the given AddRef. This method also succeeds
// if no forwarding package is found.
func (lc *LightningChannel) AckAddHtlcs(addRef channeldb.AddRef) error {
return lc.channelState.AckAddHtlcs(addRef)
}
// AckSettleFails sets a bit in the SettleFailFilter of a forwarding package
// belonging to this channel, that corresponds to the given SettleFailRef. This
// method also succeeds if no forwarding package is found.
func (lc *LightningChannel) AckSettleFails(
settleFailRefs ...channeldb.SettleFailRef) error {
return lc.channelState.AckSettleFails(settleFailRefs...)
}
// SetFwdFilter writes the forwarding decision for a given remote commitment
// height.
func (lc *LightningChannel) SetFwdFilter(height uint64,
fwdFilter *channeldb.PkgFilter) error {
return lc.channelState.SetFwdFilter(height, fwdFilter)
}
// RemoveFwdPkgs permanently deletes the forwarding package at the given heights.
func (lc *LightningChannel) RemoveFwdPkgs(heights ...uint64) error {
return lc.channelState.RemoveFwdPkgs(heights...)
}
// NextRevocationKey returns the commitment point for the _next_ commitment
// height. The pubkey returned by this function is required by the remote party
// along with their revocation base to extend our commitment chain with a
// new commitment.
func (lc *LightningChannel) NextRevocationKey() (*btcec.PublicKey, error) {
lc.RLock()
defer lc.RUnlock()
nextHeight := lc.currentHeight + 1
revocation, err := lc.channelState.RevocationProducer.AtIndex(nextHeight)
if err != nil {
return nil, err
}
return input.ComputeCommitmentPoint(revocation[:]), nil
}
// InitNextRevocation inserts the passed commitment point as the _next_
// revocation to be used when creating a new commitment state for the remote
// party. This function MUST be called before the channel can accept or propose
// any new states.
func (lc *LightningChannel) InitNextRevocation(revKey *btcec.PublicKey) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.InsertNextRevocation(revKey)
}
// AddHTLC is a wrapper of the `addHTLC` function which always enforces the
// FeeBuffer on the local balance if being the initiator of the channel. This
// method should be called when preparing to send an outgoing HTLC.
//
// The additional openKey argument corresponds to the incoming CircuitKey of the
// committed circuit for this HTLC. This value should never be nil.
//
// NOTE: It is okay for sourceRef to be nil when unit testing the wallet.
func (lc *LightningChannel) AddHTLC(htlc *lnwire.UpdateAddHTLC,
openKey *models.CircuitKey) (uint64, error) {
return lc.addHTLC(htlc, openKey, FeeBuffer)
}
// addHTLC adds an HTLC to the state machine's local update log. It provides
// the ability to enforce a buffer on the local balance when we are the
// initiator of the channel. This is useful when checking the edge cases of a
// channel state e.g. the BOLT 03 test vectors.
//
// The additional openKey argument corresponds to the incoming CircuitKey of the
// committed circuit for this HTLC. This value should never be nil.
//
// NOTE: It is okay for sourceRef to be nil when unit testing the wallet.
func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC,
openKey *models.CircuitKey, buffer BufferType) (uint64, error) {
lc.Lock()
defer lc.Unlock()
pd := lc.htlcAddDescriptor(htlc, openKey)
if err := lc.validateAddHtlc(pd, buffer); err != nil {
return 0, err
}
lc.updateLogs.Local.appendHtlc(pd)
return pd.HtlcIndex, nil
}
// GetDustSum takes in a boolean that determines which commitment to evaluate
// the dust sum on. The return value is the sum of dust on the desired
// commitment tx.
//
// NOTE: This over-estimates the dust exposure.
func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
lc.RLock()
defer lc.RUnlock()
var dustSum lnwire.MilliSatoshi
dustLimit := lc.channelState.LocalChanCfg.DustLimit
commit := lc.channelState.LocalCommitment
if whoseCommit.IsRemote() {
// Calculate dust sum on the remote's commitment.
dustLimit = lc.channelState.RemoteChanCfg.DustLimit
commit = lc.channelState.RemoteCommitment
}
chanType := lc.channelState.ChanType
feeRate := chainfee.SatPerKWeight(commit.FeePerKw)
// Optionally use the dry-run fee-rate.
feeRate = dryRunFee.UnwrapOr(feeRate)
// Grab all of our HTLCs and evaluate against the dust limit.
for e := lc.updateLogs.Local.Front(); e != nil; e = e.Next() {
pd := e.Value
if pd.EntryType != Add {
continue
}
amt := pd.Amount.ToSatoshis()
// If the satoshi amount is under the dust limit, add the msat
// amount to the dust sum.
if HtlcIsDust(
chanType, false, whoseCommit, feeRate, amt, dustLimit,
) {
dustSum += pd.Amount
}
}
// Grab all of their HTLCs and evaluate against the dust limit.
for e := lc.updateLogs.Remote.Front(); e != nil; e = e.Next() {
pd := e.Value
if pd.EntryType != Add {
continue
}
amt := pd.Amount.ToSatoshis()
// If the satoshi amount is under the dust limit, add the msat
// amount to the dust sum.
if HtlcIsDust(
chanType, true, whoseCommit, feeRate,
amt, dustLimit,
) {
dustSum += pd.Amount
}
}
return dustSum
}
// MayAddOutgoingHtlc validates whether we can add an outgoing htlc to this
// channel. We don't have a circuit for this htlc, because we just want to test
// that we have slots for a potential htlc so we use a "mock" htlc to validate
// a potential commitment state with one more outgoing htlc. If a zero htlc
// amount is provided, we'll attempt to add the smallest possible htlc to the
// channel (either the minimum htlc, or 1 sat).
func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error {
lc.Lock()
defer lc.Unlock()
var mockHtlcAmt lnwire.MilliSatoshi
switch {
// If the caller specifically set an amount, we use it.
case amt != 0:
mockHtlcAmt = amt
// In absence of a specific amount, we want to use minimum htlc value
// for the channel. However certain implementations may set this value
// to zero, so we only use this value if it is non-zero.
case lc.channelState.LocalChanCfg.MinHTLC != 0:
mockHtlcAmt = lc.channelState.LocalChanCfg.MinHTLC
// As a last resort, we just add a non-zero amount.
default:
mockHtlcAmt++
}
// Create a "mock" outgoing htlc, using the smallest amount we can add
// to the commitment so that we validate commitment slots rather than
// available balance, since our actual htlc amount is unknown at this
// stage.
pd := lc.htlcAddDescriptor(
&lnwire.UpdateAddHTLC{
Amount: mockHtlcAmt,
},
&models.CircuitKey{},
)
// Enforce the FeeBuffer because we are evaluating whether we can add
// another htlc to the channel state.
if err := lc.validateAddHtlc(pd, FeeBuffer); err != nil {
lc.log.Debugf("May add outgoing htlc rejected: %v", err)
return err
}
return nil
}
// htlcAddDescriptor returns a payment descriptor for the htlc and open key
// provided to add to our local update log.
func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
openKey *models.CircuitKey) *paymentDescriptor {
return &paymentDescriptor{
ChanID: htlc.ChanID,
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.updateLogs.Local.logIndex,
HtlcIndex: lc.updateLogs.Local.htlcCounter,
OnionBlob: htlc.OnionBlob,
OpenCircuitKey: openKey,
BlindingPoint: htlc.BlindingPoint,
CustomRecords: htlc.CustomRecords.Copy(),
}
}
// validateAddHtlc validates the addition of an outgoing htlc to our local and
// remote commitments.
func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor,
buffer BufferType) error {
// Make sure adding this HTLC won't violate any of the constraints we
// must keep on the commitment transactions.
remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote
// First we'll check whether this HTLC can be added to the remote
// commitment transaction without violation any of the constraints.
err := lc.validateCommitmentSanity(
remoteACKedIndex, lc.updateLogs.Local.logIndex, lntypes.Remote,
buffer, pd, nil,
)
if err != nil {
return err
}
// We must also check whether it can be added to our own commitment
// transaction, or the remote node will refuse to sign. This is not
// totally bullet proof, as the remote might be adding updates
// concurrently, but if we fail this check there is for sure not
// possible for us to add the HTLC.
err = lc.validateCommitmentSanity(
lc.updateLogs.Remote.logIndex, lc.updateLogs.Local.logIndex,
lntypes.Local, buffer, pd, nil,
)
if err != nil {
return err
}
return nil
}
// ReceiveHTLC adds an HTLC to the state machine's remote update log. This
// method should be called in response to receiving a new HTLC from the remote
// party.
func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64,
error) {
lc.Lock()
defer lc.Unlock()
if htlc.ID != lc.updateLogs.Remote.htlcCounter {
return 0, fmt.Errorf("ID %d on HTLC add does not match "+
"expected next ID %d", htlc.ID,
lc.updateLogs.Remote.htlcCounter)
}
pd := &paymentDescriptor{
ChanID: htlc.ChanID,
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.updateLogs.Remote.logIndex,
HtlcIndex: lc.updateLogs.Remote.htlcCounter,
OnionBlob: htlc.OnionBlob,
BlindingPoint: htlc.BlindingPoint,
CustomRecords: htlc.CustomRecords.Copy(),
}
localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local
// Clamp down on the number of HTLC's we can receive by checking the
// commitment sanity.
// We do not enforce the FeeBuffer here because one of the reasons it
// was introduced is to protect against asynchronous sending of htlcs so
// we use it here. The current lightning protocol does not allow to
// reject ADDs already sent by the peer.
err := lc.validateCommitmentSanity(
lc.updateLogs.Remote.logIndex, localACKedIndex, lntypes.Local,
NoBuffer, nil, pd,
)
if err != nil {
return 0, err
}
lc.updateLogs.Remote.appendHtlc(pd)
return pd.HtlcIndex, nil
}
// SettleHTLC attempts to settle an existing outstanding received HTLC. The
// remote log index of the HTLC settled is returned in order to facilitate
// creating the corresponding wire message. In the case the supplied preimage
// is invalid, an error is returned.
//
// The additional arguments correspond to:
//
// - sourceRef: specifies the location of the Add HTLC within a forwarding
// package that this HTLC is settling. Every Settle fails exactly one Add,
// so this should never be empty in practice.
//
// - destRef: specifies the location of the Settle HTLC within another
// channel's forwarding package. This value can be nil if the corresponding
// Add HTLC was never locked into an outgoing commitment txn, or this
// HTLC does not originate as a response from the peer on the outgoing
// link, e.g. on-chain resolutions.
//
// - closeKey: identifies the circuit that should be deleted after this Settle
// HTLC is included in a commitment txn. This value should only be nil if
// the HTLC was settled locally before committing a circuit to the circuit
// map.
//
// NOTE: It is okay for sourceRef, destRef, and closeKey to be nil when unit
// testing the wallet.
func (lc *LightningChannel) SettleHTLC(preimage [32]byte,
htlcIndex uint64, sourceRef *channeldb.AddRef,
destRef *channeldb.SettleFailRef, closeKey *models.CircuitKey) error {
lc.Lock()
defer lc.Unlock()
htlc := lc.updateLogs.Remote.lookupHtlc(htlcIndex)
if htlc == nil {
return ErrUnknownHtlcIndex{lc.ShortChanID(), htlcIndex}
}
// Now that we know the HTLC exists, before checking to see if the
// preimage matches, we'll ensure that we haven't already attempted to
// modify the HTLC.
if lc.updateLogs.Remote.htlcHasModification(htlcIndex) {
return ErrHtlcIndexAlreadySettled(htlcIndex)
}
if htlc.RHash != sha256.Sum256(preimage[:]) {
return ErrInvalidSettlePreimage{preimage[:], htlc.RHash[:]}
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
Amount: htlc.Amount,
RPreimage: preimage,
LogIndex: lc.updateLogs.Local.logIndex,
ParentIndex: htlcIndex,
EntryType: Settle,
SourceRef: sourceRef,
DestRef: destRef,
ClosedCircuitKey: closeKey,
}
lc.updateLogs.Local.appendUpdate(pd)
// With the settle added to our local log, we'll now mark the HTLC as
// modified to prevent ourselves from accidentally attempting a
// duplicate settle.
lc.updateLogs.Remote.markHtlcModified(htlcIndex)
return nil
}
// ReceiveHTLCSettle attempts to settle an existing outgoing HTLC indexed by an
// index into the local log. If the specified index doesn't exist within the
// log, and error is returned. Similarly if the preimage is invalid w.r.t to
// the referenced of then a distinct error is returned.
func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint64) error {
lc.Lock()
defer lc.Unlock()
htlc := lc.updateLogs.Local.lookupHtlc(htlcIndex)
if htlc == nil {
return ErrUnknownHtlcIndex{lc.ShortChanID(), htlcIndex}
}
// Now that we know the HTLC exists, before checking to see if the
// preimage matches, we'll ensure that they haven't already attempted
// to modify the HTLC.
if lc.updateLogs.Local.htlcHasModification(htlcIndex) {
return ErrHtlcIndexAlreadySettled(htlcIndex)
}
if htlc.RHash != sha256.Sum256(preimage[:]) {
return ErrInvalidSettlePreimage{preimage[:], htlc.RHash[:]}
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
Amount: htlc.Amount,
RPreimage: preimage,
ParentIndex: htlc.HtlcIndex,
RHash: htlc.RHash,
LogIndex: lc.updateLogs.Remote.logIndex,
EntryType: Settle,
}
lc.updateLogs.Remote.appendUpdate(pd)
// With the settle added to the remote log, we'll now mark the HTLC as
// modified to prevent the remote party from accidentally attempting a
// duplicate settle.
lc.updateLogs.Local.markHtlcModified(htlcIndex)
return nil
}
// FailHTLC attempts to fail a targeted HTLC by its payment hash, inserting an
// entry which will remove the target log entry within the next commitment
// update. This method is intended to be called in order to cancel in
// _incoming_ HTLC.
//
// The additional arguments correspond to:
//
// - sourceRef: specifies the location of the Add HTLC within a forwarding
// package that this HTLC is failing. Every Fail fails exactly one Add, so
// this should never be empty in practice.
//
// - destRef: specifies the location of the Fail HTLC within another channel's
// forwarding package. This value can be nil if the corresponding Add HTLC
// was never locked into an outgoing commitment txn, or this HTLC does not
// originate as a response from the peer on the outgoing link, e.g.
// on-chain resolutions.
//
// - closeKey: identifies the circuit that should be deleted after this Fail
// HTLC is included in a commitment txn. This value should only be nil if
// the HTLC was failed locally before committing a circuit to the circuit
// map.
//
// NOTE: It is okay for sourceRef, destRef, and closeKey to be nil when unit
// testing the wallet.
func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte,
sourceRef *channeldb.AddRef, destRef *channeldb.SettleFailRef,
closeKey *models.CircuitKey) error {
lc.Lock()
defer lc.Unlock()
htlc := lc.updateLogs.Remote.lookupHtlc(htlcIndex)
if htlc == nil {
return ErrUnknownHtlcIndex{lc.ShortChanID(), htlcIndex}
}
// Now that we know the HTLC exists, we'll ensure that we haven't
// already attempted to fail the HTLC.
if lc.updateLogs.Remote.htlcHasModification(htlcIndex) {
return ErrHtlcIndexAlreadyFailed(htlcIndex)
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
Amount: htlc.Amount,
RHash: htlc.RHash,
ParentIndex: htlcIndex,
LogIndex: lc.updateLogs.Local.logIndex,
EntryType: Fail,
FailReason: reason,
SourceRef: sourceRef,
DestRef: destRef,
ClosedCircuitKey: closeKey,
}
lc.updateLogs.Local.appendUpdate(pd)
// With the fail added to the remote log, we'll now mark the HTLC as
// modified to prevent ourselves from accidentally attempting a
// duplicate fail.
lc.updateLogs.Remote.markHtlcModified(htlcIndex)
return nil
}
// MalformedFailHTLC attempts to fail a targeted HTLC by its payment hash,
// inserting an entry which will remove the target log entry within the next
// commitment update. This method is intended to be called in order to cancel
// in _incoming_ HTLC.
//
// The additional sourceRef specifies the location of the Add HTLC within a
// forwarding package that this HTLC is failing. This value should never be
// empty.
//
// NOTE: It is okay for sourceRef to be nil when unit testing the wallet.
func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64,
failCode lnwire.FailCode, shaOnionBlob [sha256.Size]byte,
sourceRef *channeldb.AddRef) error {
lc.Lock()
defer lc.Unlock()
htlc := lc.updateLogs.Remote.lookupHtlc(htlcIndex)
if htlc == nil {
return ErrUnknownHtlcIndex{lc.ShortChanID(), htlcIndex}
}
// Now that we know the HTLC exists, we'll ensure that we haven't
// already attempted to fail the HTLC.
if lc.updateLogs.Remote.htlcHasModification(htlcIndex) {
return ErrHtlcIndexAlreadyFailed(htlcIndex)
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
Amount: htlc.Amount,
RHash: htlc.RHash,
ParentIndex: htlcIndex,
LogIndex: lc.updateLogs.Local.logIndex,
EntryType: MalformedFail,
FailCode: failCode,
ShaOnionBlob: shaOnionBlob,
SourceRef: sourceRef,
}
lc.updateLogs.Local.appendUpdate(pd)
// With the fail added to the remote log, we'll now mark the HTLC as
// modified to prevent ourselves from accidentally attempting a
// duplicate fail.
lc.updateLogs.Remote.markHtlcModified(htlcIndex)
return nil
}
// ReceiveFailHTLC attempts to cancel a targeted HTLC by its log index,
// inserting an entry which will remove the target log entry within the next
// commitment update. This method should be called in response to the upstream
// party cancelling an outgoing HTLC.
func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte,
) error {
lc.Lock()
defer lc.Unlock()
htlc := lc.updateLogs.Local.lookupHtlc(htlcIndex)
if htlc == nil {
return ErrUnknownHtlcIndex{lc.ShortChanID(), htlcIndex}
}
// Now that we know the HTLC exists, we'll ensure that they haven't
// already attempted to fail the HTLC.
if lc.updateLogs.Local.htlcHasModification(htlcIndex) {
return ErrHtlcIndexAlreadyFailed(htlcIndex)
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
Amount: htlc.Amount,
RHash: htlc.RHash,
ParentIndex: htlc.HtlcIndex,
LogIndex: lc.updateLogs.Remote.logIndex,
EntryType: Fail,
FailReason: reason,
}
lc.updateLogs.Remote.appendUpdate(pd)
// With the fail added to the remote log, we'll now mark the HTLC as
// modified to prevent ourselves from accidentally attempting a
// duplicate fail.
lc.updateLogs.Local.markHtlcModified(htlcIndex)
return nil
}
// ChannelPoint returns the outpoint of the original funding transaction which
// created this active channel. This outpoint is used throughout various
// subsystems to uniquely identify an open channel.
func (lc *LightningChannel) ChannelPoint() wire.OutPoint {
return lc.channelState.FundingOutpoint
}
// ChannelID returns the ChannelID of this LightningChannel. This is the same
// ChannelID that is used in update messages for this channel.
func (lc *LightningChannel) ChannelID() lnwire.ChannelID {
return lnwire.NewChanIDFromOutPoint(lc.ChannelPoint())
}
// ShortChanID returns the short channel ID for the channel. The short channel
// ID encodes the exact location in the main chain that the original
// funding output can be found.
func (lc *LightningChannel) ShortChanID() lnwire.ShortChannelID {
return lc.channelState.ShortChanID()
}
// LocalUpfrontShutdownScript returns the local upfront shutdown script for the
// channel. If it was not set, an empty byte array is returned.
func (lc *LightningChannel) LocalUpfrontShutdownScript() lnwire.DeliveryAddress {
return lc.channelState.LocalShutdownScript
}
// RemoteUpfrontShutdownScript returns the remote upfront shutdown script for the
// channel. If it was not set, an empty byte array is returned.
func (lc *LightningChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress {
return lc.channelState.RemoteShutdownScript
}
// AbsoluteThawHeight determines a frozen channel's absolute thaw height. If
// the channel is not frozen, then 0 is returned.
//
// An error is returned if the channel is pending, or is an unconfirmed zero
// conf channel.
func (lc *LightningChannel) AbsoluteThawHeight() (uint32, error) {
return lc.channelState.AbsoluteThawHeight()
}
// SignedCommitTxInputs contains data needed to create a signed commit
// transaction using a signer. See GetSignedCommitTx.
type SignedCommitTxInputs struct {
// CommitTx is the latest version of the commitment state, broadcast
// able by us.
CommitTx *wire.MsgTx
// CommitSig is one half of the signature required to fully complete
// the script for the commitment transaction above. This is the
// signature signed by the remote party for our version of the
// commitment transactions.
CommitSig []byte
// OurKey is our key to be used within the 2-of-2 output script
// for the owner of this channel.
OurKey keychain.KeyDescriptor
// TheirKey is their key to be used within the 2-of-2 output script
// for the owner of this channel.
TheirKey keychain.KeyDescriptor
// SignDesc is the primary sign descriptor that is capable of signing
// the commitment transaction that spends the multi-sig output.
SignDesc *input.SignDescriptor
// Taproot holds fields needed in case of a taproot channel.
// Iff the channel is of taproot type, this field is filled.
Taproot fn.Option[TaprootSignedCommitTxInputs]
}
// TaprootSignedCommitTxInputs contains additional data needed to create a
// signed commit transaction using a signer, used in case of a taproot channel.
// See GetSignedCommitTx.
type TaprootSignedCommitTxInputs struct {
// CommitHeight is the update number that this channel state represents.
// It is the total number of commitment updates up to this point. This
// can be viewed as sort of a "commitment height" as this number is
// monotonically increasing. This number is used to make a signature
// for a taproot channel, since it is used by shachain nonce producer
// (TaprootNonceProducer).
CommitHeight uint64
// TaprootNonceProducer is used to generate a shachain tree for the
// purpose of generating verification nonces for taproot channels.
TaprootNonceProducer shachain.Producer
// TapscriptRoot is the root of the tapscript tree that will be used to
// create the funding output. This is an optional field that should
// only be set for taproot channels.
TapscriptRoot fn.Option[chainhash.Hash]
}
// GetSignedCommitTx creates the witness stack of a channel commitment
// transaction. It can handle all commitment types (taproot, legacy). It is
// exported to give outside tooling the possibility to recreate the witness.
// A key use case is generating the witness data for a commitment transaction
// from a Static Channel Backup (SCB).
func GetSignedCommitTx(inputs SignedCommitTxInputs,
signer input.Signer) (*wire.MsgTx, error) {
commitTx := inputs.CommitTx.Copy()
var witness wire.TxWitness
switch {
// If this is a taproot channel, then we'll need to re-derive the nonce
// we need to generate a new signature
case inputs.Taproot.IsSome():
// Extract Taproot from fn.Option. It is safe to call
// UnsafeFromSome because we just checked that it is some.
taproot := inputs.Taproot.UnsafeFromSome()
// First, we'll need to re-derive the local nonce we sent to
// the remote party to create this musig session. We pass in
// the same height here as we're generating the nonce needed
// for the _current_ state.
localNonce, err := channeldb.NewMusigVerificationNonce(
inputs.OurKey.PubKey, taproot.CommitHeight,
taproot.TaprootNonceProducer,
)
if err != nil {
return nil, fmt.Errorf("unable to re-derive "+
"verification nonce: %w", err)
}
tapscriptTweak := fn.MapOption(TapscriptRootToTweak)(
taproot.TapscriptRoot,
)
// Now that we have the local nonce, we'll re-create the musig
// session we had for this height.
musigSession := NewPartialMusigSession(
*localNonce, inputs.OurKey, inputs.TheirKey, signer,
inputs.SignDesc.Output, LocalMusigCommit,
tapscriptTweak,
)
var remoteSig lnwire.PartialSigWithNonce
err = remoteSig.Decode(
bytes.NewReader(inputs.CommitSig),
)
if err != nil {
return nil, fmt.Errorf("unable to decode remote "+
"partial sig: %w", err)
}
// Next, we'll manually finalize the session with the signing
// nonce we got from the remote party which is embedded in the
// signature we have.
err = musigSession.FinalizeSession(musig2.Nonces{
PubNonce: remoteSig.Nonce,
})
if err != nil {
return nil, fmt.Errorf("unable to finalize musig "+
"session: %w", err)
}
// Now that the session has been finalized, we can generate our
// half of the signature for the state. We don't capture the
// sig as it's stored within the session.
if _, err := musigSession.SignCommit(commitTx); err != nil {
return nil, fmt.Errorf("unable to sign musig2 "+
"commitment: %w", err)
}
// The final step is now to combine this signature we generated
// above, with the remote party's signature. We only need to
// pass the remote sig, as the local sig was already cached in
// the session.
var partialSig MusigPartialSig
partialSig.FromWireSig(&remoteSig)
finalSig, err := musigSession.CombineSigs(partialSig.sig)
if err != nil {
return nil, fmt.Errorf("unable to combine musig "+
"partial sigs: %w", err)
}
// The witness is the single keyspend schnorr sig.
witness = wire.TxWitness{
finalSig.Serialize(),
}
// Otherwise, the final witness we generate will be a normal p2wsh
// multi-sig spend.
default:
theirSig, err := ecdsa.ParseDERSignature(inputs.CommitSig)
if err != nil {
return nil, err
}
// With this, we then generate the full witness so the caller
// can broadcast a fully signed transaction.
inputs.SignDesc.SigHashes = input.NewTxSigHashesV0Only(commitTx)
ourSig, err := signer.SignOutputRaw(commitTx, inputs.SignDesc)
if err != nil {
return nil, err
}
// With the final signature generated, create the witness stack
// required to spend from the multi-sig output.
witness = input.SpendMultiSig(
inputs.SignDesc.WitnessScript,
inputs.OurKey.PubKey.SerializeCompressed(), ourSig,
inputs.TheirKey.PubKey.SerializeCompressed(), theirSig,
)
}
commitTx.TxIn[0].Witness = witness
return commitTx, nil
}
// getSignedCommitTx method takes the latest commitment transaction and
// populates it with witness data.
func (lc *LightningChannel) getSignedCommitTx() (*wire.MsgTx, error) {
// Fetch the current commitment transaction, along with their signature
// for the transaction.
localCommit := lc.channelState.LocalCommitment
inputs := SignedCommitTxInputs{
CommitTx: localCommit.CommitTx,
CommitSig: localCommit.CommitSig,
OurKey: lc.channelState.LocalChanCfg.MultiSigKey,
TheirKey: lc.channelState.RemoteChanCfg.MultiSigKey,
SignDesc: lc.signDesc,
}
if lc.channelState.ChanType.IsTaproot() {
inputs.Taproot = fn.Some(TaprootSignedCommitTxInputs{
CommitHeight: lc.currentHeight,
TaprootNonceProducer: lc.taprootNonceProducer,
TapscriptRoot: lc.channelState.TapscriptRoot,
})
}
return GetSignedCommitTx(inputs, lc.Signer)
}
// CommitOutputResolution carries the necessary information required to allow
// us to sweep our commitment output in the case that either party goes to
// chain.
type CommitOutputResolution struct {
// SelfOutPoint is the full outpoint that points to out pay-to-self
// output within the closing commitment transaction.
SelfOutPoint wire.OutPoint
// SelfOutputSignDesc is a fully populated sign descriptor capable of
// generating a valid signature to sweep the output paying to us.
SelfOutputSignDesc input.SignDescriptor
// MaturityDelay is the relative time-lock, in blocks for all outputs
// that pay to the local party within the broadcast commitment
// transaction.
MaturityDelay uint32
// ResolutionBlob is a blob used for aux channels that permits a
// spender of the output to properly resolve it in the case of a force
// close.
ResolutionBlob fn.Option[tlv.Blob]
}
// UnilateralCloseSummary describes the details of a detected unilateral
// channel closure. This includes the information about with which
// transactions, and block the channel was unilaterally closed, as well as
// summarization details concerning the _state_ of the channel at the point of
// channel closure. Additionally, if we had a commitment output above dust on
// the remote party's commitment transaction, the necessary a SignDescriptor
// with the material necessary to seep the output are returned. Finally, if we
// had any outgoing HTLC's within the commitment transaction, then an
// OutgoingHtlcResolution for each output will included.
type UnilateralCloseSummary struct {
// SpendDetail is a struct that describes how and when the funding
// output was spent.
*chainntnfs.SpendDetail
// ChannelCloseSummary is a struct describing the final state of the
// channel and in which state is was closed.
channeldb.ChannelCloseSummary
// CommitResolution contains all the data required to sweep the output
// to ourselves. If this is our commitment transaction, then we'll need
// to wait a time delay before we can sweep the output.
//
// NOTE: If our commitment delivery output is below the dust limit,
// then this will be nil.
CommitResolution *CommitOutputResolution
// HtlcResolutions contains a fully populated HtlcResolutions struct
// which contains all the data required to sweep any outgoing HTLC's,
// and also any incoming HTLC's that we know the pre-image to.
HtlcResolutions *HtlcResolutions
// RemoteCommit is the exact commitment state that the remote party
// broadcast.
RemoteCommit channeldb.ChannelCommitment
// AnchorResolution contains the data required to sweep our anchor
// output. If the channel type doesn't include anchors, the value of
// this field will be nil.
AnchorResolution *AnchorResolution
}
// NewUnilateralCloseSummary creates a new summary that provides the caller
// with all the information required to claim all funds on chain in the event
// that the remote party broadcasts their commitment. The commitPoint argument
// should be set to the per_commitment_point corresponding to the spending
// commitment.
//
// NOTE: The remoteCommit argument should be set to the stored commitment for
// this particular state. If we don't have the commitment stored (should only
// happen in case we have lost state) it should be set to an empty struct, in
// which case we will attempt to sweep the non-HTLC output using the passed
// commitPoint.
func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, //nolint:funlen
signer input.Signer, commitSpend *chainntnfs.SpendDetail,
remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey,
leafStore fn.Option[AuxLeafStore],
auxResolver fn.Option[AuxContractResolver]) (*UnilateralCloseSummary,
error) {
// First, we'll generate the commitment point and the revocation point
// so we can re-construct the HTLC state and also our payment key.
commitType := lntypes.Remote
keyRing := DeriveCommitmentKeys(
commitPoint, commitType, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg,
)
auxResult, err := fn.MapOptionZ(
leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(chanState), remoteCommit,
*keyRing, lntypes.Remote,
)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// Next, we'll obtain HTLC resolutions for all the outgoing HTLC's we
// had on their commitment transaction.
var (
leaseExpiry uint32
selfPoint *wire.OutPoint
localBalance int64
isRemoteInitiator = !chanState.IsInitiator
commitTxBroadcast = commitSpend.SpendingTx
)
if chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = chanState.ThawHeight
}
htlcResolutions, err := extractHtlcResolutions(
chainfee.SatPerKWeight(remoteCommit.FeePerKw), commitType,
signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg,
&chanState.RemoteChanCfg, commitSpend.SpendingTx,
chanState.ChanType, isRemoteInitiator, leaseExpiry, chanState,
auxResult.AuxLeaves, auxResolver,
)
if err != nil {
return nil, fmt.Errorf("unable to create htlc resolutions: %w",
err)
}
// Before we can generate the proper sign descriptor, we'll need to
// locate the output index of our non-delayed output on the commitment
// transaction.
remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
)(auxResult.AuxLeaves)
selfScript, maturityDelay, err := CommitScriptToRemote(
chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey,
leaseExpiry, remoteAuxLeaf,
)
if err != nil {
return nil, fmt.Errorf("unable to create self commit "+
"script: %w", err)
}
for outputIndex, txOut := range commitTxBroadcast.TxOut {
if bytes.Equal(txOut.PkScript, selfScript.PkScript()) {
selfPoint = &wire.OutPoint{
Hash: *commitSpend.SpenderTxHash,
Index: uint32(outputIndex),
}
localBalance = txOut.Value
break
}
}
// With the HTLC's taken care of, we'll generate the sign descriptor
// necessary to sweep our commitment output, but only if we had a
// non-trimmed balance.
var commitResolution *CommitOutputResolution
if selfPoint != nil {
localPayBase := chanState.LocalChanCfg.PaymentBasePoint
// As the remote party has force closed, we just need the
// success witness script.
witnessScript, err := selfScript.WitnessScriptForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
commitResolution = &CommitOutputResolution{
SelfOutPoint: *selfPoint,
SelfOutputSignDesc: input.SignDescriptor{
KeyDesc: localPayBase,
SingleTweak: keyRing.LocalCommitKeyTweak,
WitnessScript: witnessScript,
Output: &wire.TxOut{
Value: localBalance,
PkScript: selfScript.PkScript(),
},
HashType: sweepSigHash(chanState.ChanType),
},
MaturityDelay: maturityDelay,
}
// For taproot channels, we'll need to set some additional
// fields to ensure the output can be swept.
//
//nolint:ll
if scriptTree, ok := selfScript.(input.TapscriptDescriptor); ok {
commitResolution.SelfOutputSignDesc.SignMethod =
input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
//nolint:ll
commitResolution.SelfOutputSignDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
// At this point, we'll check to see if we need any extra
// resolution data for this output.
resolveReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanState.ChanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.RemoteCommitment.CustomBlob,
FundingBlob: chanState.CustomBlob,
Type: input.TaprootRemoteCommitSpend,
CloseType: RemoteForceClose,
CommitTx: commitTxBroadcast,
ContractPoint: *selfPoint,
SignDesc: commitResolution.SelfOutputSignDesc,
KeyRing: keyRing,
CsvDelay: maturityDelay,
CommitFee: chanState.RemoteCommitment.CommitFee,
}
resolveBlob := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
return a.ResolveContract(resolveReq)
},
)
if err := resolveBlob.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
commitResolution.ResolutionBlob = resolveBlob.OkToSome()
}
closeSummary := channeldb.ChannelCloseSummary{
ChanPoint: chanState.FundingOutpoint,
ChainHash: chanState.ChainHash,
ClosingTXID: *commitSpend.SpenderTxHash,
CloseHeight: uint32(commitSpend.SpendingHeight),
RemotePub: chanState.IdentityPub,
Capacity: chanState.Capacity,
SettledBalance: btcutil.Amount(localBalance),
CloseType: channeldb.RemoteForceClose,
IsPending: true,
RemoteCurrentRevocation: chanState.RemoteCurrentRevocation,
RemoteNextRevocation: chanState.RemoteNextRevocation,
ShortChanID: chanState.ShortChanID(),
LocalChanConfig: chanState.LocalChanCfg,
}
// Attempt to add a channel sync message to the close summary.
chanSync, err := chanState.ChanSyncMsg()
if err != nil {
walletLog.Errorf("ChannelPoint(%v): unable to create channel sync "+
"message: %v", chanState.FundingOutpoint, err)
} else {
closeSummary.LastChanSyncMsg = chanSync
}
anchorResolution, err := NewAnchorResolution(
chanState, commitTxBroadcast, keyRing, lntypes.Remote,
)
if err != nil {
return nil, err
}
return &UnilateralCloseSummary{
SpendDetail: commitSpend,
ChannelCloseSummary: closeSummary,
CommitResolution: commitResolution,
HtlcResolutions: htlcResolutions,
RemoteCommit: remoteCommit,
AnchorResolution: anchorResolution,
}, nil
}
// IncomingHtlcResolution houses the information required to sweep any incoming
// HTLC's that we know the preimage to. We'll need to sweep an HTLC manually
// using this struct if we need to go on-chain for any reason, or if we detect
// that the remote party broadcasts their commitment transaction.
type IncomingHtlcResolution struct {
// Preimage is the preimage that will be used to satisfy the contract of
// the HTLC.
//
// NOTE: This field will only be populated in the incoming contest
// resolver.
Preimage [32]byte
// SignedSuccessTx is the fully signed HTLC success transaction. This
// transaction (if non-nil) can be broadcast immediately. After a csv
// delay (included below), then the output created by this transactions
// can be swept on-chain.
//
// NOTE: If this field is nil, then this indicates that we don't need
// to go to the second level to claim this HTLC. Instead, it can be
// claimed directly from the outpoint listed below.
SignedSuccessTx *wire.MsgTx
// SignDetails is non-nil if SignedSuccessTx is non-nil, and the
// channel is of the anchor type. As the above HTLC transaction will be
// signed by the channel peer using SINGLE|ANYONECANPAY for such
// channels, we can use the sign details to add the input-output pair
// of the HTLC transaction to another transaction, thereby aggregating
// multiple HTLC transactions together, and adding fees as needed.
SignDetails *input.SignDetails
// CsvDelay is the relative time lock (expressed in blocks) that must
// pass after the SignedSuccessTx is confirmed in the chain before the
// output can be swept.
//
// NOTE: If SignedTimeoutTx is nil, then this field denotes the CSV
// delay needed to spend from the commitment transaction.
CsvDelay uint32
// ClaimOutpoint is the final outpoint that needs to be spent in order
// to fully sweep the HTLC. The SignDescriptor below should be used to
// spend this outpoint. In the case of a second-level HTLC (non-nil
// SignedTimeoutTx), then we'll be spending a new transaction.
// Otherwise, it'll be an output in the commitment transaction.
ClaimOutpoint wire.OutPoint
// SweepSignDesc is a sign descriptor that has been populated with the
// necessary items required to spend the sole output of the above
// transaction.
SweepSignDesc input.SignDescriptor
// ResolutionBlob is a blob used for aux channels that permits a
// spender of the output to properly resolve it in the case of a force
// close.
ResolutionBlob fn.Option[tlv.Blob]
}
// OutgoingHtlcResolution houses the information necessary to sweep any
// outgoing HTLC's after their contract has expired. This struct will be needed
// in one of two cases: the local party force closes the commitment transaction
// or the remote party unilaterally closes with their version of the commitment
// transaction.
type OutgoingHtlcResolution struct {
// Expiry the absolute timeout of the HTLC. This value is expressed in
// block height, meaning after this height the HLTC can be swept.
Expiry uint32
// SignedTimeoutTx is the fully signed HTLC timeout transaction. This
// must be broadcast immediately after timeout has passed. Once this
// has been confirmed, the HTLC output will transition into the
// delay+claim state.
//
// NOTE: If this field is nil, then this indicates that we don't need
// to go to the second level to claim this HTLC. Instead, it can be
// claimed directly from the outpoint listed below.
SignedTimeoutTx *wire.MsgTx
// SignDetails is non-nil if SignedTimeoutTx is non-nil, and the
// channel is of the anchor type. As the above HTLC transaction will be
// signed by the channel peer using SINGLE|ANYONECANPAY for such
// channels, we can use the sign details to add the input-output pair
// of the HTLC transaction to another transaction, thereby aggregating
// multiple HTLC transactions together, and adding fees as needed.
SignDetails *input.SignDetails
// CsvDelay is the relative time lock (expressed in blocks) that must
// pass after the SignedTimeoutTx is confirmed in the chain before the
// output can be swept.
//
// NOTE: If SignedTimeoutTx is nil, then this field denotes the CSV
// delay needed to spend from the commitment transaction.
CsvDelay uint32
// ClaimOutpoint is the final outpoint that needs to be spent in order
// to fully sweep the HTLC. The SignDescriptor below should be used to
// spend this outpoint. In the case of a second-level HTLC (non-nil
// SignedTimeoutTx), then we'll be spending a new transaction.
// Otherwise, it'll be an output in the commitment transaction.
ClaimOutpoint wire.OutPoint
// SweepSignDesc is a sign descriptor that has been populated with the
// necessary items required to spend the sole output of the above
// transaction.
SweepSignDesc input.SignDescriptor
// ResolutionBlob is a blob used for aux channels that permits a
// spender of the output to properly resolve it in the case of a force
// close.
ResolutionBlob fn.Option[tlv.Blob]
}
// HtlcResolutions contains the items necessary to sweep HTLC's on chain
// directly from a commitment transaction. We'll use this in case either party
// goes broadcasts a commitment transaction with live HTLC's.
type HtlcResolutions struct {
// IncomingHTLCs contains a set of structs that can be used to sweep
// all the incoming HTL'C that we know the preimage to.
IncomingHTLCs []IncomingHtlcResolution
// OutgoingHTLCs contains a set of structs that contains all the info
// needed to sweep an outgoing HTLC we've sent to the remote party
// after an absolute delay has expired.
OutgoingHTLCs []OutgoingHtlcResolution
}
// newOutgoingHtlcResolution generates a new HTLC resolution capable of
// allowing the caller to sweep an outgoing HTLC present on either their, or
// the remote party's commitment transaction.
func newOutgoingHtlcResolution(signer input.Signer,
localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx,
htlc *channeldb.HTLC, keyRing *CommitmentKeyRing,
feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32,
whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool,
chanType channeldb.ChannelType, chanState *channeldb.OpenChannel,
auxLeaves fn.Option[CommitAuxLeaves],
auxResolver fn.Option[AuxContractResolver],
) (*OutgoingHtlcResolution, error) {
op := wire.OutPoint{
Hash: commitTx.TxHash(),
Index: uint32(htlc.OutputIndex),
}
// First, we'll re-generate the script used to send the HTLC to the
// remote party within their commitment transaction.
auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.OutgoingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf
})(auxLeaves)
htlcScriptInfo, err := genHtlcScript(
chanType, false, whoseCommit, htlc.RefundTimeout, htlc.RHash,
keyRing, auxLeaf,
)
if err != nil {
return nil, err
}
htlcPkScript := htlcScriptInfo.PkScript()
// As this is an outgoing HTLC, we just care about the timeout path
// here.
scriptPath := input.ScriptPathTimeout
htlcWitnessScript, err := htlcScriptInfo.WitnessScriptForPath(
scriptPath,
)
if err != nil {
return nil, err
}
htlcCsvDelay := HtlcSecondLevelInputSequence(chanType)
// If we're spending this HTLC output from the remote node's
// commitment, then we won't need to go to the second level as our
// outputs don't have a CSV delay.
if whoseCommit.IsRemote() {
// With the script generated, we can completely populated the
// SignDescriptor needed to sweep the output.
prevFetcher := txscript.NewCannedPrevOutputFetcher(
htlcPkScript, int64(htlc.Amt.ToSatoshis()),
)
signDesc := input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript,
Output: &wire.TxOut{
PkScript: htlcPkScript,
Value: int64(htlc.Amt.ToSatoshis()),
},
HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher,
}
scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor)
if ok {
signDesc.SignMethod = input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
scriptPath,
)
if err != nil {
return nil, err
}
signDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
resReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.RemoteCommitment.CustomBlob,
FundingBlob: chanState.CustomBlob,
Type: input.TaprootHtlcOfferedRemoteTimeout,
CloseType: RemoteForceClose,
CommitTx: commitTx,
ContractPoint: op,
SignDesc: signDesc,
KeyRing: keyRing,
CsvDelay: htlcCsvDelay,
CltvDelay: fn.Some(htlc.RefundTimeout),
CommitFee: chanState.RemoteCommitment.CommitFee,
HtlcID: fn.Some(htlc.HtlcIndex),
PayHash: fn.Some(htlc.RHash),
}
resolveRes := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
return a.ResolveContract(resReq)
},
)
if err := resolveRes.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
resolutionBlob := resolveRes.OkToSome()
return &OutgoingHtlcResolution{
Expiry: htlc.RefundTimeout,
ClaimOutpoint: op,
SweepSignDesc: signDesc,
CsvDelay: csvDelay,
ResolutionBlob: resolutionBlob,
}, nil
}
// Otherwise, we'll need to craft a second level HTLC transaction, as
// well as a sign desc to sweep after the CSV delay.
// In order to properly reconstruct the HTLC transaction, we'll need to
// re-calculate the fee required at this state, so we can add the
// correct output value amount to the transaction.
htlcFee := HtlcTimeoutFee(chanType, feePerKw)
secondLevelOutputAmt := htlc.Amt.ToSatoshis() - htlcFee
// With the fee calculated, re-construct the second level timeout
// transaction.
secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.OutgoingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
},
)(auxLeaves)
timeoutTx, err := CreateHtlcTimeoutTx(
chanType, isCommitFromInitiator, op, secondLevelOutputAmt,
htlc.RefundTimeout, csvDelay, leaseExpiry,
keyRing.RevocationKey, keyRing.ToLocalKey, secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
// With the transaction created, we can generate a sign descriptor
// that's capable of generating the signature required to spend the
// HTLC output using the timeout transaction.
txOut := commitTx.TxOut[htlc.OutputIndex]
prevFetcher := txscript.NewCannedPrevOutputFetcher(
txOut.PkScript, txOut.Value,
)
hashCache := txscript.NewTxSigHashes(timeoutTx, prevFetcher)
timeoutSignDesc := input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript,
Output: txOut,
HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher,
SigHashes: hashCache,
InputIndex: 0,
}
htlcSig, err := input.ParseSignature(htlc.Signature)
if err != nil {
return nil, err
}
// With the sign desc created, we can now construct the full witness
// for the timeout transaction, and populate it as well.
sigHashType := HtlcSigHashType(chanType)
var timeoutWitness wire.TxWitness
if scriptTree, ok := htlcScriptInfo.(input.TapscriptDescriptor); ok {
timeoutSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
timeoutWitness, err = input.SenderHTLCScriptTaprootTimeout(
htlcSig, sigHashType, signer, &timeoutSignDesc,
timeoutTx, keyRing.RevocationKey,
scriptTree.TapScriptTree(),
)
if err != nil {
return nil, err
}
// The control block is always the final element of the witness
// stack. We set this here as eventually the sweeper will need
// to re-sign, so it needs the isolated control block.
//
// TODO(roasbeef): move this into input.go?
ctlrBlkIdx := len(timeoutWitness) - 1
timeoutSignDesc.ControlBlock = timeoutWitness[ctlrBlkIdx]
} else {
timeoutWitness, err = input.SenderHtlcSpendTimeout(
htlcSig, sigHashType, signer, &timeoutSignDesc,
timeoutTx,
)
}
if err != nil {
return nil, err
}
timeoutTx.TxIn[0].Witness = timeoutWitness
// If this is an anchor type channel, the sign details will let us
// re-sign an aggregated tx later.
txSignDetails := HtlcSignDetails(
chanType, timeoutSignDesc, sigHashType, htlcSig,
)
// Finally, we'll generate the script output that the timeout
// transaction creates so we can generate the signDesc required to
// complete the claim process after a delay period.
var (
htlcSweepScript input.ScriptDescriptor
signMethod input.SignMethod
ctrlBlock []byte
)
if !chanType.IsTaproot() {
htlcSweepScript, err = SecondLevelHtlcScript(
chanType, isCommitFromInitiator, keyRing.RevocationKey,
keyRing.ToLocalKey, csvDelay, leaseExpiry,
secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
} else {
//nolint:ll
secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree(
keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay,
secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
signMethod = input.TaprootScriptSpendSignMethod
controlBlock, err := secondLevelScriptTree.CtrlBlockForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
ctrlBlock, err = controlBlock.ToBytes()
if err != nil {
return nil, err
}
htlcSweepScript = secondLevelScriptTree
}
// In this case, the witness script that needs to be signed will always
// be that of the success path.
htlcSweepWitnessScript, err := htlcSweepScript.WitnessScriptForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
localDelayTweak := input.SingleTweakBytes(
keyRing.CommitPoint, localChanCfg.DelayBasePoint.PubKey,
)
// In addition to the info in txSignDetails, we also need extra
// information to sweep the second level output after confirmation.
sweepSignDesc := input.SignDescriptor{
KeyDesc: localChanCfg.DelayBasePoint,
SingleTweak: localDelayTweak,
WitnessScript: htlcSweepWitnessScript,
Output: &wire.TxOut{
PkScript: htlcSweepScript.PkScript(),
Value: int64(secondLevelOutputAmt),
},
HashType: sweepSigHash(chanType),
PrevOutputFetcher: txscript.NewCannedPrevOutputFetcher(
htlcSweepScript.PkScript(),
int64(secondLevelOutputAmt),
),
SignMethod: signMethod,
ControlBlock: ctrlBlock,
}
// This might be an aux channel, so we'll go ahead and attempt to
// generate the resolution blob for the channel so we can pass along to
// the sweeping sub-system.
resolveRes := fn.MapOptionZ(
auxResolver, func(a AuxContractResolver) fn.Result[tlv.Blob] {
resReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.LocalCommitment.CustomBlob, //nolint:ll
FundingBlob: chanState.CustomBlob,
Type: input.TaprootHtlcLocalOfferedTimeout, //nolint:ll
CloseType: LocalForceClose,
CommitTx: commitTx,
ContractPoint: op,
SignDesc: sweepSignDesc,
KeyRing: keyRing,
CsvDelay: htlcCsvDelay,
HtlcAmt: btcutil.Amount(txOut.Value),
CommitCsvDelay: csvDelay,
CltvDelay: fn.Some(htlc.RefundTimeout),
CommitFee: chanState.LocalCommitment.CommitFee, //nolint:ll
HtlcID: fn.Some(htlc.HtlcIndex),
PayHash: fn.Some(htlc.RHash),
AuxSigDesc: fn.Some(AuxSigDesc{
SignDetails: *txSignDetails,
AuxSig: func() []byte {
tlvType := htlcCustomSigType.TypeVal() //nolint:ll
return htlc.CustomRecords[uint64(tlvType)] //nolint:ll
}(),
}),
}
return a.ResolveContract(resReq)
},
)
if err := resolveRes.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
resolutionBlob := resolveRes.OkToSome()
return &OutgoingHtlcResolution{
Expiry: htlc.RefundTimeout,
SignedTimeoutTx: timeoutTx,
SignDetails: txSignDetails,
CsvDelay: csvDelay,
ResolutionBlob: resolutionBlob,
ClaimOutpoint: wire.OutPoint{
Hash: timeoutTx.TxHash(),
Index: 0,
},
SweepSignDesc: sweepSignDesc,
}, nil
}
// newIncomingHtlcResolution creates a new HTLC resolution capable of allowing
// the caller to sweep an incoming HTLC. If the HTLC is on the caller's
// commitment transaction, then they'll need to broadcast a second-level
// transaction before sweeping the output (and incur a CSV delay). Otherwise,
// they can just sweep the output immediately with knowledge of the pre-image.
//
// TODO(roasbeef) consolidate code with above func
func newIncomingHtlcResolution(signer input.Signer,
localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx,
htlc *channeldb.HTLC, keyRing *CommitmentKeyRing,
feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32,
whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool,
chanType channeldb.ChannelType, chanState *channeldb.OpenChannel,
auxLeaves fn.Option[CommitAuxLeaves],
auxResolver fn.Option[AuxContractResolver],
) (*IncomingHtlcResolution, error) {
op := wire.OutPoint{
Hash: commitTx.TxHash(),
Index: uint32(htlc.OutputIndex),
}
// First, we'll re-generate the script the remote party used to
// send the HTLC to us in their commitment transaction.
auxLeaf := fn.FlatMapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.IncomingHtlcLeaves[htlc.HtlcIndex].AuxTapLeaf
})(auxLeaves)
scriptInfo, err := genHtlcScript(
chanType, true, whoseCommit, htlc.RefundTimeout, htlc.RHash,
keyRing, auxLeaf,
)
if err != nil {
return nil, err
}
htlcPkScript := scriptInfo.PkScript()
// As this is an incoming HTLC, we're attempting to sweep with the
// success path.
scriptPath := input.ScriptPathSuccess
htlcWitnessScript, err := scriptInfo.WitnessScriptForPath(
scriptPath,
)
if err != nil {
return nil, err
}
htlcCsvDelay := HtlcSecondLevelInputSequence(chanType)
// If we're spending this output from the remote node's commitment,
// then we can skip the second layer and spend the output directly.
if whoseCommit.IsRemote() {
// With the script generated, we can completely populated the
// SignDescriptor needed to sweep the output.
prevFetcher := txscript.NewCannedPrevOutputFetcher(
htlcPkScript, int64(htlc.Amt.ToSatoshis()),
)
signDesc := input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript,
Output: &wire.TxOut{
PkScript: htlcPkScript,
Value: int64(htlc.Amt.ToSatoshis()),
},
HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher,
}
//nolint:ll
if scriptTree, ok := scriptInfo.(input.TapscriptDescriptor); ok {
signDesc.SignMethod = input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
scriptPath,
)
if err != nil {
return nil, err
}
signDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
resReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.RemoteCommitment.CustomBlob,
Type: input.TaprootHtlcAcceptedRemoteSuccess,
FundingBlob: chanState.CustomBlob,
CloseType: RemoteForceClose,
CommitTx: commitTx,
ContractPoint: op,
SignDesc: signDesc,
KeyRing: keyRing,
HtlcID: fn.Some(htlc.HtlcIndex),
CsvDelay: htlcCsvDelay,
CltvDelay: fn.Some(htlc.RefundTimeout),
CommitFee: chanState.RemoteCommitment.CommitFee,
PayHash: fn.Some(htlc.RHash),
CommitCsvDelay: csvDelay,
HtlcAmt: htlc.Amt.ToSatoshis(),
}
resolveRes := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
return a.ResolveContract(resReq)
},
)
if err := resolveRes.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
resolutionBlob := resolveRes.OkToSome()
return &IncomingHtlcResolution{
ClaimOutpoint: op,
SweepSignDesc: signDesc,
CsvDelay: htlcCsvDelay,
ResolutionBlob: resolutionBlob,
}, nil
}
secondLevelAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
leaves := l.IncomingHtlcLeaves
return leaves[htlc.HtlcIndex].SecondLevelLeaf
},
)(auxLeaves)
// Otherwise, we'll need to go to the second level to sweep this HTLC.
//
// First, we'll reconstruct the original HTLC success transaction,
// taking into account the fee rate used.
htlcFee := HtlcSuccessFee(chanType, feePerKw)
secondLevelOutputAmt := htlc.Amt.ToSatoshis() - htlcFee
successTx, err := CreateHtlcSuccessTx(
chanType, isCommitFromInitiator, op, secondLevelOutputAmt,
csvDelay, leaseExpiry, keyRing.RevocationKey,
keyRing.ToLocalKey, secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
// Once we've created the second-level transaction, we'll generate the
// SignDesc needed spend the HTLC output using the success transaction.
txOut := commitTx.TxOut[htlc.OutputIndex]
prevFetcher := txscript.NewCannedPrevOutputFetcher(
txOut.PkScript, txOut.Value,
)
hashCache := txscript.NewTxSigHashes(successTx, prevFetcher)
successSignDesc := input.SignDescriptor{
KeyDesc: localChanCfg.HtlcBasePoint,
SingleTweak: keyRing.LocalHtlcKeyTweak,
WitnessScript: htlcWitnessScript,
Output: txOut,
HashType: sweepSigHash(chanType),
PrevOutputFetcher: prevFetcher,
SigHashes: hashCache,
InputIndex: 0,
}
htlcSig, err := input.ParseSignature(htlc.Signature)
if err != nil {
return nil, err
}
// Next, we'll construct the full witness needed to satisfy the input of
// the success transaction. Don't specify the preimage yet. The preimage
// will be supplied by the contract resolver, either directly or when it
// becomes known.
var successWitness wire.TxWitness
sigHashType := HtlcSigHashType(chanType)
if scriptTree, ok := scriptInfo.(input.TapscriptDescriptor); ok {
successSignDesc.SignMethod = input.TaprootScriptSpendSignMethod
successWitness, err = input.ReceiverHTLCScriptTaprootRedeem(
htlcSig, sigHashType, nil, signer, &successSignDesc,
successTx, keyRing.RevocationKey,
scriptTree.TapScriptTree(),
)
if err != nil {
return nil, err
}
// The control block is always the final element of the witness
// stack. We set this here as eventually the sweeper will need
// to re-sign, so it needs the isolated control block.
//
// TODO(roasbeef): move this into input.go?
ctlrBlkIdx := len(successWitness) - 1
successSignDesc.ControlBlock = successWitness[ctlrBlkIdx]
} else {
successWitness, err = input.ReceiverHtlcSpendRedeem(
htlcSig, sigHashType, nil, signer, &successSignDesc,
successTx,
)
if err != nil {
return nil, err
}
}
successTx.TxIn[0].Witness = successWitness
// If this is an anchor type channel, the sign details will let us
// re-sign an aggregated tx later.
txSignDetails := HtlcSignDetails(
chanType, successSignDesc, sigHashType, htlcSig,
)
// Finally, we'll generate the script that the second-level transaction
// creates so we can generate the proper signDesc to sweep it after the
// CSV delay has passed.
var (
htlcSweepScript input.ScriptDescriptor
signMethod input.SignMethod
ctrlBlock []byte
)
if !chanType.IsTaproot() {
htlcSweepScript, err = SecondLevelHtlcScript(
chanType, isCommitFromInitiator, keyRing.RevocationKey,
keyRing.ToLocalKey, csvDelay, leaseExpiry,
secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
} else {
//nolint:ll
secondLevelScriptTree, err := input.TaprootSecondLevelScriptTree(
keyRing.RevocationKey, keyRing.ToLocalKey, csvDelay,
secondLevelAuxLeaf,
)
if err != nil {
return nil, err
}
signMethod = input.TaprootScriptSpendSignMethod
controlBlock, err := secondLevelScriptTree.CtrlBlockForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
ctrlBlock, err = controlBlock.ToBytes()
if err != nil {
return nil, err
}
htlcSweepScript = secondLevelScriptTree
}
// In this case, the witness script that needs to be signed will always
// be that of the success path.
htlcSweepWitnessScript, err := htlcSweepScript.WitnessScriptForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
localDelayTweak := input.SingleTweakBytes(
keyRing.CommitPoint, localChanCfg.DelayBasePoint.PubKey,
)
// In addition to the info in txSignDetails, we also need extra
// information to sweep the second level output after confirmation.
sweepSignDesc := input.SignDescriptor{
KeyDesc: localChanCfg.DelayBasePoint,
SingleTweak: localDelayTweak,
WitnessScript: htlcSweepWitnessScript,
Output: &wire.TxOut{
PkScript: htlcSweepScript.PkScript(),
Value: int64(secondLevelOutputAmt),
},
HashType: sweepSigHash(chanType),
PrevOutputFetcher: txscript.NewCannedPrevOutputFetcher(
htlcSweepScript.PkScript(),
int64(secondLevelOutputAmt),
),
SignMethod: signMethod,
ControlBlock: ctrlBlock,
}
resolveRes := fn.MapOptionZ(
auxResolver, func(a AuxContractResolver) fn.Result[tlv.Blob] {
resReq := ResolutionReq{
ChanPoint: chanState.FundingOutpoint,
ChanType: chanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.LocalCommitment.CustomBlob, //nolint:ll
Type: input.TaprootHtlcAcceptedLocalSuccess, //nolint:ll
FundingBlob: chanState.CustomBlob,
CloseType: LocalForceClose,
CommitTx: commitTx,
ContractPoint: op,
SignDesc: sweepSignDesc,
KeyRing: keyRing,
HtlcID: fn.Some(htlc.HtlcIndex),
CsvDelay: htlcCsvDelay,
CommitFee: chanState.LocalCommitment.CommitFee, //nolint:ll
PayHash: fn.Some(htlc.RHash),
AuxSigDesc: fn.Some(AuxSigDesc{
SignDetails: *txSignDetails,
AuxSig: func() []byte {
tlvType := htlcCustomSigType.TypeVal() //nolint:ll
return htlc.CustomRecords[uint64(tlvType)] //nolint:ll
}(),
}),
CommitCsvDelay: csvDelay,
HtlcAmt: btcutil.Amount(txOut.Value),
CltvDelay: fn.Some(htlc.RefundTimeout),
}
return a.ResolveContract(resReq)
},
)
if err := resolveRes.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
resolutionBlob := resolveRes.OkToSome()
return &IncomingHtlcResolution{
SignedSuccessTx: successTx,
SignDetails: txSignDetails,
CsvDelay: csvDelay,
ResolutionBlob: resolutionBlob,
ClaimOutpoint: wire.OutPoint{
Hash: successTx.TxHash(),
Index: 0,
},
SweepSignDesc: sweepSignDesc,
}, nil
}
// HtlcPoint returns the htlc's outpoint on the commitment tx.
func (r *IncomingHtlcResolution) HtlcPoint() wire.OutPoint {
// If we have a success transaction, then the htlc's outpoint
// is the transaction's only input. Otherwise, it's the claim
// point.
if r.SignedSuccessTx != nil {
return r.SignedSuccessTx.TxIn[0].PreviousOutPoint
}
return r.ClaimOutpoint
}
// HtlcPoint returns the htlc's outpoint on the commitment tx.
func (r *OutgoingHtlcResolution) HtlcPoint() wire.OutPoint {
// If we have a timeout transaction, then the htlc's outpoint
// is the transaction's only input. Otherwise, it's the claim
// point.
if r.SignedTimeoutTx != nil {
return r.SignedTimeoutTx.TxIn[0].PreviousOutPoint
}
return r.ClaimOutpoint
}
// extractHtlcResolutions creates a series of outgoing HTLC resolutions, and
// the local key used when generating the HTLC scrips. This function is to be
// used in two cases: force close, or a unilateral close.
func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight,
whoseCommit lntypes.ChannelParty, signer input.Signer,
htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig,
commitTx *wire.MsgTx, chanType channeldb.ChannelType,
isCommitFromInitiator bool, leaseExpiry uint32,
chanState *channeldb.OpenChannel, auxLeaves fn.Option[CommitAuxLeaves],
auxResolver fn.Option[AuxContractResolver]) (*HtlcResolutions, error) {
// TODO(roasbeef): don't need to swap csv delay?
dustLimit := remoteChanCfg.DustLimit
csvDelay := remoteChanCfg.CsvDelay
if whoseCommit.IsLocal() {
dustLimit = localChanCfg.DustLimit
csvDelay = localChanCfg.CsvDelay
}
incomingResolutions := make([]IncomingHtlcResolution, 0, len(htlcs))
outgoingResolutions := make([]OutgoingHtlcResolution, 0, len(htlcs))
for _, htlc := range htlcs {
htlc := htlc
// We'll skip any HTLC's which were dust on the commitment
// transaction, as these don't have a corresponding output
// within the commitment transaction.
if HtlcIsDust(
chanType, htlc.Incoming, whoseCommit, feePerKw,
htlc.Amt.ToSatoshis(), dustLimit,
) {
continue
}
// If the HTLC is incoming, then we'll attempt to see if we
// know the pre-image to the HTLC.
if htlc.Incoming {
// Otherwise, we'll create an incoming HTLC resolution
// as we can satisfy the contract.
ihr, err := newIncomingHtlcResolution(
signer, localChanCfg, commitTx, &htlc,
keyRing, feePerKw, uint32(csvDelay),
leaseExpiry, whoseCommit, isCommitFromInitiator,
chanType, chanState, auxLeaves, auxResolver,
)
if err != nil {
return nil, fmt.Errorf("incoming resolution "+
"failed: %v", err)
}
incomingResolutions = append(incomingResolutions, *ihr)
continue
}
ohr, err := newOutgoingHtlcResolution(
signer, localChanCfg, commitTx, &htlc, keyRing,
feePerKw, uint32(csvDelay), leaseExpiry, whoseCommit,
isCommitFromInitiator, chanType, chanState, auxLeaves,
auxResolver,
)
if err != nil {
return nil, fmt.Errorf("outgoing resolution "+
"failed: %v", err)
}
outgoingResolutions = append(outgoingResolutions, *ohr)
}
return &HtlcResolutions{
IncomingHTLCs: incomingResolutions,
OutgoingHTLCs: outgoingResolutions,
}, nil
}
// AnchorResolution holds the information necessary to spend our commitment tx
// anchor.
type AnchorResolution struct {
// AnchorSignDescriptor is the sign descriptor for our anchor.
AnchorSignDescriptor input.SignDescriptor
// CommitAnchor is the anchor outpoint on the commit tx.
CommitAnchor wire.OutPoint
// CommitFee is the fee of the commit tx.
CommitFee btcutil.Amount
// CommitWeight is the weight of the commit tx.
CommitWeight lntypes.WeightUnit
}
// LocalForceCloseSummary describes the final commitment state before the
// channel is locked-down to initiate a force closure by broadcasting the
// latest state on-chain. If we intend to broadcast this this state, the
// channel should not be used after generating this close summary. The summary
// includes all the information required to claim all rightfully owned outputs
// when the commitment gets confirmed.
type LocalForceCloseSummary struct {
// ChanPoint is the outpoint that created the channel which has been
// force closed.
ChanPoint wire.OutPoint
// CloseTx is the transaction which can be used to close the channel
// on-chain. When we initiate a force close, this will be our latest
// commitment state.
CloseTx *wire.MsgTx
// ChanSnapshot is a snapshot of the final state of the channel at the
// time the summary was created.
ChanSnapshot channeldb.ChannelSnapshot
// ContractResolutions contains all the data required for resolving the
// different output types of a commitment transaction.
ContractResolutions fn.Option[ContractResolutions]
}
// ContractResolutions contains all the data required for resolving the
// different output types of a commitment transaction.
type ContractResolutions struct {
// CommitResolution contains all the data required to sweep the output
// to ourselves. Since this is our commitment transaction, we'll need
// to wait a time delay before we can sweep the output.
//
// NOTE: If our commitment delivery output is below the dust limit,
// then this will be nil.
CommitResolution *CommitOutputResolution
// AnchorResolution contains the data required to sweep the anchor
// output. If the channel type doesn't include anchors, the value of
// this field will be nil.
AnchorResolution *AnchorResolution
// HtlcResolutions contains all the data required to sweep any outgoing
// HTLC's and incoming HTLc's we know the preimage to. For each of these
// HTLC's, we'll need to go to the second level to sweep them fully.
HtlcResolutions *HtlcResolutions
}
// ForceCloseOpt is a functional option argument for the ForceClose method.
type ForceCloseOpt func(*forceCloseConfig)
// forceCloseConfig holds the configuration options for force closing a channel.
type forceCloseConfig struct {
// skipResolution if true will skip creating the contract resolutions
// when generating the force close summary.
skipResolution bool
}
// defaultForceCloseConfig returns the default force close configuration.
func defaultForceCloseConfig() *forceCloseConfig {
return &forceCloseConfig{}
}
// WithSkipContractResolutions creates an option to skip the contract
// resolutions from the returned summary.
func WithSkipContractResolutions() ForceCloseOpt {
return func(cfg *forceCloseConfig) {
cfg.skipResolution = true
}
}
// ForceClose executes a unilateral closure of the transaction at the current
// lowest commitment height of the channel. Following a force closure, all
// state transitions, or modifications to the state update logs will be
// rejected. Additionally, this function also returns a LocalForceCloseSummary
// which includes the necessary details required to sweep all the time-locked
// outputs within the commitment transaction.
//
// TODO(roasbeef): all methods need to abort if in dispute state
func (lc *LightningChannel) ForceClose(opts ...ForceCloseOpt) (
*LocalForceCloseSummary, error) {
lc.Lock()
defer lc.Unlock()
cfg := defaultForceCloseConfig()
for _, opt := range opts {
opt(cfg)
}
// If we've detected local data loss for this channel, then we won't
// allow a force close, as it may be the case that we have a dated
// version of the commitment, or this is actually a channel shell.
if lc.channelState.HasChanStatus(channeldb.ChanStatusLocalDataLoss) {
return nil, fmt.Errorf("%w: channel_state=%v",
ErrForceCloseLocalDataLoss,
lc.channelState.ChanStatus())
}
commitTx, err := lc.getSignedCommitTx()
if err != nil {
return nil, err
}
if cfg.skipResolution {
return &LocalForceCloseSummary{
ChanPoint: lc.channelState.FundingOutpoint,
ChanSnapshot: *lc.channelState.Snapshot(),
CloseTx: commitTx,
}, nil
}
localCommitment := lc.channelState.LocalCommitment
summary, err := NewLocalForceCloseSummary(
lc.channelState, lc.Signer, commitTx,
localCommitment.CommitHeight, lc.leafStore, lc.auxResolver,
)
if err != nil {
return nil, fmt.Errorf("unable to gen force close "+
"summary: %w", err)
}
// Mark the channel as closed to block future closure requests.
lc.isClosed = true
return summary, nil
}
// NewLocalForceCloseSummary generates a LocalForceCloseSummary from the given
// channel state. The passed commitTx must be a fully signed commitment
// transaction corresponding to localCommit.
func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel,
signer input.Signer, commitTx *wire.MsgTx, stateNum uint64,
leafStore fn.Option[AuxLeafStore],
auxResolver fn.Option[AuxContractResolver]) (*LocalForceCloseSummary,
error) {
// Re-derive the original pkScript for to-self output within the
// commitment transaction. We'll need this to find the corresponding
// output in the commitment transaction and potentially for creating
// the sign descriptor.
csvTimeout := uint32(chanState.LocalChanCfg.CsvDelay)
// We use the passed state num to derive our scripts, since in case
// this is after recovery, our latest channels state might not be up to
// date.
revocation, err := chanState.RevocationProducer.AtIndex(stateNum)
if err != nil {
return nil, err
}
commitPoint := input.ComputeCommitmentPoint(revocation[:])
keyRing := DeriveCommitmentKeys(
commitPoint, lntypes.Local, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg,
)
auxResult, err := fn.MapOptionZ(
leafStore, func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return s.FetchLeavesFromCommit(
NewAuxChanState(chanState),
chanState.LocalCommitment, *keyRing,
lntypes.Local,
)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
var leaseExpiry uint32
if chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = chanState.ThawHeight
}
localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
)(auxResult.AuxLeaves)
toLocalScript, err := CommitScriptToSelf(
chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey,
keyRing.RevocationKey, csvTimeout, leaseExpiry, localAuxLeaf,
)
if err != nil {
return nil, err
}
// Locate the output index of the delayed commitment output back to us.
// We'll return the details of this output to the caller so they can
// sweep it once it's mature.
var (
delayIndex uint32
delayOut *wire.TxOut
)
for i, txOut := range commitTx.TxOut {
if !bytes.Equal(toLocalScript.PkScript(), txOut.PkScript) {
continue
}
delayIndex = uint32(i)
delayOut = txOut
break
}
// With the necessary information gathered above, create a new sign
// descriptor which is capable of generating the signature the caller
// needs to sweep this output. The hash cache, and input index are not
// set as the caller will decide these values once sweeping the output.
// If the output is non-existent (dust), have the sign descriptor be
// nil.
var commitResolution *CommitOutputResolution
if delayOut != nil {
// When attempting to sweep our own output, we only need the
// witness script for the delay path
scriptPath := input.ScriptPathDelay
witnessScript, err := toLocalScript.WitnessScriptForPath(
scriptPath,
)
if err != nil {
return nil, err
}
localBalance := delayOut.Value
commitResolution = &CommitOutputResolution{
SelfOutPoint: wire.OutPoint{
Hash: commitTx.TxHash(),
Index: delayIndex,
},
SelfOutputSignDesc: input.SignDescriptor{
KeyDesc: chanState.LocalChanCfg.DelayBasePoint,
SingleTweak: keyRing.LocalCommitKeyTweak,
WitnessScript: witnessScript,
Output: &wire.TxOut{
PkScript: delayOut.PkScript,
Value: localBalance,
},
HashType: sweepSigHash(chanState.ChanType),
},
MaturityDelay: csvTimeout,
}
// For taproot channels, we'll need to set some additional
// fields to ensure the output can be swept.
scriptTree, ok := toLocalScript.(input.TapscriptDescriptor)
if ok {
commitResolution.SelfOutputSignDesc.SignMethod =
input.TaprootScriptSpendSignMethod
ctrlBlock, err := scriptTree.CtrlBlockForPath(
scriptPath,
)
if err != nil {
return nil, err
}
//nolint:ll
commitResolution.SelfOutputSignDesc.ControlBlock, err = ctrlBlock.ToBytes()
if err != nil {
return nil, err
}
}
// At this point, we'll check to see if we need any extra
// resolution data for this output.
resolveBlob := fn.MapOptionZ(
auxResolver,
func(a AuxContractResolver) fn.Result[tlv.Blob] {
//nolint:ll
return a.ResolveContract(ResolutionReq{
ChanPoint: chanState.FundingOutpoint, //nolint:ll
ChanType: chanState.ChanType,
ShortChanID: chanState.ShortChanID(),
Initiator: chanState.IsInitiator,
CommitBlob: chanState.LocalCommitment.CustomBlob,
FundingBlob: chanState.CustomBlob,
Type: input.TaprootLocalCommitSpend,
CloseType: LocalForceClose,
CommitTx: commitTx,
ContractPoint: commitResolution.SelfOutPoint,
SignDesc: commitResolution.SelfOutputSignDesc,
KeyRing: keyRing,
CsvDelay: csvTimeout,
CommitFee: chanState.LocalCommitment.CommitFee,
})
},
)
if err := resolveBlob.Err(); err != nil {
return nil, fmt.Errorf("unable to aux resolve: %w", err)
}
commitResolution.ResolutionBlob = resolveBlob.OkToSome()
}
// Once the delay output has been found (if it exists), then we'll also
// need to create a series of sign descriptors for any lingering
// outgoing HTLC's that we'll need to claim as well. If this is after
// recovery there is not much we can do with HTLCs, so we'll always
// use what we have in our latest state when extracting resolutions.
localCommit := chanState.LocalCommitment
htlcResolutions, err := extractHtlcResolutions(
chainfee.SatPerKWeight(localCommit.FeePerKw), lntypes.Local,
signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg,
&chanState.RemoteChanCfg, commitTx, chanState.ChanType,
chanState.IsInitiator, leaseExpiry, chanState,
auxResult.AuxLeaves, auxResolver,
)
if err != nil {
return nil, fmt.Errorf("unable to gen htlc resolution: %w", err)
}
anchorResolution, err := NewAnchorResolution(
chanState, commitTx, keyRing, lntypes.Local,
)
if err != nil {
return nil, fmt.Errorf("unable to gen anchor "+
"resolution: %w", err)
}
return &LocalForceCloseSummary{
ChanPoint: chanState.FundingOutpoint,
CloseTx: commitTx,
ChanSnapshot: *chanState.Snapshot(),
ContractResolutions: fn.Some(ContractResolutions{
CommitResolution: commitResolution,
HtlcResolutions: htlcResolutions,
AnchorResolution: anchorResolution,
}),
}, nil
}
// CloseOutput wraps a normal tx out with additional metadata that indicates if
// the output belongs to the initiator of the channel or not.
type CloseOutput struct {
wire.TxOut
// IsLocal indicates if the output belong to the local party.
IsLocal bool
}
// CloseSortFunc is a function type alias for a function that sorts the closing
// transaction.
type CloseSortFunc func(*wire.MsgTx) error
// chanCloseOpt is a functional option that can be used to modify the co-op
// close process.
type chanCloseOpt struct {
musigSession *MusigSession
extraCloseOutputs []CloseOutput
// customSort is a custom function that can be used to sort the
// transaction outputs. If this isn't set, then the default BIP-69
// sorting is used.
customSort CloseSortFunc
customSequence fn.Option[uint32]
customLockTime fn.Option[uint32]
customPayer fn.Option[lntypes.ChannelParty]
}
// ChanCloseOpt is a closure type that cen be used to modify the set of default
// options.
type ChanCloseOpt func(*chanCloseOpt)
// defaultCloseOpts is the default set of close options.
func defaultCloseOpts() *chanCloseOpt {
return &chanCloseOpt{}
}
// WithCoopCloseMusigSession can be used to apply an existing musig2 session to
// the cooperative close process. If specified, then a musig2 co-op close
// (single sig keyspend) will be used.
func WithCoopCloseMusigSession(session *MusigSession) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.musigSession = session
}
}
// WithExtraCloseOutputs can be used to add extra outputs to the cooperative
// transaction.
func WithExtraCloseOutputs(extraOutputs []CloseOutput) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.extraCloseOutputs = extraOutputs
}
}
// WithCustomCoopSort can be used to modify the way the co-op close transaction
// is sorted.
func WithCustomCoopSort(sorter CloseSortFunc) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.customSort = sorter
}
}
// WithCustomSequence can be used to specify a custom sequence number for the
// co-op close process. Otherwise, a default non-final sequence will be used.
func WithCustomSequence(sequence uint32) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.customSequence = fn.Some(sequence)
}
}
// WithCustomLockTime can be used to specify a custom lock time for the coop
// close transaction.
func WithCustomLockTime(lockTime uint32) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.customLockTime = fn.Some(lockTime)
}
}
// WithCustomPayer can be used to specify a custom payer for the closing
// transaction. This overrides the default payer, which is the initiator of the
// channel.
func WithCustomPayer(payer lntypes.ChannelParty) ChanCloseOpt {
return func(opts *chanCloseOpt) {
opts.customPayer = fn.Some(payer)
}
}
// CreateCloseProposal is used by both parties in a cooperative channel close
// workflow to generate proposed close transactions and signatures. This method
// should only be executed once all pending HTLCs (if any) on the channel have
// been cleared/removed. Upon completion, the source channel will shift into
// the "closing" state, which indicates that all incoming/outgoing HTLC
// requests should be rejected. A signature for the closing transaction is
// returned.
func (lc *LightningChannel) CreateCloseProposal(proposedFee btcutil.Amount,
localDeliveryScript []byte, remoteDeliveryScript []byte,
closeOpts ...ChanCloseOpt) (input.Signature, *wire.MsgTx,
btcutil.Amount, error) {
lc.Lock()
defer lc.Unlock()
opts := defaultCloseOpts()
for _, optFunc := range closeOpts {
optFunc(opts)
}
// Unless there's a custom payer (sign of the RBF flow), if we're
// already closing the channel, then ignore this request.
if lc.isClosed && opts.customPayer.IsNone() {
return nil, nil, 0, ErrChanClosing
}
// Get the final balances after subtracting the proposed fee, taking
// care not to persist the adjusted balance, as the feeRate may change
// during the channel closing process.
ourBalance, theirBalance, err := CoopCloseBalance(
lc.channelState.ChanType, lc.channelState.IsInitiator,
proposedFee,
lc.channelState.LocalCommitment.LocalBalance.ToSatoshis(),
lc.channelState.LocalCommitment.RemoteBalance.ToSatoshis(),
lc.channelState.LocalCommitment.CommitFee,
opts.customPayer,
)
if err != nil {
return nil, nil, 0, err
}
var closeTxOpts []CloseTxOpt
// If this is a taproot channel, then we use an RBF'able funding input.
if lc.channelState.ChanType.IsTaproot() {
closeTxOpts = append(closeTxOpts, WithRBFCloseTx())
}
// If we have any extra outputs to pass along, then we'll map that to
// the co-op close option txn type.
if opts.extraCloseOutputs != nil {
closeTxOpts = append(closeTxOpts, WithExtraTxCloseOutputs(
opts.extraCloseOutputs,
))
}
if opts.customSort != nil {
closeTxOpts = append(
closeTxOpts, WithCustomTxSort(opts.customSort),
)
}
opts.customSequence.WhenSome(func(sequence uint32) {
closeTxOpts = append(closeTxOpts, WithCustomTxInSequence(
sequence,
))
})
opts.customLockTime.WhenSome(func(lockTime uint32) {
closeTxOpts = append(closeTxOpts, WithCustomTxLockTime(
lockTime,
))
})
closeTx, err := CreateCooperativeCloseTx(
fundingTxIn(lc.channelState), lc.channelState.LocalChanCfg.DustLimit,
lc.channelState.RemoteChanCfg.DustLimit, ourBalance, theirBalance,
localDeliveryScript, remoteDeliveryScript, closeTxOpts...,
)
if err != nil {
return nil, nil, 0, err
}
// Ensure that the transaction doesn't explicitly violate any
// consensus rules such as being too big, or having any value with a
// negative output.
tx := btcutil.NewTx(closeTx)
if err := blockchain.CheckTransactionSanity(tx); err != nil {
return nil, nil, 0, err
}
// If we have a co-op close musig session, then this is a taproot
// channel, so we'll generate a _partial_ signature.
var sig input.Signature
if opts.musigSession != nil {
sig, err = opts.musigSession.SignCommit(closeTx)
if err != nil {
return nil, nil, 0, err
}
} else {
// For regular channels we'll, sign the completed cooperative
// closure transaction. As the initiator we'll simply send our
// signature over to the remote party, using the generated txid
// to be notified once the closure transaction has been
// confirmed.
lc.signDesc.SigHashes = input.NewTxSigHashesV0Only(closeTx)
sig, err = lc.Signer.SignOutputRaw(closeTx, lc.signDesc)
if err != nil {
return nil, nil, 0, err
}
}
return sig, closeTx, ourBalance, nil
}
// CompleteCooperativeClose completes the cooperative closure of the target
// active lightning channel. A fully signed closure transaction as well as the
// signature itself are returned. Additionally, we also return our final
// settled balance, which reflects any fees we may have paid.
//
// NOTE: The passed local and remote sigs are expected to be fully complete
// signatures including the proper sighash byte.
func (lc *LightningChannel) CompleteCooperativeClose(
localSig, remoteSig input.Signature,
localDeliveryScript, remoteDeliveryScript []byte,
proposedFee btcutil.Amount,
closeOpts ...ChanCloseOpt) (*wire.MsgTx, btcutil.Amount, error) {
lc.Lock()
defer lc.Unlock()
opts := defaultCloseOpts()
for _, optFunc := range closeOpts {
optFunc(opts)
}
// Unless there's a custom payer (sign of the RBF flow), if we're
// already closing the channel, then ignore this request.
if lc.isClosed && opts.customPayer.IsNone() {
return nil, 0, ErrChanClosing
}
// Get the final balances after subtracting the proposed fee.
ourBalance, theirBalance, err := CoopCloseBalance(
lc.channelState.ChanType, lc.channelState.IsInitiator,
proposedFee,
lc.channelState.LocalCommitment.LocalBalance.ToSatoshis(),
lc.channelState.LocalCommitment.RemoteBalance.ToSatoshis(),
lc.channelState.LocalCommitment.CommitFee,
opts.customPayer,
)
if err != nil {
return nil, 0, err
}
var closeTxOpts []CloseTxOpt
// If this is a taproot channel, then we use an RBF'able funding input.
if lc.channelState.ChanType.IsTaproot() {
closeTxOpts = append(closeTxOpts, WithRBFCloseTx())
}
// If we have any extra outputs to pass along, then we'll map that to
// the co-op close option txn type.
if opts.extraCloseOutputs != nil {
closeTxOpts = append(closeTxOpts, WithExtraTxCloseOutputs(
opts.extraCloseOutputs,
))
}
if opts.customSort != nil {
closeTxOpts = append(
closeTxOpts, WithCustomTxSort(opts.customSort),
)
}
opts.customSequence.WhenSome(func(sequence uint32) {
closeTxOpts = append(closeTxOpts, WithCustomTxInSequence(
sequence,
))
})
opts.customLockTime.WhenSome(func(lockTime uint32) {
closeTxOpts = append(closeTxOpts, WithCustomTxLockTime(
lockTime,
))
})
// Create the transaction used to return the current settled balance
// on this active channel back to both parties. In this current model,
// the initiator pays full fees for the cooperative close transaction.
closeTx, err := CreateCooperativeCloseTx(
fundingTxIn(lc.channelState), lc.channelState.LocalChanCfg.DustLimit,
lc.channelState.RemoteChanCfg.DustLimit, ourBalance, theirBalance,
localDeliveryScript, remoteDeliveryScript, closeTxOpts...,
)
if err != nil {
return nil, 0, err
}
// Ensure that the transaction doesn't explicitly validate any
// consensus rules such as being too big, or having any value with a
// negative output.
tx := btcutil.NewTx(closeTx)
prevOut := lc.signDesc.Output
if err := blockchain.CheckTransactionSanity(tx); err != nil {
return nil, 0, err
}
prevOutputFetcher := txscript.NewCannedPrevOutputFetcher(
prevOut.PkScript, prevOut.Value,
)
hashCache := txscript.NewTxSigHashes(closeTx, prevOutputFetcher)
// Next, we'll complete the co-op close transaction. Depending on the
// set of options, we'll either do a regular p2wsh spend, or construct
// the final schnorr signature from a set of partial sigs.
if opts.musigSession != nil {
// For taproot channels, we'll use the attached session to
// combine the two partial signatures into a proper schnorr
// signature.
remotePartialSig, ok := remoteSig.(*MusigPartialSig)
if !ok {
return nil, 0, fmt.Errorf("expected MusigPartialSig, "+
"got %T", remoteSig)
}
finalSchnorrSig, err := opts.musigSession.CombineSigs(
remotePartialSig.sig,
)
if err != nil {
return nil, 0, fmt.Errorf("unable to combine "+
"final co-op close sig: %w", err)
}
// The witness for a keyspend is just the signature itself.
closeTx.TxIn[0].Witness = wire.TxWitness{
finalSchnorrSig.Serialize(),
}
} else {
// For regular channels, we'll need to , construct the witness
// stack minding the order of the pubkeys+sigs on the stack.
ourKey := lc.channelState.LocalChanCfg.MultiSigKey.PubKey.
SerializeCompressed()
theirKey := lc.channelState.RemoteChanCfg.MultiSigKey.PubKey.
SerializeCompressed()
witness := input.SpendMultiSig(
lc.signDesc.WitnessScript, ourKey, localSig, theirKey,
remoteSig,
)
closeTx.TxIn[0].Witness = witness
}
// Validate the finalized transaction to ensure the output script is
// properly met, and that the remote peer supplied a valid signature.
vm, err := txscript.NewEngine(
prevOut.PkScript, closeTx, 0, txscript.StandardVerifyFlags, nil,
hashCache, prevOut.Value, prevOutputFetcher,
)
if err != nil {
return nil, 0, err
}
if err := vm.Execute(); err != nil {
return nil, 0, err
}
// As the transaction is sane, and the scripts are valid we'll mark the
// channel now as closed as the closure transaction should get into the
// chain in a timely manner and possibly be re-broadcast by the wallet.
lc.isClosed = true
return closeTx, ourBalance, nil
}
// AnchorResolutions is a set of anchor resolutions that's being used when
// sweeping anchors during local channel force close.
type AnchorResolutions struct {
// Local is the anchor resolution for the local commitment tx.
Local *AnchorResolution
// Remote is the anchor resolution for the remote commitment tx.
Remote *AnchorResolution
// RemotePending is the anchor resolution for the remote pending
// commitment tx. The value will be non-nil iff we've created a new
// commitment tx for the remote party which they haven't ACKed yet.
RemotePending *AnchorResolution
}
// NewAnchorResolutions returns a set of anchor resolutions wrapped in the
// struct AnchorResolutions. Because we have no view on the mempool, we can
// only blindly anchor all of these txes down. The caller needs to check the
// returned values against nil to decide whether there exists an anchor
// resolution for local/remote/pending remote commitment txes.
func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions,
error) {
lc.Lock()
defer lc.Unlock()
var resolutions AnchorResolutions
// Add anchor for local commitment tx, if any.
revocation, err := lc.channelState.RevocationProducer.AtIndex(
lc.currentHeight,
)
if err != nil {
return nil, err
}
localCommitPoint := input.ComputeCommitmentPoint(revocation[:])
localKeyRing := DeriveCommitmentKeys(
localCommitPoint, lntypes.Local, lc.channelState.ChanType,
&lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg,
)
localRes, err := NewAnchorResolution(
lc.channelState, lc.channelState.LocalCommitment.CommitTx,
localKeyRing, lntypes.Local,
)
if err != nil {
return nil, err
}
resolutions.Local = localRes
// Add anchor for remote commitment tx, if any.
remoteKeyRing := DeriveCommitmentKeys(
lc.channelState.RemoteCurrentRevocation, lntypes.Remote,
lc.channelState.ChanType, &lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
)
remoteRes, err := NewAnchorResolution(
lc.channelState, lc.channelState.RemoteCommitment.CommitTx,
remoteKeyRing, lntypes.Remote,
)
if err != nil {
return nil, err
}
resolutions.Remote = remoteRes
// Add anchor for remote pending commitment tx, if any.
remotePendingCommit, err := lc.channelState.RemoteCommitChainTip()
if err != nil && err != channeldb.ErrNoPendingCommit {
return nil, err
}
if remotePendingCommit != nil {
pendingRemoteKeyRing := DeriveCommitmentKeys(
lc.channelState.RemoteNextRevocation, lntypes.Remote,
lc.channelState.ChanType, &lc.channelState.LocalChanCfg,
&lc.channelState.RemoteChanCfg,
)
remotePendingRes, err := NewAnchorResolution(
lc.channelState,
remotePendingCommit.Commitment.CommitTx,
pendingRemoteKeyRing, lntypes.Remote,
)
if err != nil {
return nil, err
}
resolutions.RemotePending = remotePendingRes
}
return &resolutions, nil
}
// NewAnchorResolution returns the information that is required to sweep the
// local anchor.
func NewAnchorResolution(chanState *channeldb.OpenChannel,
commitTx *wire.MsgTx, keyRing *CommitmentKeyRing,
whoseCommit lntypes.ChannelParty) (*AnchorResolution, error) {
// Return nil resolution if the channel has no anchors.
if !chanState.ChanType.HasAnchors() {
return nil, nil
}
// Derive our local anchor script. For taproot channels, rather than
// use the same multi-sig key for both commitments, the anchor script
// will differ depending on if this is our local or remote
// commitment.
localAnchor, remoteAnchor, err := CommitScriptAnchors(
chanState.ChanType, &chanState.LocalChanCfg,
&chanState.RemoteChanCfg, keyRing,
)
if err != nil {
return nil, err
}
if chanState.ChanType.IsTaproot() && whoseCommit.IsRemote() {
//nolint:ineffassign
localAnchor, remoteAnchor = remoteAnchor, localAnchor
}
// TODO(roasbeef): remote anchor not needed above
// Look up the script on the commitment transaction. It may not be
// present if there is no output paying to us.
found, index := input.FindScriptOutputIndex(
commitTx, localAnchor.PkScript(),
)
if !found {
return nil, nil
}
// For anchor outputs, we'll only ever care about the success path.
// script (sweep after 1 block csv delay).
anchorWitnessScript, err := localAnchor.WitnessScriptForPath(
input.ScriptPathSuccess,
)
if err != nil {
return nil, err
}
outPoint := &wire.OutPoint{
Hash: commitTx.TxHash(),
Index: index,
}
// Instantiate the sign descriptor that allows sweeping of the anchor.
signDesc := &input.SignDescriptor{
KeyDesc: chanState.LocalChanCfg.MultiSigKey,
WitnessScript: anchorWitnessScript,
Output: &wire.TxOut{
PkScript: localAnchor.PkScript(),
Value: int64(AnchorSize),
},
HashType: sweepSigHash(chanState.ChanType),
}
// For taproot outputs, we'll need to ensure that the proper sign
// method is used, and the tweak as well.
if scriptTree, ok := localAnchor.(input.TapscriptDescriptor); ok {
signDesc.SignMethod = input.TaprootKeySpendSignMethod
//nolint:ll
signDesc.PrevOutputFetcher = txscript.NewCannedPrevOutputFetcher(
localAnchor.PkScript(), int64(AnchorSize),
)
// For anchor outputs with taproot channels, the key desc is
// also different: we'll just re-use our local delay base point
// (which becomes our to local output).
if whoseCommit.IsLocal() {
// In addition to the sign method, we'll also need to
// ensure that the single tweak is set, as with the
// current formulation, we'll need to use two levels of
// tweaks: the normal LN tweak, and the tapscript
// tweak.
signDesc.SingleTweak = keyRing.LocalCommitKeyTweak
signDesc.KeyDesc = chanState.LocalChanCfg.DelayBasePoint
} else {
// When we're playing the force close of a remote
// commitment, as this is a "tweakless" channel type,
// we don't need a tweak value at all.
//
//nolint:ll
signDesc.KeyDesc = chanState.LocalChanCfg.PaymentBasePoint
}
// Finally, as this is a keyspend method, we'll need to also
// include the taptweak as well.
signDesc.TapTweak = scriptTree.TapTweak()
}
var witnessWeight int64
if chanState.ChanType.IsTaproot() {
witnessWeight = input.TaprootKeyPathWitnessSize
} else {
witnessWeight = input.WitnessCommitmentTxWeight
}
// Calculate commit tx weight. This commit tx doesn't yet include the
// witness spending the funding output, so we add the (worst case)
// weight for that too.
utx := btcutil.NewTx(commitTx)
weight := blockchain.GetTransactionWeight(utx) + witnessWeight
// Calculate commit tx fee.
fee := chanState.Capacity
for _, out := range commitTx.TxOut {
fee -= btcutil.Amount(out.Value)
}
return &AnchorResolution{
CommitAnchor: *outPoint,
AnchorSignDescriptor: *signDesc,
CommitWeight: lntypes.WeightUnit(weight),
CommitFee: fee,
}, nil
}
// AvailableBalance returns the current balance available for sending within
// the channel. By available balance, we mean that if at this very instance a
// new commitment were to be created which evals all the log entries, what
// would our available balance for adding an additional HTLC be. It takes into
// account the fee that must be paid for adding this HTLC, that we cannot spend
// from the channel reserve and moreover the FeeBuffer when we are the
// initiator of the channel. This method is useful when deciding if a given
// channel can accept an HTLC in the multi-hop forwarding scenario.
func (lc *LightningChannel) AvailableBalance() lnwire.MilliSatoshi {
lc.RLock()
defer lc.RUnlock()
bal, _ := lc.availableBalance(FeeBuffer)
return bal
}
// availableBalance is the private, non mutexed version of AvailableBalance.
// This method is provided so methods that already hold the lock can access
// this method. Additionally, the total weight of the next to be created
// commitment is returned for accounting purposes.
func (lc *LightningChannel) availableBalance(
buffer BufferType) (lnwire.MilliSatoshi, lntypes.WeightUnit) {
// We'll grab the current set of log updates that the remote has
// ACKed.
remoteACKedIndex := lc.commitChains.Local.tip().messageIndices.Remote
htlcView := lc.fetchHTLCView(remoteACKedIndex,
lc.updateLogs.Local.logIndex)
// Calculate our available balance from our local commitment.
// TODO(halseth): could reuse parts validateCommitmentSanity to do this
// balance calculation, as most of the logic is the same.
//
// NOTE: This is not always accurate, since the remote node can always
// add updates concurrently, causing our balance to go down if we're
// the initiator, but this is a problem on the protocol level.
ourLocalCommitBalance, commitWeight := lc.availableCommitmentBalance(
htlcView, lntypes.Local, buffer,
)
// Do the same calculation from the remote commitment point of view.
ourRemoteCommitBalance, _ := lc.availableCommitmentBalance(
htlcView, lntypes.Remote, buffer,
)
// Return which ever balance is lowest.
if ourRemoteCommitBalance < ourLocalCommitBalance {
return ourRemoteCommitBalance, commitWeight
}
return ourLocalCommitBalance, commitWeight
}
// availableCommitmentBalance attempts to calculate the balance we have
// available for HTLCs on the local/remote commitment given the HtlcView. To
// account for sending HTLCs of different sizes, it will report the balance
// available for sending non-dust HTLCs, which will be manifested on the
// commitment, increasing the commitment fee we must pay as an initiator,
// eating into our balance. It will make sure we won't violate the channel
// reserve constraints for this amount.
func (lc *LightningChannel) availableCommitmentBalance(view *HtlcView,
whoseCommitChain lntypes.ChannelParty, buffer BufferType) (
lnwire.MilliSatoshi, lntypes.WeightUnit) {
// Compute the current balances for this commitment. This will take
// into account HTLCs to determine the commit weight, which the
// initiator must pay the fee for.
ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView(
view, whoseCommitChain, false,
fn.None[chainfee.SatPerKWeight](),
)
if err != nil {
lc.log.Errorf("Unable to fetch available balance: %v", err)
return 0, 0
}
// We can never spend from the channel reserve, so we'll subtract it
// from our available balance.
ourReserve := lnwire.NewMSatFromSatoshis(
lc.channelState.LocalChanCfg.ChanReserve,
)
if ourReserve <= ourBalance {
ourBalance -= ourReserve
} else {
ourBalance = 0
}
// Calculate the commitment fee in the case where we would add another
// HTLC to the commitment, as only the balance remaining after this fee
// has been paid is actually available for sending.
feePerKw := filteredView.FeePerKw
additionalHtlcFee := lnwire.NewMSatFromSatoshis(
feePerKw.FeeForWeight(input.HTLCWeight),
)
commitFee := lnwire.NewMSatFromSatoshis(
feePerKw.FeeForWeight(commitWeight))
if lc.channelState.IsInitiator {
// When the buffer is of type `FeeBuffer` type we know we are
// going to send or forward an htlc over this channel therefore
// we account for an additional htlc output on the commitment
// tx.
futureCommitWeight := commitWeight
if buffer == FeeBuffer {
futureCommitWeight += input.HTLCWeight
}
// Make sure we do not overwrite `ourBalance` that's why we
// declare bufferAmt beforehand.
var bufferAmt lnwire.MilliSatoshi
ourBalance, bufferAmt, commitFee, err = lc.applyCommitFee(
ourBalance, futureCommitWeight, feePerKw, buffer,
)
if err != nil {
lc.log.Debugf("No available local balance after "+
"applying the CommitmentFee of the new "+
"CommitmentState(%v): ourBalance would drop "+
"below the reserve: "+
"ourBalance(w/o reserve)=%v, reserve=%v, "+
"current commitFee(w/o additional htlc)=%v, "+
"feeBuffer=%v (type=%v) "+
"local_chan_initiator=%v", whoseCommitChain,
ourBalance, ourReserve, commitFee, bufferAmt,
buffer, lc.channelState.IsInitiator)
return 0, commitWeight
}
return ourBalance, commitWeight
}
// If we're not the initiator, we must check whether the remote has
// enough balance to pay for the fee of our HTLC. We'll start by also
// subtracting our counterparty's reserve from their balance.
theirReserve := lnwire.NewMSatFromSatoshis(
lc.channelState.RemoteChanCfg.ChanReserve,
)
if theirReserve <= theirBalance {
theirBalance -= theirReserve
} else {
theirBalance = 0
}
// We'll use the dustlimit and htlcFee to find the largest HTLC value
// that will be considered dust on the commitment.
dustlimit := lnwire.NewMSatFromSatoshis(
lc.channelState.LocalChanCfg.DustLimit,
)
// For an extra HTLC fee to be paid on our commitment, the HTLC must be
// large enough to make a non-dust HTLC timeout transaction.
htlcFee := lnwire.NewMSatFromSatoshis(
HtlcTimeoutFee(lc.channelState.ChanType, feePerKw),
)
// If we are looking at the remote commitment, we must use the remote
// dust limit and the fee for adding an HTLC success transaction.
if whoseCommitChain.IsRemote() {
dustlimit = lnwire.NewMSatFromSatoshis(
lc.channelState.RemoteChanCfg.DustLimit,
)
htlcFee = lnwire.NewMSatFromSatoshis(
HtlcSuccessFee(lc.channelState.ChanType, feePerKw),
)
}
// The HTLC output will be manifested on the commitment if it
// is non-dust after paying the HTLC fee.
nonDustHtlcAmt := dustlimit + htlcFee
// commitFeeWithHtlc is the fee our peer has to pay in case we add
// another htlc to the commitment.
commitFeeWithHtlc := commitFee + additionalHtlcFee
// If they cannot pay the fee if we add another non-dust HTLC, we'll
// report our available balance just below the non-dust amount, to
// avoid attempting HTLCs larger than this size.
if theirBalance < commitFeeWithHtlc && ourBalance >= nonDustHtlcAmt {
// see https://github.com/lightning/bolts/issues/728
ourReportedBalance := nonDustHtlcAmt - 1
lc.log.Infof("Reducing local (reported) balance "+
"(from %v to %v): remote side does not have enough "+
"funds (%v < %v) to pay for non-dust HTLC in case of "+
"unilateral close.", ourBalance, ourReportedBalance,
theirBalance, commitFeeWithHtlc)
ourBalance = ourReportedBalance
}
return ourBalance, commitWeight
}
// StateSnapshot returns a snapshot of the current fully committed state within
// the channel.
func (lc *LightningChannel) StateSnapshot() *channeldb.ChannelSnapshot {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.Snapshot()
}
// validateFeeRate ensures that if the passed fee is applied to the channel,
// and a new commitment is created (which evaluates this fee), then the
// initiator of the channel does not dip below their reserve.
func (lc *LightningChannel) validateFeeRate(feePerKw chainfee.SatPerKWeight) error {
// We'll ensure that we can accommodate this new fee change, yet still
// be above our reserve balance. Otherwise, we'll reject the fee
// update.
// We do not enforce the FeeBuffer here because it was exactly
// introduced to use this buffer for potential fee rate increases.
availableBalance, txWeight := lc.availableBalance(AdditionalHtlc)
oldFee := lnwire.NewMSatFromSatoshis(
lc.commitChains.Local.tip().feePerKw.FeeForWeight(txWeight),
)
// Our base balance is the total amount of satoshis we can commit
// towards fees before factoring in the channel reserve.
baseBalance := availableBalance + oldFee
// Using the weight of the commitment transaction if we were to create
// a commitment now, we'll compute our remaining balance if we apply
// this new fee update.
newFee := lnwire.NewMSatFromSatoshis(
feePerKw.FeeForWeight(txWeight),
)
// If the total fee exceeds our available balance (taking into account
// the fee from the last state), then we'll reject this update as it
// would mean we need to trim our entire output.
if newFee > baseBalance {
return fmt.Errorf("cannot apply fee_update=%v sat/kw, new fee "+
"of %v is greater than balance of %v", int64(feePerKw),
newFee, baseBalance)
}
// TODO(halseth): should fail if fee update is unreasonable,
// as specified in BOLT#2.
// * COMMENT(roasbeef): can cross-check with our ideal fee rate
return nil
}
// UpdateFee initiates a fee update for this channel. Must only be called by
// the channel initiator, and must be called before sending update_fee to
// the remote.
func (lc *LightningChannel) UpdateFee(feePerKw chainfee.SatPerKWeight) error {
lc.Lock()
defer lc.Unlock()
// Only initiator can send fee update, so trying to send one as
// non-initiator will fail.
if !lc.channelState.IsInitiator {
return fmt.Errorf("local fee update as non-initiator")
}
// Ensure that the passed fee rate meets our current requirements.
if err := lc.validateFeeRate(feePerKw); err != nil {
return err
}
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
LogIndex: lc.updateLogs.Local.logIndex,
Amount: lnwire.NewMSatFromSatoshis(btcutil.Amount(feePerKw)),
EntryType: FeeUpdate,
}
lc.updateLogs.Local.appendUpdate(pd)
return nil
}
// CommitFeeTotalAt applies a proposed feerate to the channel and returns the
// commitment fee with this new feerate. It does not modify the underlying
// LightningChannel.
func (lc *LightningChannel) CommitFeeTotalAt(
feePerKw chainfee.SatPerKWeight) (btcutil.Amount, btcutil.Amount,
error) {
lc.RLock()
defer lc.RUnlock()
dryRunFee := fn.Some[chainfee.SatPerKWeight](feePerKw)
// We want to grab every update in both update logs to calculate the
// commitment fees in the worst-case with this fee-rate.
localIdx := lc.updateLogs.Local.logIndex
remoteIdx := lc.updateLogs.Remote.logIndex
localHtlcView := lc.fetchHTLCView(remoteIdx, localIdx)
var localCommitFee, remoteCommitFee btcutil.Amount
// Compute the local commitment's weight.
_, _, localWeight, _, err := lc.computeView(
localHtlcView, lntypes.Local, false, dryRunFee,
)
if err != nil {
return 0, 0, err
}
localCommitFee = feePerKw.FeeForWeight(localWeight)
// Create another view in case for some reason the prior one was
// mutated.
remoteHtlcView := lc.fetchHTLCView(remoteIdx, localIdx)
// Compute the remote commitment's weight.
_, _, remoteWeight, _, err := lc.computeView(
remoteHtlcView, lntypes.Remote, false, dryRunFee,
)
if err != nil {
return 0, 0, err
}
remoteCommitFee = feePerKw.FeeForWeight(remoteWeight)
return localCommitFee, remoteCommitFee, err
}
// ReceiveUpdateFee handles an updated fee sent from remote. This method will
// return an error if called as channel initiator.
func (lc *LightningChannel) ReceiveUpdateFee(feePerKw chainfee.SatPerKWeight) error {
lc.Lock()
defer lc.Unlock()
// Only initiator can send fee update, and we must fail if we receive
// fee update as initiator
if lc.channelState.IsInitiator {
return fmt.Errorf("received fee update as initiator")
}
// TODO(roasbeef): or just modify to use the other balance?
pd := &paymentDescriptor{
ChanID: lc.ChannelID(),
LogIndex: lc.updateLogs.Remote.logIndex,
Amount: lnwire.NewMSatFromSatoshis(btcutil.Amount(feePerKw)),
EntryType: FeeUpdate,
}
lc.updateLogs.Remote.appendUpdate(pd)
return nil
}
// generateRevocation generates the revocation message for a given height.
func (lc *LightningChannel) generateRevocation(height uint64) (*lnwire.RevokeAndAck,
error) {
// Now that we've accept a new state transition, we send the remote
// party the revocation for our current commitment state.
revocationMsg := &lnwire.RevokeAndAck{}
commitSecret, err := lc.channelState.RevocationProducer.AtIndex(height)
if err != nil {
return nil, err
}
copy(revocationMsg.Revocation[:], commitSecret[:])
// Along with this revocation, we'll also send the _next_ commitment
// point that the remote party should use to create our next commitment
// transaction. We use a +2 here as we already gave them a look ahead
// of size one after the ChannelReady message was sent:
//
// 0: current revocation, 1: their "next" revocation, 2: this revocation
//
// We're revoking the current revocation. Once they receive this
// message they'll set the "current" revocation for us to their stored
// "next" revocation, and this revocation will become their new "next"
// revocation.
//
// Put simply in the window slides to the left by one.
revHeight := height + 2
nextCommitSecret, err := lc.channelState.RevocationProducer.AtIndex(
revHeight,
)
if err != nil {
return nil, err
}
revocationMsg.NextRevocationKey = input.ComputeCommitmentPoint(nextCommitSecret[:])
revocationMsg.ChanID = lnwire.NewChanIDFromOutPoint(
lc.channelState.FundingOutpoint,
)
// If this is a taproot channel, then we also need to generate the
// verification nonce for this target state.
if lc.channelState.ChanType.IsTaproot() {
nextVerificationNonce, err := channeldb.NewMusigVerificationNonce( //nolint:ll
lc.channelState.LocalChanCfg.MultiSigKey.PubKey,
revHeight, lc.taprootNonceProducer,
)
if err != nil {
return nil, err
}
revocationMsg.LocalNonce = lnwire.SomeMusig2Nonce(
nextVerificationNonce.PubNonce,
)
}
return revocationMsg, nil
}
// closeTxOpts houses the set of options that modify how the cooperative close
// tx is to be constructed.
type closeTxOpts struct {
// enableRBF indicates whether the cooperative close tx should signal
// RBF or not.
enableRBF bool
// extraCloseOutputs is a set of additional outputs that should be
// added the co-op close transaction.
extraCloseOutputs []CloseOutput
// customSort is a custom function that can be used to sort the
// transaction outputs. If this isn't set, then the default BIP-69
// sorting is used.
customSort CloseSortFunc
// customSequence is an optional custom sequence to set on the co-op
// close transaction. This gives slightly more control compared to the
// enableRBF option.
customSequence fn.Option[uint32]
customLockTime fn.Option[uint32]
}
// defaultCloseTxOpts returns a closeTxOpts struct with default values.
func defaultCloseTxOpts() closeTxOpts {
return closeTxOpts{
enableRBF: false,
}
}
// CloseTxOpt is a functional option that allows us to modify how the closing
// transaction is created.
type CloseTxOpt func(*closeTxOpts)
// WithRBFCloseTx signals that the cooperative close tx should signal RBF.
func WithRBFCloseTx() CloseTxOpt {
return func(o *closeTxOpts) {
o.enableRBF = true
}
}
// WithExtraTxCloseOutputs can be used to add extra outputs to the cooperative
// transaction.
func WithExtraTxCloseOutputs(extraOutputs []CloseOutput) CloseTxOpt {
return func(o *closeTxOpts) {
o.extraCloseOutputs = extraOutputs
}
}
// WithCustomTxSort can be used to modify the way the close transaction is
// sorted.
func WithCustomTxSort(sorter CloseSortFunc) CloseTxOpt {
return func(opts *closeTxOpts) {
opts.customSort = sorter
}
}
// WithCustomTxInSequence allows a caller to set a custom sequence on the sole
// input of the co-op close tx.
func WithCustomTxInSequence(sequence uint32) CloseTxOpt {
return func(o *closeTxOpts) {
o.customSequence = fn.Some(sequence)
}
}
func WithCustomTxLockTime(lockTime uint32) CloseTxOpt {
return func(o *closeTxOpts) {
o.customLockTime = fn.Some(lockTime)
}
}
// CreateCooperativeCloseTx creates a transaction which if signed by both
// parties, then broadcast cooperatively closes an active channel. The creation
// of the closure transaction is modified by a boolean indicating if the party
// constructing the channel is the initiator of the closure. Currently it is
// expected that the initiator pays the transaction fees for the closing
// transaction in full.
func CreateCooperativeCloseTx(fundingTxIn wire.TxIn,
localDust, remoteDust, ourBalance, theirBalance btcutil.Amount,
ourDeliveryScript, theirDeliveryScript []byte,
closeOpts ...CloseTxOpt) (*wire.MsgTx, error) {
opts := defaultCloseTxOpts()
for _, optFunc := range closeOpts {
optFunc(&opts)
}
// If RBF is signalled, then we'll modify the sequence to permit
// replacement.
if opts.enableRBF {
fundingTxIn.Sequence = mempool.MaxRBFSequence
}
// Otherwise, a custom sequence might be specified.
opts.customSequence.WhenSome(func(sequence uint32) {
fundingTxIn.Sequence = sequence
})
// Construct the transaction to perform a cooperative closure of the
// channel. In the event that one side doesn't have any settled funds
// within the channel then a refund output for that particular side can
// be omitted.
closeTx := wire.NewMsgTx(2)
closeTx.AddTxIn(&fundingTxIn)
opts.customLockTime.WhenSome(func(lockTime uint32) {
closeTx.LockTime = lockTime
})
// Create both cooperative closure outputs, properly respecting the dust
// limits of both parties.
var localOutputIdx fn.Option[int]
haveLocalOutput := ourBalance >= localDust
if haveLocalOutput {
// If our script is an OP_RETURN, then we set our balance to
// zero.
if opts.customSequence.IsSome() &&
input.ScriptIsOpReturn(ourDeliveryScript) {
ourBalance = 0
}
closeTx.AddTxOut(&wire.TxOut{
PkScript: ourDeliveryScript,
Value: int64(ourBalance),
})
localOutputIdx = fn.Some(len(closeTx.TxOut) - 1)
}
var remoteOutputIdx fn.Option[int]
haveRemoteOutput := theirBalance >= remoteDust
if haveRemoteOutput {
// If a party's script is an OP_RETURN, then we set their
// balance to zero.
if opts.customSequence.IsSome() &&
input.ScriptIsOpReturn(theirDeliveryScript) {
theirBalance = 0
}
closeTx.AddTxOut(&wire.TxOut{
PkScript: theirDeliveryScript,
Value: int64(theirBalance),
})
remoteOutputIdx = fn.Some(len(closeTx.TxOut) - 1)
}
// If we have extra outputs to add to the co-op close transaction, then
// we'll examine them now. We'll deduct the output's value from the
// owning party. In the case that a party can't pay for the output, then
// their normal output will be omitted.
for _, extraTxOut := range opts.extraCloseOutputs {
switch {
// For additional local outputs, add the output, then deduct
// the balance from our local balance.
case extraTxOut.IsLocal:
// The extraCloseOutputs in the options just indicate if
// an extra output should be added in general. But we
// only add one if we actually _need_ one, based on the
// balance. If we don't have enough local balance to
// cover the extra output, then localOutputIdx is None.
localOutputIdx.WhenSome(func(idx int) {
// The output that currently represents the
// local balance, which means:
// txOut.Value == ourBalance.
txOut := closeTx.TxOut[idx]
// The extra output (if one exists) is the more
// important one, as in custom channels it might
// carry some additional values. The normal
// output is just an address that sends the
// local balance back to our wallet. The extra
// one also goes to our wallet, but might also
// carry other values, so it has higher
// priority. Do we have enough balance to have
// both the extra output with the given value
// (which is subtracted from our balance) and
// still an above-dust normal output? If not, we
// skip the extra output and just overwrite the
// existing output script with the one from the
// extra output.
amtAfterOutput := btcutil.Amount(
txOut.Value - extraTxOut.Value,
)
if amtAfterOutput <= localDust {
txOut.PkScript = extraTxOut.PkScript
return
}
txOut.Value -= extraTxOut.Value
closeTx.AddTxOut(&extraTxOut.TxOut)
})
// For extra remote outputs, we'll do the opposite.
case !extraTxOut.IsLocal:
// The extraCloseOutputs in the options just indicate if
// an extra output should be added in general. But we
// only add one if we actually _need_ one, based on the
// balance. If we don't have enough remote balance to
// cover the extra output, then remoteOutputIdx is None.
remoteOutputIdx.WhenSome(func(idx int) {
// The output that currently represents the
// remote balance, which means:
// txOut.Value == theirBalance.
txOut := closeTx.TxOut[idx]
// The extra output (if one exists) is the more
// important one, as in custom channels it might
// carry some additional values. The normal
// output is just an address that sends the
// remote balance back to their wallet. The
// extra one also goes to their wallet, but
// might also carry other values, so it has
// higher priority. Do they have enough balance
// to have both the extra output with the given
// value (which is subtracted from their
// balance) and still an above-dust normal
// output? If not, we skip the extra output and
// just overwrite the existing output script
// with the one from the extra output.
amtAfterOutput := btcutil.Amount(
txOut.Value - extraTxOut.Value,
)
if amtAfterOutput <= remoteDust {
txOut.PkScript = extraTxOut.PkScript
return
}
txOut.Value -= extraTxOut.Value
closeTx.AddTxOut(&extraTxOut.TxOut)
})
}
}
if opts.customSort != nil {
if err := opts.customSort(closeTx); err != nil {
return nil, err
}
} else {
txsort.InPlaceSort(closeTx)
}
return closeTx, nil
}
// LocalBalanceDust returns true if when creating a co-op close transaction,
// the balance of the local party will be dust after accounting for any anchor
// outputs.
func (lc *LightningChannel) LocalBalanceDust() (bool, btcutil.Amount) {
lc.RLock()
defer lc.RUnlock()
chanState := lc.channelState
localBalance := chanState.LocalCommitment.LocalBalance.ToSatoshis()
// If this is an anchor channel, and we're the initiator, then we'll
// regain the stats allocated to the anchor outputs with the co-op
// close transaction.
if chanState.ChanType.HasAnchors() && chanState.IsInitiator {
localBalance += 2 * AnchorSize
}
localDust := chanState.LocalChanCfg.DustLimit
return localBalance <= localDust, localDust
}
// RemoteBalanceDust returns true if when creating a co-op close transaction,
// the balance of the remote party will be dust after accounting for any anchor
// outputs.
func (lc *LightningChannel) RemoteBalanceDust() (bool, btcutil.Amount) {
lc.RLock()
defer lc.RUnlock()
chanState := lc.channelState
remoteBalance := chanState.RemoteCommitment.RemoteBalance.ToSatoshis()
// If this is an anchor channel, and they're the initiator, then we'll
// regain the stats allocated to the anchor outputs with the co-op
// close transaction.
if chanState.ChanType.HasAnchors() && !chanState.IsInitiator {
remoteBalance += 2 * AnchorSize
}
remoteDust := chanState.RemoteChanCfg.DustLimit
return remoteBalance <= remoteDust, remoteDust
}
// CommitBalances returns the local and remote balances in the current
// commitment state.
func (lc *LightningChannel) CommitBalances() (btcutil.Amount, btcutil.Amount) {
lc.RLock()
defer lc.RUnlock()
chanState := lc.channelState
localCommit := lc.channelState.LocalCommitment
localBalance := localCommit.LocalBalance.ToSatoshis()
remoteBalance := localCommit.RemoteBalance.ToSatoshis()
if chanState.ChanType.HasAnchors() {
if chanState.IsInitiator {
localBalance += 2 * AnchorSize
} else {
remoteBalance += 2 * AnchorSize
}
}
return localBalance, remoteBalance
}
// CommitFee returns the commitment fee for the current commitment state.
func (lc *LightningChannel) CommitFee() btcutil.Amount {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.LocalCommitment.CommitFee
}
// CalcFee returns the commitment fee to use for the given fee rate
// (fee-per-kw).
func (lc *LightningChannel) CalcFee(feeRate chainfee.SatPerKWeight) btcutil.Amount {
return feeRate.FeeForWeight(CommitWeight(lc.channelState.ChanType))
}
// MaxFeeRate returns the maximum fee rate given an allocation of the channel
// initiator's spendable balance along with the local reserve amount. This can
// be useful to determine when we should stop proposing fee updates that exceed
// our maximum allocation.
// Moreover it returns the share of the total balance in the range of [0,1]
// which can be allocated to fees. When our desired fee allocation would lead to
// a maximum fee rate below the current commitment fee rate we floor the maximum
// at the current fee rate which leads to different fee allocations than
// initially requested via `maxAllocation`.
//
// NOTE: This should only be used for channels in which the local commitment is
// the initiator.
func (lc *LightningChannel) MaxFeeRate(
maxAllocation float64) (chainfee.SatPerKWeight, float64) {
lc.RLock()
defer lc.RUnlock()
// The maximum fee depends on the available balance that can be
// committed towards fees. It takes into account our local reserve
// balance. We do not account for a FeeBuffer here because that is
// exactly why it was introduced to react for sharp fee changes.
availableBalance, weight := lc.availableBalance(AdditionalHtlc)
currentFee := lc.commitChains.Local.tip().feePerKw.FeeForWeight(weight)
// baseBalance is the maximum amount available for us to spend on fees.
baseBalance := availableBalance.ToSatoshis() + currentFee
// In case our local channel balance is drained, we make sure we do not
// decrease the fee rate below the current fee rate. This could lead to
// a scenario where we lower the commitment fee rate as low as the fee
// floor although current fee rates are way higher. The maximum fee
// we allow should not be smaller then the current fee. The decrease
// in fee rate should happen when the mempool reports lower fee levels
// rather than us decreasing in local balance. The max fee rate is
// always floored by the current fee rate of the channel.
idealMaxFee := float64(baseBalance) * maxAllocation
maxFee := max(float64(currentFee), idealMaxFee)
maxFeeAllocation := maxFee / float64(baseBalance)
maxFeeRate := chainfee.SatPerKWeight(maxFee / (float64(weight) / 1000))
return maxFeeRate, maxFeeAllocation
}
// IdealCommitFeeRate uses the current network fee, the minimum relay fee,
// maximum fee allocation and anchor channel commitment fee rate to determine
// the ideal fee to be used for the commitments of the channel.
func (lc *LightningChannel) IdealCommitFeeRate(netFeeRate, minRelayFeeRate,
maxAnchorCommitFeeRate chainfee.SatPerKWeight,
maxFeeAlloc float64) chainfee.SatPerKWeight {
// Get the maximum fee rate that we can use given our max fee allocation
// and given the local reserve balance that we must preserve.
maxFeeRate, _ := lc.MaxFeeRate(maxFeeAlloc)
var commitFeeRate chainfee.SatPerKWeight
// If the channel has anchor outputs then cap the fee rate at the
// max anchor fee rate if that maximum is less than our max fee rate.
// Otherwise, cap the fee rate at the max fee rate.
switch lc.channelState.ChanType.HasAnchors() &&
maxFeeRate > maxAnchorCommitFeeRate {
case true:
commitFeeRate = min(netFeeRate, maxAnchorCommitFeeRate)
case false:
commitFeeRate = min(netFeeRate, maxFeeRate)
}
if commitFeeRate >= minRelayFeeRate {
return commitFeeRate
}
// The commitment fee rate is below the minimum relay fee rate.
// If the min relay fee rate is still below the maximum fee, then use
// the minimum relay fee rate.
if minRelayFeeRate <= maxFeeRate {
return minRelayFeeRate
}
// The minimum relay fee rate is more than the ideal maximum fee rate.
// Check if it is smaller than the absolute maximum fee rate we can
// use. If it is, then we use the minimum relay fee rate and we log a
// warning to indicate that the max channel fee allocation option was
// ignored.
absoluteMaxFee, _ := lc.MaxFeeRate(1)
if minRelayFeeRate <= absoluteMaxFee {
lc.log.Warn("Ignoring max channel fee allocation to " +
"ensure that the commitment fee is above the " +
"minimum relay fee.")
return minRelayFeeRate
}
// The absolute maximum fee rate we can pay is below the minimum
// relay fee rate. The commitment tx will not be able to propagate.
// To give the transaction the best chance, we use the absolute
// maximum fee we have available and we log an error.
lc.log.Errorf("The commitment fee rate of %s is below the current "+
"minimum relay fee rate of %s. The max fee rate of %s will be "+
"used.", commitFeeRate, minRelayFeeRate, absoluteMaxFee)
return absoluteMaxFee
}
// RemoteNextRevocation returns the channelState's RemoteNextRevocation. For
// musig2 channels, until a nonce pair is processed by the remote party, a nil
// public key is returned.
//
// TODO(roasbeef): revisit, maybe just make a more general method instead?
func (lc *LightningChannel) RemoteNextRevocation() *btcec.PublicKey {
lc.RLock()
defer lc.RUnlock()
if !lc.channelState.ChanType.IsTaproot() {
return lc.channelState.RemoteNextRevocation
}
if lc.musigSessions == nil {
return nil
}
return lc.channelState.RemoteNextRevocation
}
// IsInitiator returns true if we were the ones that initiated the funding
// workflow which led to the creation of this channel. Otherwise, it returns
// false.
func (lc *LightningChannel) IsInitiator() bool {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.IsInitiator
}
// CommitFeeRate returns the current fee rate of the commitment transaction in
// units of sat-per-kw.
func (lc *LightningChannel) CommitFeeRate() chainfee.SatPerKWeight {
lc.RLock()
defer lc.RUnlock()
return chainfee.SatPerKWeight(lc.channelState.LocalCommitment.FeePerKw)
}
// WorstCaseFeeRate returns the higher feerate from either the local commitment
// or the remote commitment.
func (lc *LightningChannel) WorstCaseFeeRate() chainfee.SatPerKWeight {
lc.RLock()
defer lc.RUnlock()
localFeeRate := lc.channelState.LocalCommitment.FeePerKw
remoteFeeRate := lc.channelState.RemoteCommitment.FeePerKw
if localFeeRate > remoteFeeRate {
return chainfee.SatPerKWeight(localFeeRate)
}
return chainfee.SatPerKWeight(remoteFeeRate)
}
// IsPending returns true if the channel's funding transaction has been fully
// confirmed, and false otherwise.
func (lc *LightningChannel) IsPending() bool {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.IsPending
}
// State provides access to the channel's internal state.
func (lc *LightningChannel) State() *channeldb.OpenChannel {
return lc.channelState
}
// MarkBorked marks the event when the channel as reached an irreconcilable
// state, such as a channel breach or state desynchronization. Borked channels
// should never be added to the switch.
func (lc *LightningChannel) MarkBorked() error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkBorked()
}
// MarkCommitmentBroadcasted marks the channel as a commitment transaction has
// been broadcast, either our own or the remote, and we should watch the chain
// for it to confirm before taking any further action. It takes a boolean which
// indicates whether we initiated the close.
func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx,
closer lntypes.ChannelParty) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkCommitmentBroadcasted(tx, closer)
}
// MarkCoopBroadcasted marks the channel as a cooperative close transaction has
// been broadcast, and that we should watch the chain for it to confirm before
// taking any further action. It takes a locally initiated bool which is true
// if we initiated the cooperative close.
func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx,
closer lntypes.ChannelParty) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkCoopBroadcasted(tx, closer)
}
// MarkShutdownSent persists the given ShutdownInfo. The existence of the
// ShutdownInfo represents the fact that the Shutdown message has been sent by
// us and so should be re-sent on re-establish.
func (lc *LightningChannel) MarkShutdownSent(
info *channeldb.ShutdownInfo) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkShutdownSent(info)
}
// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the
// passed commitPoint for use to retrieve funds in case the remote force closes
// the channel.
func (lc *LightningChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkDataLoss(commitPoint)
}
// ActiveHtlcs returns a slice of HTLC's which are currently active on *both*
// commitment transactions.
func (lc *LightningChannel) ActiveHtlcs() []channeldb.HTLC {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.ActiveHtlcs()
}
// LocalChanReserve returns our local ChanReserve requirement for the remote party.
func (lc *LightningChannel) LocalChanReserve() btcutil.Amount {
return lc.channelState.LocalChanCfg.ChanReserve
}
// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure
// this always returns the next index that has been not been allocated, this
// will first try to examine any pending commitments, before falling back to the
// last locked-in local commitment.
func (lc *LightningChannel) NextLocalHtlcIndex() (uint64, error) {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.NextLocalHtlcIndex()
}
// FwdMinHtlc returns the minimum HTLC value required by the remote node, i.e.
// the minimum value HTLC we can forward on this channel.
func (lc *LightningChannel) FwdMinHtlc() lnwire.MilliSatoshi {
return lc.channelState.LocalChanCfg.MinHTLC
}
// unsignedLocalUpdates retrieves the unsigned local updates that we should
// store upon receiving a revocation. This function is called from
// ReceiveRevocation. remoteMessageIndex is the height into the local update
// log that the remote commitment chain tip includes. localMessageIndex
// is the height into the local update log that the local commitment tail
// includes. Our local updates that are unsigned by the remote should
// have height greater than or equal to localMessageIndex (not on our commit),
// and height less than remoteMessageIndex (on the remote commit).
//
// NOTE: remoteMessageIndex is the height on the tip because this is called
// before the tail is advanced to the tip during ReceiveRevocation.
func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex,
localMessageIndex uint64) []channeldb.LogUpdate {
var localPeerUpdates []channeldb.LogUpdate
for e := lc.updateLogs.Local.Front(); e != nil; e = e.Next() {
pd := e.Value
// We don't save add updates as they are restored from the
// remote commitment in restoreStateLogs.
if pd.EntryType == Add {
continue
}
// This is a settle/fail that is on the remote commitment, but
// not on the local commitment. We expect this update to be
// covered in the next commitment signature that the remote
// sends.
if pd.LogIndex < remoteMessageIndex && pd.LogIndex >= localMessageIndex {
localPeerUpdates = append(
localPeerUpdates, pd.toLogUpdate(),
)
}
}
return localPeerUpdates
}
// GenMusigNonces generates the verification nonce to start off a new musig2
// channel session.
func (lc *LightningChannel) GenMusigNonces() (*musig2.Nonces, error) {
lc.Lock()
defer lc.Unlock()
var err error
// We pass in the current height+1 as this'll be the set of
// verification nonces we'll send to the party to create our _next_
// state.
lc.pendingVerificationNonce, err = channeldb.NewMusigVerificationNonce(
lc.channelState.LocalChanCfg.MultiSigKey.PubKey,
lc.currentHeight+1, lc.taprootNonceProducer,
)
if err != nil {
return nil, err
}
return lc.pendingVerificationNonce, nil
}
// HasRemoteNonces returns true if the channel has a remote nonce pair.
func (lc *LightningChannel) HasRemoteNonces() bool {
return lc.musigSessions != nil
}
// InitRemoteMusigNonces processes the remote musig nonces sent by the remote
// party. This should be called upon connection re-establishment, after we've
// generated our own nonces. Once this method returns a nil error, then the
// channel can be used to sign commitment states.
func (lc *LightningChannel) InitRemoteMusigNonces(remoteNonce *musig2.Nonces,
) error {
lc.Lock()
defer lc.Unlock()
if lc.pendingVerificationNonce == nil {
return fmt.Errorf("pending verification nonce is not set")
}
// Now that we have the set of local and remote nonces, we can generate
// a new pair of musig sessions for our local commitment and the
// commitment of the remote party.
localNonce := lc.pendingVerificationNonce
localChanCfg := lc.channelState.LocalChanCfg
remoteChanCfg := lc.channelState.RemoteChanCfg
// TODO(roasbeef): propagate rename of signing and verification nonces
sessionCfg := &MusigSessionCfg{
LocalKey: localChanCfg.MultiSigKey,
RemoteKey: remoteChanCfg.MultiSigKey,
LocalNonce: *localNonce,
RemoteNonce: *remoteNonce,
Signer: lc.Signer,
InputTxOut: &lc.fundingOutput,
TapscriptTweak: lc.channelState.TapscriptRoot,
}
lc.musigSessions = NewMusigPairSession(
sessionCfg,
)
lc.pendingVerificationNonce = nil
lc.opts.localNonce = nil
lc.opts.remoteNonce = nil
return nil
}
// ChanType returns the channel type.
func (lc *LightningChannel) ChanType() channeldb.ChannelType {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.ChanType
}
// Initiator returns the ChannelParty that originally opened this channel.
func (lc *LightningChannel) Initiator() lntypes.ChannelParty {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.Initiator()
}
// FundingTxOut returns the funding output of the channel.
func (lc *LightningChannel) FundingTxOut() *wire.TxOut {
lc.RLock()
defer lc.RUnlock()
return &lc.fundingOutput
}
// DeriveHeightHint derives the block height for the channel opening.
func (lc *LightningChannel) DeriveHeightHint() uint32 {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.DeriveHeightHint()
}
// MultiSigKeys returns the set of multi-sig keys for an channel.
func (lc *LightningChannel) MultiSigKeys() (keychain.KeyDescriptor,
keychain.KeyDescriptor) {
lc.RLock()
defer lc.RUnlock()
return lc.channelState.LocalChanCfg.MultiSigKey,
lc.channelState.RemoteChanCfg.MultiSigKey
}
// LocalCommitmentBlob returns the custom blob of the local commitment.
func (lc *LightningChannel) LocalCommitmentBlob() fn.Option[tlv.Blob] {
lc.RLock()
defer lc.RUnlock()
chanState := lc.channelState
localBalance := chanState.LocalCommitment.CustomBlob
return fn.MapOption(func(b tlv.Blob) tlv.Blob {
newBlob := make([]byte, len(b))
copy(newBlob, b)
return newBlob
})(localBalance)
}
// FundingBlob returns the funding custom blob.
func (lc *LightningChannel) FundingBlob() fn.Option[tlv.Blob] {
lc.RLock()
defer lc.RUnlock()
return fn.MapOption(func(b tlv.Blob) tlv.Blob {
newBlob := make([]byte, len(b))
copy(newBlob, b)
return newBlob
})(lc.channelState.CustomBlob)
}
// ZeroConfRealScid returns an optional real scid for the channel. If this
// returns None, then this isn't a zero conf channel. Otherwise, the real scid
// value will be returned.
//
//nolint:ll
func (lc *LightningChannel) ZeroConfRealScid() fn.Option[lnwire.ShortChannelID] {
if lc.channelState.IsZeroConf() {
return fn.Some(lc.channelState.ZeroConfRealScid())
}
return fn.None[lnwire.ShortChannelID]()
}
package chanvalidate
import (
"bytes"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrInvalidOutPoint is returned when the ChanLocator is unable to
// find the target outpoint.
ErrInvalidOutPoint = fmt.Errorf("output meant to create channel cannot " +
"be found")
// ErrWrongPkScript is returned when the alleged funding transaction is
// found to have an incorrect pkSript.
ErrWrongPkScript = fmt.Errorf("wrong pk script")
// ErrInvalidSize is returned when the alleged funding transaction
// output has the wrong size (channel capacity).
ErrInvalidSize = fmt.Errorf("channel has wrong size")
)
// ErrScriptValidateError is returned when Script VM validation fails for an
// alleged channel output.
type ErrScriptValidateError struct {
err error
}
// Error returns a human readable string describing the error.
func (e *ErrScriptValidateError) Error() string {
return fmt.Sprintf("script validation failed: %v", e.err)
}
// Unwrap returns the underlying wrapped VM execution failure error.
func (e *ErrScriptValidateError) Unwrap() error {
return e.err
}
// ChanLocator abstracts away obtaining the output that created the channel, as
// well as validating its existence given the funding transaction. We need
// this as there are several ways (outpoint, short chan ID) to identify the
// output of a channel given the funding transaction.
type ChanLocator interface {
// Locate attempts to locate the funding output within the funding
// transaction. It also returns the final out point of the channel
// which uniquely identifies the output which creates the channel. If
// the target output cannot be found, or cannot exist on the funding
// transaction, then an error is to be returned.
Locate(*wire.MsgTx) (*wire.TxOut, *wire.OutPoint, error)
}
// OutPointChanLocator is an implementation of the ChanLocator that can be used
// when one already knows the expected chan point.
type OutPointChanLocator struct {
// ChanPoint is the expected chan point.
ChanPoint wire.OutPoint
}
// Locate attempts to locate the funding output within the passed funding
// transaction.
//
// NOTE: Part of the ChanLocator interface.
func (o *OutPointChanLocator) Locate(fundingTx *wire.MsgTx) (
*wire.TxOut, *wire.OutPoint, error) {
// If the expected index is greater than the amount of output in the
// transaction, then we'll reject this channel as it's invalid.
if int(o.ChanPoint.Index) >= len(fundingTx.TxOut) {
return nil, nil, ErrInvalidOutPoint
}
// As an extra sanity check, we'll also ensure the txid hash matches.
fundingHash := fundingTx.TxHash()
if !bytes.Equal(fundingHash[:], o.ChanPoint.Hash[:]) {
return nil, nil, ErrInvalidOutPoint
}
return fundingTx.TxOut[o.ChanPoint.Index], &o.ChanPoint, nil
}
// ShortChanIDChanLocator is an implementation of the ChanLocator that can be
// used when one only knows the short channel ID of a channel. This should be
// used in contexts when one is verifying a 3rd party channel.
type ShortChanIDChanLocator struct {
// ID is the short channel ID of the target channel.
ID lnwire.ShortChannelID
}
// Locate attempts to locate the funding output within the passed funding
// transaction.
//
// NOTE: Part of the ChanLocator interface.
func (s *ShortChanIDChanLocator) Locate(fundingTx *wire.MsgTx) (
*wire.TxOut, *wire.OutPoint, error) {
// If the expected index is greater than the amount of output in the
// transaction, then we'll reject this channel as it's invalid.
outputIndex := s.ID.TxPosition
if int(outputIndex) >= len(fundingTx.TxOut) {
return nil, nil, ErrInvalidOutPoint
}
chanPoint := wire.OutPoint{
Hash: fundingTx.TxHash(),
Index: uint32(outputIndex),
}
return fundingTx.TxOut[outputIndex], &chanPoint, nil
}
// CommitmentContext is optional validation context that can be passed into the
// main Validate for self-owned channel. The information in this context allows
// us to fully verify out initial commitment spend based on the on-chain state
// of the funding output.
type CommitmentContext struct {
// Value is the known size of the channel.
Value btcutil.Amount
// FullySignedCommitTx is the fully signed commitment transaction. This
// should include a valid witness.
FullySignedCommitTx *wire.MsgTx
}
// Context is the main validation contxet. For a given channel, all fields but
// the optional CommitCtx should be populated based on existing
// known-to-be-valid parameters.
type Context struct {
// Locator is a concrete implementation of the ChanLocator interface.
Locator ChanLocator
// MultiSigPkScript is the fully serialized witness script of the
// multi-sig output. This is the final witness program that should be
// found in the funding output.
MultiSigPkScript []byte
// FundingTx is channel funding transaction as found confirmed in the
// chain.
FundingTx *wire.MsgTx
// CommitCtx is an optional additional set of validation context
// required to validate a self-owned channel. If present, then a full
// Script VM validation will be performed.
CommitCtx *CommitmentContext
}
// Validate given the specified context, this function validates that the
// alleged channel is well formed, and spendable (if the optional CommitCtx is
// specified). If this method returns an error, then the alleged channel is
// invalid and should be abandoned immediately.
func Validate(ctx *Context) (*wire.OutPoint, error) {
// First, we'll attempt to locate the target outpoint in the funding
// transaction. If this returns an error, then we know that the
// outpoint doesn't actually exist, so we'll exit early.
fundingOutput, chanPoint, err := ctx.Locator.Locate(
ctx.FundingTx,
)
if err != nil {
return nil, err
}
// The scripts should match up exactly, otherwise the channel is
// invalid.
fundingScript := fundingOutput.PkScript
if !bytes.Equal(ctx.MultiSigPkScript, fundingScript) {
return nil, ErrWrongPkScript
}
// If there's no commitment context, then we're done here as this is a
// 3rd party channel.
if ctx.CommitCtx == nil {
return chanPoint, nil
}
// Now that we know this is our channel, we'll verify the amount of the
// created output against our expected size of the channel.
fundingValue := fundingOutput.Value
if btcutil.Amount(fundingValue) != ctx.CommitCtx.Value {
return nil, ErrInvalidSize
}
// If we reach this point, then all other checks have succeeded, so
// we'll now attempt a full Script VM execution to ensure that we're
// able to close the channel using this initial state.
prevFetcher := txscript.NewCannedPrevOutputFetcher(
ctx.MultiSigPkScript, fundingValue,
)
commitTx := ctx.CommitCtx.FullySignedCommitTx
hashCache := txscript.NewTxSigHashes(commitTx, prevFetcher)
vm, err := txscript.NewEngine(
ctx.MultiSigPkScript, commitTx, 0, txscript.StandardVerifyFlags,
nil, hashCache, fundingValue, prevFetcher,
)
if err != nil {
return nil, err
}
// Finally, we'll attempt to verify our full spend, if this fails then
// the channel is definitely invalid.
err = vm.Execute()
if err != nil {
return nil, &ErrScriptValidateError{err: err}
}
return chanPoint, nil
}
package lnwallet
import (
"bytes"
"sort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
// InPlaceCommitSort performs an in-place sort of a commitment transaction,
// given an unsorted transaction and a list of CLTV values for the HTLCs.
//
// The sort applied is a modified BIP69 sort, that uses the CLTV values of HTLCs
// as a tie breaker in case two HTLC outputs have an identical amount and
// pkscript. The pkscripts can be the same if they share the same payment hash,
// but since the CLTV is enforced via the nLockTime of the second-layer
// transactions, the script does not directly commit to them. Instead, the CLTVs
// must be supplied separately to act as a tie-breaker, otherwise we may produce
// invalid HTLC signatures if the receiver produces an alternative ordering
// during verification.
//
// NOTE: Commitment outputs should have a 0 CLTV corresponding to their index on
// the unsorted commitment transaction.
func InPlaceCommitSort(tx *wire.MsgTx, cltvs []uint32) {
if len(tx.TxOut) != len(cltvs) {
panic("output and cltv list size mismatch")
}
sort.Sort(sortableInputSlice(tx.TxIn))
sort.Sort(sortableCommitOutputSlice{tx.TxOut, cltvs})
}
// sortableInputSlice is a slice of transaction inputs that supports sorting via
// BIP69.
type sortableInputSlice []*wire.TxIn
// Len returns the length of the sortableInputSlice.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableInputSlice) Len() int { return len(s) }
// Swap exchanges the position of inputs i and j.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableInputSlice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// Less is the BIP69 input comparison function. The sort is first applied on
// input hash (reversed / rpc-style), then index. This logic is copied from
// btcutil/txsort.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableInputSlice) Less(i, j int) bool {
// Input hashes are the same, so compare the index.
ihash := s[i].PreviousOutPoint.Hash
jhash := s[j].PreviousOutPoint.Hash
if ihash == jhash {
return s[i].PreviousOutPoint.Index < s[j].PreviousOutPoint.Index
}
// At this point, the hashes are not equal, so reverse them to
// big-endian and return the result of the comparison.
const hashSize = chainhash.HashSize
for b := 0; b < hashSize/2; b++ {
ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b]
jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b]
}
return bytes.Compare(ihash[:], jhash[:]) == -1
}
// sortableCommitOutputSlice is a slice of transaction outputs on a commitment
// transaction and the corresponding CLTV values of any HTLCs. Commitment
// outputs should have a CLTV of 0 and the same index in cltvs.
type sortableCommitOutputSlice struct {
txouts []*wire.TxOut
cltvs []uint32
}
// Len returns the length of the sortableCommitOutputSlice.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableCommitOutputSlice) Len() int {
return len(s.txouts)
}
// Swap exchanges the position of outputs i and j, as well as their
// corresponding CLTV values.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableCommitOutputSlice) Swap(i, j int) {
s.txouts[i], s.txouts[j] = s.txouts[j], s.txouts[i]
s.cltvs[i], s.cltvs[j] = s.cltvs[j], s.cltvs[i]
}
// Less is a modified BIP69 output comparison, that sorts based on value, then
// pkscript, then CLTV value.
//
// NOTE: Part of the sort.Interface interface.
func (s sortableCommitOutputSlice) Less(i, j int) bool {
outi, outj := s.txouts[i], s.txouts[j]
if outi.Value != outj.Value {
return outi.Value < outj.Value
}
pkScriptCmp := bytes.Compare(outi.PkScript, outj.PkScript)
if pkScriptCmp != 0 {
return pkScriptCmp < 0
}
return s.cltvs[i] < s.cltvs[j]
}
package lnwallet
import (
"bytes"
"fmt"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
)
// AnchorSize is the constant anchor output size.
const AnchorSize = btcutil.Amount(330)
// DefaultAnchorsCommitMaxFeeRateSatPerVByte is the default max fee rate in
// sat/vbyte the initiator will use for anchor channels. This should be enough
// to ensure propagation before anchoring down the commitment transaction.
const DefaultAnchorsCommitMaxFeeRateSatPerVByte = 10
// CommitmentKeyRing holds all derived keys needed to construct commitment and
// HTLC transactions. The keys are derived differently depending whether the
// commitment transaction is ours or the remote peer's. Private keys associated
// with each key may belong to the commitment owner or the "other party" which
// is referred to in the field comments, regardless of which is local and which
// is remote.
type CommitmentKeyRing struct {
// CommitPoint is the "per commitment point" used to derive the tweak
// for each base point.
CommitPoint *btcec.PublicKey
// LocalCommitKeyTweak is the tweak used to derive the local public key
// from the local payment base point or the local private key from the
// base point secret. This may be included in a SignDescriptor to
// generate signatures for the local payment key.
//
// NOTE: This will always refer to "our" local key, regardless of
// whether this is our commit or not.
LocalCommitKeyTweak []byte
// TODO(roasbeef): need delay tweak as well?
// LocalHtlcKeyTweak is the tweak used to derive the local HTLC key
// from the local HTLC base point. This value is needed in order to
// derive the final key used within the HTLC scripts in the commitment
// transaction.
//
// NOTE: This will always refer to "our" local HTLC key, regardless of
// whether this is our commit or not.
LocalHtlcKeyTweak []byte
// LocalHtlcKey is the key that will be used in any clause paying to
// our node of any HTLC scripts within the commitment transaction for
// this key ring set.
//
// NOTE: This will always refer to "our" local HTLC key, regardless of
// whether this is our commit or not.
LocalHtlcKey *btcec.PublicKey
// RemoteHtlcKey is the key that will be used in clauses within the
// HTLC script that send money to the remote party.
//
// NOTE: This will always refer to "their" remote HTLC key, regardless
// of whether this is our commit or not.
RemoteHtlcKey *btcec.PublicKey
// ToLocalKey is the commitment transaction owner's key which is
// included in HTLC success and timeout transaction scripts. This is
// the public key used for the to_local output of the commitment
// transaction.
//
// NOTE: Who's key this is depends on the current perspective. If this
// is our commitment this will be our key.
ToLocalKey *btcec.PublicKey
// ToRemoteKey is the non-owner's payment key in the commitment tx.
// This is the key used to generate the to_remote output within the
// commitment transaction.
//
// NOTE: Who's key this is depends on the current perspective. If this
// is our commitment this will be their key.
ToRemoteKey *btcec.PublicKey
// RevocationKey is the key that can be used by the other party to
// redeem outputs from a revoked commitment transaction if it were to
// be published.
//
// NOTE: Who can sign for this key depends on the current perspective.
// If this is our commitment, it means the remote node can sign for
// this key in case of a breach.
RevocationKey *btcec.PublicKey
}
// DeriveCommitmentKeys generates a new commitment key set using the base points
// and commitment point. The keys are derived differently depending on the type
// of channel, and whether the commitment transaction is ours or the remote
// peer's.
func DeriveCommitmentKeys(commitPoint *btcec.PublicKey,
whoseCommit lntypes.ChannelParty, chanType channeldb.ChannelType,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig) *CommitmentKeyRing {
tweaklessCommit := chanType.IsTweakless()
// Depending on if this is our commit or not, we'll choose the correct
// base point.
localBasePoint := localChanCfg.PaymentBasePoint
if whoseCommit.IsLocal() {
localBasePoint = localChanCfg.DelayBasePoint
}
// First, we'll derive all the keys that don't depend on the context of
// whose commitment transaction this is.
keyRing := &CommitmentKeyRing{
CommitPoint: commitPoint,
LocalCommitKeyTweak: input.SingleTweakBytes(
commitPoint, localBasePoint.PubKey,
),
LocalHtlcKeyTweak: input.SingleTweakBytes(
commitPoint, localChanCfg.HtlcBasePoint.PubKey,
),
LocalHtlcKey: input.TweakPubKey(
localChanCfg.HtlcBasePoint.PubKey, commitPoint,
),
RemoteHtlcKey: input.TweakPubKey(
remoteChanCfg.HtlcBasePoint.PubKey, commitPoint,
),
}
// We'll now compute the to_local, to_remote, and revocation key based
// on the current commitment point. All keys are tweaked each state in
// order to ensure the keys from each state are unlinkable. To create
// the revocation key, we take the opposite party's revocation base
// point and combine that with the current commitment point.
var (
toLocalBasePoint *btcec.PublicKey
toRemoteBasePoint *btcec.PublicKey
revocationBasePoint *btcec.PublicKey
)
if whoseCommit.IsLocal() {
toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey
toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey
revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey
} else {
toLocalBasePoint = remoteChanCfg.DelayBasePoint.PubKey
toRemoteBasePoint = localChanCfg.PaymentBasePoint.PubKey
revocationBasePoint = localChanCfg.RevocationBasePoint.PubKey
}
// With the base points assigned, we can now derive the actual keys
// using the base point, and the current commitment tweak.
keyRing.ToLocalKey = input.TweakPubKey(toLocalBasePoint, commitPoint)
keyRing.RevocationKey = input.DeriveRevocationPubkey(
revocationBasePoint, commitPoint,
)
// If this commitment should omit the tweak for the remote point, then
// we'll use that directly, and ignore the commitPoint tweak.
if tweaklessCommit {
keyRing.ToRemoteKey = toRemoteBasePoint
// If this is not our commitment, the above ToRemoteKey will be
// ours, and we blank out the local commitment tweak to
// indicate that the key should not be tweaked when signing.
if whoseCommit.IsRemote() {
keyRing.LocalCommitKeyTweak = nil
}
} else {
keyRing.ToRemoteKey = input.TweakPubKey(
toRemoteBasePoint, commitPoint,
)
}
return keyRing
}
// WitnessScriptDesc holds the output script and the witness script for p2wsh
// outputs.
type WitnessScriptDesc struct {
// OutputScript is the output's PkScript.
OutputScript []byte
// WitnessScript is the full script required to properly redeem the
// output. This field should be set to the full script if a p2wsh
// output is being signed. For p2wkh it should be set equal to the
// PkScript.
WitnessScript []byte
}
// PkScript is the public key script that commits to the final
// contract.
func (w *WitnessScriptDesc) PkScript() []byte {
return w.OutputScript
}
// WitnessScriptToSign returns the witness script that we'll use when signing
// for the remote party, and also verifying signatures on our transactions. As
// an example, when we create an outgoing HTLC for the remote party, we want to
// sign their success path.
func (w *WitnessScriptDesc) WitnessScriptToSign() []byte {
return w.WitnessScript
}
// WitnessScriptForPath returns the witness script for the given spending path.
// An error is returned if the path is unknown. This is useful as when
// constructing a control block for a given path, one also needs witness script
// being signed.
func (w *WitnessScriptDesc) WitnessScriptForPath(
_ input.ScriptPath) ([]byte, error) {
return w.WitnessScript, nil
}
// CommitScriptToSelf constructs the public key script for the output on the
// commitment transaction paying to the "owner" of said commitment transaction.
// The `initiator` argument should correspond to the owner of the commitment
// transaction which we are generating the to_local script for. If the other
// party learns of the preimage to the revocation hash, then they can claim all
// the settled funds in the channel, plus the unsettled funds.
func CommitScriptToSelf(chanType channeldb.ChannelType, initiator bool,
selfKey, revokeKey *btcec.PublicKey, csvDelay, leaseExpiry uint32,
auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) {
switch {
// For taproot scripts, we'll need to make a slightly modified script
// where a NUMS key is used to force a script path reveal of either the
// revocation or the CSV timeout.
//
// Our "redeem" script here is just the taproot witness program.
case chanType.IsTaproot():
return input.NewLocalCommitScriptTree(
csvDelay, selfKey, revokeKey, auxLeaf,
)
// If we are the initiator of a leased channel, then we have an
// additional CLTV requirement in addition to the usual CSV
// requirement.
case initiator && chanType.HasLeaseExpiration():
toLocalRedeemScript, err := input.LeaseCommitScriptToSelf(
selfKey, revokeKey, csvDelay, leaseExpiry,
)
if err != nil {
return nil, err
}
toLocalScriptHash, err := input.WitnessScriptHash(
toLocalRedeemScript,
)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: toLocalScriptHash,
WitnessScript: toLocalRedeemScript,
}, nil
default:
toLocalRedeemScript, err := input.CommitScriptToSelf(
csvDelay, selfKey, revokeKey,
)
if err != nil {
return nil, err
}
toLocalScriptHash, err := input.WitnessScriptHash(
toLocalRedeemScript,
)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: toLocalScriptHash,
WitnessScript: toLocalRedeemScript,
}, nil
}
}
// CommitScriptToRemote derives the appropriate to_remote script based on the
// channel's commitment type. The `initiator` argument should correspond to the
// owner of the commitment transaction which we are generating the to_remote
// script for. The second return value is the CSV delay of the output script,
// what must be satisfied in order to spend the output.
func CommitScriptToRemote(chanType channeldb.ChannelType, initiator bool,
remoteKey *btcec.PublicKey, leaseExpiry uint32,
auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, uint32, error) {
switch {
// If we are not the initiator of a leased channel, then the remote
// party has an additional CLTV requirement in addition to the 1 block
// CSV requirement.
case chanType.HasLeaseExpiration() && !initiator:
script, err := input.LeaseCommitScriptToRemoteConfirmed(
remoteKey, leaseExpiry,
)
if err != nil {
return nil, 0, err
}
p2wsh, err := input.WitnessScriptHash(script)
if err != nil {
return nil, 0, err
}
return &WitnessScriptDesc{
OutputScript: p2wsh,
WitnessScript: script,
}, 1, nil
// For taproot channels, we'll use a slightly different format, where
// we use a NUMS key to force the remote party to take a script path,
// with the sole tap leaf enforcing the 1 CSV delay.
case chanType.IsTaproot():
toRemoteScriptTree, err := input.NewRemoteCommitScriptTree(
remoteKey, auxLeaf,
)
if err != nil {
return nil, 0, err
}
return toRemoteScriptTree, 1, nil
// If this channel type has anchors, we derive the delayed to_remote
// script.
case chanType.HasAnchors():
script, err := input.CommitScriptToRemoteConfirmed(remoteKey)
if err != nil {
return nil, 0, err
}
p2wsh, err := input.WitnessScriptHash(script)
if err != nil {
return nil, 0, err
}
return &WitnessScriptDesc{
OutputScript: p2wsh,
WitnessScript: script,
}, 1, nil
default:
// Otherwise the to_remote will be a simple p2wkh.
p2wkh, err := input.CommitScriptUnencumbered(remoteKey)
if err != nil {
return nil, 0, err
}
// Since this is a regular P2WKH, the WitnessScipt and PkScript
// should both be set to the script hash.
return &WitnessScriptDesc{
OutputScript: p2wkh,
WitnessScript: p2wkh,
}, 0, nil
}
}
// HtlcSigHashType returns the sighash type to use for HTLC success and timeout
// transactions given the channel type.
func HtlcSigHashType(chanType channeldb.ChannelType) txscript.SigHashType {
if chanType.HasAnchors() {
return txscript.SigHashSingle | txscript.SigHashAnyOneCanPay
}
return txscript.SigHashAll
}
// HtlcSignDetails converts the passed parameters to a SignDetails valid for
// this channel type. For non-anchor channels this will return nil.
func HtlcSignDetails(chanType channeldb.ChannelType, signDesc input.SignDescriptor,
sigHash txscript.SigHashType, peerSig input.Signature) *input.SignDetails {
// Non-anchor channels don't need sign details, as the HTLC second
// level cannot be altered.
if !chanType.HasAnchors() {
return nil
}
return &input.SignDetails{
SignDesc: signDesc,
SigHashType: sigHash,
PeerSig: peerSig,
}
}
// HtlcSecondLevelInputSequence dictates the sequence number we must use on the
// input to a second level HTLC transaction.
func HtlcSecondLevelInputSequence(chanType channeldb.ChannelType) uint32 {
if chanType.HasAnchors() {
return 1
}
return 0
}
// sweepSigHash returns the sign descriptor to use when signing a sweep
// transaction. For taproot channels, we'll use this to always sweep with
// sighash default.
func sweepSigHash(chanType channeldb.ChannelType) txscript.SigHashType {
if chanType.IsTaproot() {
return txscript.SigHashDefault
}
return txscript.SigHashAll
}
// SecondLevelHtlcScript derives the appropriate second level HTLC script based
// on the channel's commitment type. It is the uniform script that's used as the
// output for the second-level HTLC transactions. The second level transaction
// act as a sort of covenant, ensuring that a 2-of-2 multi-sig output can only
// be spent in a particular way, and to a particular output. The `initiator`
// argument should correspond to the owner of the commitment transaction which
// we are generating the to_local script for.
func SecondLevelHtlcScript(chanType channeldb.ChannelType, initiator bool,
revocationKey, delayKey *btcec.PublicKey, csvDelay, leaseExpiry uint32,
auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) {
switch {
// For taproot channels, the pkScript is a segwit v1 p2tr output.
case chanType.IsTaproot():
return input.TaprootSecondLevelScriptTree(
revocationKey, delayKey, csvDelay, auxLeaf,
)
// If we are the initiator of a leased channel, then we have an
// additional CLTV requirement in addition to the usual CSV
// requirement.
case initiator && chanType.HasLeaseExpiration():
witnessScript, err := input.LeaseSecondLevelHtlcScript(
revocationKey, delayKey, csvDelay, leaseExpiry,
)
if err != nil {
return nil, err
}
pkScript, err := input.WitnessScriptHash(witnessScript)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: pkScript,
WitnessScript: witnessScript,
}, nil
default:
witnessScript, err := input.SecondLevelHtlcScript(
revocationKey, delayKey, csvDelay,
)
if err != nil {
return nil, err
}
pkScript, err := input.WitnessScriptHash(witnessScript)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: pkScript,
WitnessScript: witnessScript,
}, nil
}
}
// CommitWeight returns the base commitment weight before adding HTLCs.
func CommitWeight(chanType channeldb.ChannelType) lntypes.WeightUnit {
switch {
case chanType.IsTaproot():
return input.TaprootCommitWeight
// If this commitment has anchors, it will be slightly heavier.
case chanType.HasAnchors():
return input.AnchorCommitWeight
default:
return input.CommitWeight
}
}
// HtlcTimeoutFee returns the fee in satoshis required for an HTLC timeout
// transaction based on the current fee rate.
func HtlcTimeoutFee(chanType channeldb.ChannelType,
feePerKw chainfee.SatPerKWeight) btcutil.Amount {
switch {
// For zero-fee HTLC channels, this will always be zero, regardless of
// feerate.
case chanType.ZeroHtlcTxFee() || chanType.IsTaproot():
return 0
case chanType.HasAnchors():
return feePerKw.FeeForWeight(input.HtlcTimeoutWeightConfirmed)
default:
return feePerKw.FeeForWeight(input.HtlcTimeoutWeight)
}
}
// HtlcSuccessFee returns the fee in satoshis required for an HTLC success
// transaction based on the current fee rate.
func HtlcSuccessFee(chanType channeldb.ChannelType,
feePerKw chainfee.SatPerKWeight) btcutil.Amount {
switch {
// For zero-fee HTLC channels, this will always be zero, regardless of
// feerate.
case chanType.ZeroHtlcTxFee() || chanType.IsTaproot():
return 0
case chanType.HasAnchors():
return feePerKw.FeeForWeight(input.HtlcSuccessWeightConfirmed)
default:
return feePerKw.FeeForWeight(input.HtlcSuccessWeight)
}
}
// CommitScriptAnchors return the scripts to use for the local and remote
// anchor.
func CommitScriptAnchors(chanType channeldb.ChannelType,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig,
keyRing *CommitmentKeyRing) (
input.ScriptDescriptor, input.ScriptDescriptor, error) {
var (
anchorScript func(
key *btcec.PublicKey) (input.ScriptDescriptor, error)
keySelector func(*channeldb.ChannelConfig,
bool) *btcec.PublicKey
)
switch {
// For taproot channels, the anchor is slightly different: the top
// level key is now the (relative) local delay and remote public key,
// since these are fully revealed once the commitment hits the chain.
case chanType.IsTaproot():
anchorScript = func(
key *btcec.PublicKey) (input.ScriptDescriptor, error) {
return input.NewAnchorScriptTree(key)
}
keySelector = func(cfg *channeldb.ChannelConfig,
local bool) *btcec.PublicKey {
if local {
return keyRing.ToLocalKey
}
return keyRing.ToRemoteKey
}
// For normal channels we'll use the multi-sig keys since those are
// revealed when the channel closes
default:
// For normal channels, we'll create a p2wsh script based on
// the target key.
anchorScript = func(
key *btcec.PublicKey) (input.ScriptDescriptor, error) {
script, err := input.CommitScriptAnchor(key)
if err != nil {
return nil, err
}
scriptHash, err := input.WitnessScriptHash(script)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: scriptHash,
WitnessScript: script,
}, nil
}
// For the existing channels, we'll always select the multi-sig
// key from the party's channel config.
keySelector = func(cfg *channeldb.ChannelConfig,
_ bool) *btcec.PublicKey {
return cfg.MultiSigKey.PubKey
}
}
// Get the script used for the anchor output spendable by the local
// node.
localAnchor, err := anchorScript(keySelector(localChanCfg, true))
if err != nil {
return nil, nil, err
}
// And the anchor spendable by the remote node.
remoteAnchor, err := anchorScript(keySelector(remoteChanCfg, false))
if err != nil {
return nil, nil, err
}
return localAnchor, remoteAnchor, nil
}
// CommitmentBuilder is a type that wraps the type of channel we are dealing
// with, and abstracts the various ways of constructing commitment
// transactions.
type CommitmentBuilder struct {
// chanState is the underlying channel's state struct, used to
// determine the type of channel we are dealing with, and relevant
// parameters.
chanState *channeldb.OpenChannel
// obfuscator is a 48-bit state hint that's used to obfuscate the
// current state number on the commitment transactions.
obfuscator [StateHintSize]byte
// auxLeafStore is an interface that allows us to fetch auxiliary
// tapscript leaves for the commitment output.
auxLeafStore fn.Option[AuxLeafStore]
}
// NewCommitmentBuilder creates a new CommitmentBuilder from chanState.
func NewCommitmentBuilder(chanState *channeldb.OpenChannel,
leafStore fn.Option[AuxLeafStore]) *CommitmentBuilder {
// The anchor channel type MUST be tweakless.
if chanState.ChanType.HasAnchors() && !chanState.ChanType.IsTweakless() {
panic("invalid channel type combination")
}
return &CommitmentBuilder{
chanState: chanState,
obfuscator: createStateHintObfuscator(chanState),
auxLeafStore: leafStore,
}
}
// createStateHintObfuscator derives and assigns the state hint obfuscator for
// the channel, which is used to encode the commitment height in the sequence
// number of commitment transaction inputs.
func createStateHintObfuscator(state *channeldb.OpenChannel) [StateHintSize]byte {
if state.IsInitiator {
return DeriveStateHintObfuscator(
state.LocalChanCfg.PaymentBasePoint.PubKey,
state.RemoteChanCfg.PaymentBasePoint.PubKey,
)
}
return DeriveStateHintObfuscator(
state.RemoteChanCfg.PaymentBasePoint.PubKey,
state.LocalChanCfg.PaymentBasePoint.PubKey,
)
}
// unsignedCommitmentTx is the final commitment created from evaluating an HTLC
// view at a given height, along with some meta data.
type unsignedCommitmentTx struct {
// txn is the final, unsigned commitment transaction for this view.
txn *wire.MsgTx
// fee is the total fee of the commitment transaction.
fee btcutil.Amount
// ourBalance is our balance on this commitment *after* subtracting
// commitment fees and anchor outputs. This can be different than the
// balances before creating the commitment transaction as one party must
// pay the commitment fee.
ourBalance lnwire.MilliSatoshi
// theirBalance is their balance of this commitment *after* subtracting
// commitment fees and anchor outputs. This can be different than the
// balances before creating the commitment transaction as one party must
// pay the commitment fee.
theirBalance lnwire.MilliSatoshi
// cltvs is a sorted list of CLTV deltas for each HTLC on the commitment
// transaction. Any non-htlc outputs will have a CLTV delay of zero.
cltvs []uint32
}
// createUnsignedCommitmentTx generates the unsigned commitment transaction for
// a commitment view and returns it as part of the unsignedCommitmentTx. The
// passed in balances should be balances *before* subtracting any commitment
// fees, but after anchor outputs.
func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty,
feePerKw chainfee.SatPerKWeight, height uint64, originalHtlcView,
filteredHTLCView *HtlcView, keyRing *CommitmentKeyRing,
prevCommit *commitment) (*unsignedCommitmentTx, error) {
dustLimit := cb.chanState.LocalChanCfg.DustLimit
if whoseCommit.IsRemote() {
dustLimit = cb.chanState.RemoteChanCfg.DustLimit
}
numHTLCs := int64(0)
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
numHTLCs++
}
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
numHTLCs++
}
// Next, we'll calculate the fee for the commitment transaction based
// on its total weight. Once we have the total weight, we'll multiply
// by the current fee-per-kw, then divide by 1000 to get the proper
// fee.
totalCommitWeight := CommitWeight(cb.chanState.ChanType) +
lntypes.WeightUnit(input.HTLCWeight*numHTLCs)
// With the weight known, we can now calculate the commitment fee,
// ensuring that we account for any dust outputs trimmed above.
commitFee := feePerKw.FeeForWeight(totalCommitWeight)
commitFeeMSat := lnwire.NewMSatFromSatoshis(commitFee)
// Currently, within the protocol, the initiator always pays the fees.
// So we'll subtract the fee amount from the balance of the current
// initiator. If the initiator is unable to pay the fee fully, then
// their entire output is consumed.
switch {
case cb.chanState.IsInitiator && commitFee > ourBalance.ToSatoshis():
ourBalance = 0
case cb.chanState.IsInitiator:
ourBalance -= commitFeeMSat
case !cb.chanState.IsInitiator && commitFee > theirBalance.ToSatoshis():
theirBalance = 0
case !cb.chanState.IsInitiator:
theirBalance -= commitFeeMSat
}
var commitTx *wire.MsgTx
// Before we create the commitment transaction below, we'll try to see
// if there're any aux leaves that need to be a part of the tapscript
// tree. We'll only do this if we have a custom blob defined though.
auxResult, err := fn.MapOptionZ(
cb.auxLeafStore,
func(s AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return auxLeavesFromView(
s, cb.chanState, prevCommit.customBlob,
originalHtlcView, whoseCommit, ourBalance,
theirBalance, *keyRing,
)
},
).Unpack()
if err != nil {
return nil, fmt.Errorf("unable to fetch aux leaves: %w", err)
}
// Depending on whether the transaction is ours or not, we call
// CreateCommitTx with parameters matching the perspective, to generate
// a new commitment transaction with all the latest unsettled/un-timed
// out HTLCs.
var leaseExpiry uint32
if cb.chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = cb.chanState.ThawHeight
}
if whoseCommit.IsLocal() {
commitTx, err = CreateCommitTx(
cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing,
&cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg,
ourBalance.ToSatoshis(), theirBalance.ToSatoshis(),
numHTLCs, cb.chanState.IsInitiator, leaseExpiry,
auxResult.AuxLeaves,
)
} else {
commitTx, err = CreateCommitTx(
cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing,
&cb.chanState.RemoteChanCfg, &cb.chanState.LocalChanCfg,
theirBalance.ToSatoshis(), ourBalance.ToSatoshis(),
numHTLCs, !cb.chanState.IsInitiator, leaseExpiry,
auxResult.AuxLeaves,
)
}
if err != nil {
return nil, err
}
// Similarly, we'll now attempt to extract the set of aux leaves for
// the set of incoming and outgoing HTLCs.
incomingAuxLeaves := fn.MapOption(
func(leaves CommitAuxLeaves) input.HtlcAuxLeaves {
return leaves.IncomingHtlcLeaves
},
)(auxResult.AuxLeaves)
outgoingAuxLeaves := fn.MapOption(
func(leaves CommitAuxLeaves) input.HtlcAuxLeaves {
return leaves.OutgoingHtlcLeaves
},
)(auxResult.AuxLeaves)
// We'll now add all the HTLC outputs to the commitment transaction.
// Each output includes an off-chain 2-of-2 covenant clause, so we'll
// need the objective local/remote keys for this particular commitment
// as well. For any non-dust HTLCs that are manifested on the commitment
// transaction, we'll also record its CLTV which is required to sort the
// commitment transaction below. The slice is initially sized to the
// number of existing outputs, since any outputs already added are
// commitment outputs and should correspond to zero values for the
// purposes of sorting.
cltvs := make([]uint32, len(commitTx.TxOut))
htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut))
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
auxLeaf := fn.FlatMapOption(
func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf {
return leaves[htlc.HtlcIndex].AuxTapLeaf
},
)(outgoingAuxLeaves)
err := addHTLC(
commitTx, whoseCommit, false, htlc, keyRing,
cb.chanState.ChanType, auxLeaf,
)
if err != nil {
return nil, err
}
// We want to add the CLTV and HTLC index to their respective
// slices, even if we already pre-allocated them.
cltvs = append(cltvs, htlc.Timeout) //nolint
htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint
}
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
) {
continue
}
auxLeaf := fn.FlatMapOption(
func(leaves input.HtlcAuxLeaves) input.AuxTapLeaf {
return leaves[htlc.HtlcIndex].AuxTapLeaf
},
)(incomingAuxLeaves)
err := addHTLC(
commitTx, whoseCommit, true, htlc, keyRing,
cb.chanState.ChanType, auxLeaf,
)
if err != nil {
return nil, err
}
// We want to add the CLTV and HTLC index to their respective
// slices, even if we already pre-allocated them.
cltvs = append(cltvs, htlc.Timeout) //nolint
htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint
}
// Set the state hint of the commitment transaction to facilitate
// quickly recovering the necessary penalty state in the case of an
// uncooperative broadcast.
err = SetStateNumHint(commitTx, height, cb.obfuscator)
if err != nil {
return nil, err
}
// Sort the transactions according to the agreed upon canonical
// ordering (which might be customized for custom channel types, but
// deterministic and both parties will arrive at the same result). This
// lets us skip sending the entire transaction over, instead we'll just
// send signatures.
commitSort := auxResult.CommitSortFunc.UnwrapOr(DefaultCommitSort)
err = commitSort(commitTx, cltvs, htlcIndexes)
if err != nil {
return nil, fmt.Errorf("unable to sort commitment "+
"transaction: %w", err)
}
// Next, we'll ensure that we don't accidentally create a commitment
// transaction which would be invalid by consensus.
uTx := btcutil.NewTx(commitTx)
if err := blockchain.CheckTransactionSanity(uTx); err != nil {
return nil, err
}
// Finally, we'll assert that were not attempting to draw more out of
// the channel that was originally placed within it.
var totalOut btcutil.Amount
for _, txOut := range commitTx.TxOut {
totalOut += btcutil.Amount(txOut.Value)
}
if totalOut+commitFee > cb.chanState.Capacity {
return nil, fmt.Errorf("height=%v, for ChannelPoint(%v) "+
"attempts to consume %v while channel capacity is %v",
height, cb.chanState.FundingOutpoint,
totalOut+commitFee, cb.chanState.Capacity)
}
return &unsignedCommitmentTx{
txn: commitTx,
fee: commitFee,
ourBalance: ourBalance,
theirBalance: theirBalance,
cltvs: cltvs,
}, nil
}
// CreateCommitTx creates a commitment transaction, spending from specified
// funding output. The commitment transaction contains two outputs: one local
// output paying to the "owner" of the commitment transaction which can be
// spent after a relative block delay or revocation event, and a remote output
// paying the counterparty within the channel, which can be spent immediately
// or after a delay depending on the commitment type. The `initiator` argument
// should correspond to the owner of the commitment transaction we are creating.
func CreateCommitTx(chanType channeldb.ChannelType,
fundingOutput wire.TxIn, keyRing *CommitmentKeyRing,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig,
amountToLocal, amountToRemote btcutil.Amount,
numHTLCs int64, initiator bool, leaseExpiry uint32,
auxLeaves fn.Option[CommitAuxLeaves]) (*wire.MsgTx, error) {
// First, we create the script for the delayed "pay-to-self" output.
// This output has 2 main redemption clauses: either we can redeem the
// output after a relative block delay, or the remote node can claim
// the funds with the revocation key if we broadcast a revoked
// commitment transaction.
localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
})(auxLeaves)
toLocalScript, err := CommitScriptToSelf(
chanType, initiator, keyRing.ToLocalKey, keyRing.RevocationKey,
uint32(localChanCfg.CsvDelay), leaseExpiry,
fn.FlattenOption(localAuxLeaf),
)
if err != nil {
return nil, err
}
// Next, we create the script paying to the remote.
remoteAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
})(auxLeaves)
toRemoteScript, _, err := CommitScriptToRemote(
chanType, initiator, keyRing.ToRemoteKey, leaseExpiry,
fn.FlattenOption(remoteAuxLeaf),
)
if err != nil {
return nil, err
}
// Now that both output scripts have been created, we can finally create
// the transaction itself. We use a transaction version of 2 since CSV
// will fail unless the tx version is >= 2.
commitTx := wire.NewMsgTx(2)
commitTx.AddTxIn(&fundingOutput)
// Avoid creating dust outputs within the commitment transaction.
localOutput := amountToLocal >= localChanCfg.DustLimit
if localOutput {
commitTx.AddTxOut(&wire.TxOut{
PkScript: toLocalScript.PkScript(),
Value: int64(amountToLocal),
})
}
remoteOutput := amountToRemote >= localChanCfg.DustLimit
if remoteOutput {
commitTx.AddTxOut(&wire.TxOut{
PkScript: toRemoteScript.PkScript(),
Value: int64(amountToRemote),
})
}
// If this channel type has anchors, we'll also add those.
if chanType.HasAnchors() {
localAnchor, remoteAnchor, err := CommitScriptAnchors(
chanType, localChanCfg, remoteChanCfg, keyRing,
)
if err != nil {
return nil, err
}
// Add local anchor output only if we have a commitment output
// or there are HTLCs.
if localOutput || numHTLCs > 0 {
commitTx.AddTxOut(&wire.TxOut{
PkScript: localAnchor.PkScript(),
Value: int64(AnchorSize),
})
}
// Add anchor output to remote only if they have a commitment
// output or there are HTLCs.
if remoteOutput || numHTLCs > 0 {
commitTx.AddTxOut(&wire.TxOut{
PkScript: remoteAnchor.PkScript(),
Value: int64(AnchorSize),
})
}
}
return commitTx, nil
}
// CoopCloseBalance returns the final balances that should be used to create
// the cooperative close tx, given the channel type and transaction fee.
func CoopCloseBalance(chanType channeldb.ChannelType, isInitiator bool,
coopCloseFee, ourBalance, theirBalance, commitFee btcutil.Amount,
feePayer fn.Option[lntypes.ChannelParty],
) (btcutil.Amount, btcutil.Amount, error) {
// We'll make sure we account for the complete balance by adding the
// current dangling commitment fee to the balance of the initiator.
initiatorDelta := commitFee
// Since the initiator's balance also is stored after subtracting the
// anchor values, add that back in case this was an anchor commitment.
if chanType.HasAnchors() {
initiatorDelta += 2 * AnchorSize
}
// To start with, we'll add the anchor and/or commitment fee to the
// balance of the initiator.
if isInitiator {
ourBalance += initiatorDelta
} else {
theirBalance += initiatorDelta
}
// With the initiator's balance credited, we'll now subtract the closing
// fee from the closing party. By default, the initiator pays the full
// amount, but this can be overridden by the feePayer option.
defaultPayer := func() lntypes.ChannelParty {
if isInitiator {
return lntypes.Local
}
return lntypes.Remote
}()
payer := feePayer.UnwrapOr(defaultPayer)
// Based on the payer computed above, we'll subtract the closing fee.
switch payer {
case lntypes.Local:
ourBalance -= coopCloseFee
case lntypes.Remote:
theirBalance -= coopCloseFee
}
// During fee negotiation it should always be verified that the
// initiator can pay the proposed fee, but we do a sanity check just to
// be sure here.
if ourBalance < 0 || theirBalance < 0 {
return 0, 0, fmt.Errorf("initiator cannot afford proposed " +
"coop close fee")
}
return ourBalance, theirBalance, nil
}
// genSegwitV0HtlcScript generates the HTLC scripts for a normal segwit v0
// channel.
func genSegwitV0HtlcScript(chanType channeldb.ChannelType,
isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32,
rHash [32]byte, keyRing *CommitmentKeyRing,
) (*WitnessScriptDesc, error) {
var (
witnessScript []byte
err error
)
// Choose scripts based on channel type.
confirmedHtlcSpends := false
if chanType.HasAnchors() {
confirmedHtlcSpends = true
}
// Generate the proper redeem scripts for the HTLC output modified by
// two-bits denoting if this is an incoming HTLC, and if the HTLC is
// being applied to their commitment transaction or ours.
switch {
// The HTLC is paying to us, and being applied to our commitment
// transaction. So we need to use the receiver's version of the HTLC
// script.
case isIncoming && whoseCommit.IsLocal():
witnessScript, err = input.ReceiverHTLCScript(
timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
)
// We're being paid via an HTLC by the remote party, and the HTLC is
// being added to their commitment transaction, so we use the sender's
// version of the HTLC script.
case isIncoming && whoseCommit.IsRemote():
witnessScript, err = input.SenderHTLCScript(
keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
)
// We're sending an HTLC which is being added to our commitment
// transaction. Therefore, we need to use the sender's version of the
// HTLC script.
case !isIncoming && whoseCommit.IsLocal():
witnessScript, err = input.SenderHTLCScript(
keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
)
// Finally, we're paying the remote party via an HTLC, which is being
// added to their commitment transaction. Therefore, we use the
// receiver's version of the HTLC script.
case !isIncoming && whoseCommit.IsRemote():
witnessScript, err = input.ReceiverHTLCScript(
timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], confirmedHtlcSpends,
)
}
if err != nil {
return nil, err
}
// Now that we have the redeem scripts, create the P2WSH public key
// script for the output itself.
htlcP2WSH, err := input.WitnessScriptHash(witnessScript)
if err != nil {
return nil, err
}
return &WitnessScriptDesc{
OutputScript: htlcP2WSH,
WitnessScript: witnessScript,
}, nil
}
// GenTaprootHtlcScript generates the HTLC scripts for a taproot+musig2
// channel.
func GenTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty,
timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing,
auxLeaf input.AuxTapLeaf) (*input.HtlcScriptTree, error) {
var (
htlcScriptTree *input.HtlcScriptTree
err error
)
// Generate the proper redeem scripts for the HTLC output modified by
// two-bits denoting if this is an incoming HTLC, and if the HTLC is
// being applied to their commitment transaction or ours.
switch {
// The HTLC is paying to us, and being applied to our commitment
// transaction. So we need to use the receiver's version of HTLC the
// script.
case isIncoming && whoseCommit.IsLocal():
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf,
)
// We're being paid via an HTLC by the remote party, and the HTLC is
// being added to their commitment transaction, so we use the sender's
// version of the HTLC script.
case isIncoming && whoseCommit.IsRemote():
htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey,
keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf,
)
// We're sending an HTLC which is being added to our commitment
// transaction. Therefore, we need to use the sender's version of the
// HTLC script.
case !isIncoming && whoseCommit.IsLocal():
htlcScriptTree, err = input.SenderHTLCScriptTaproot(
keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf,
)
// Finally, we're paying the remote party via an HTLC, which is being
// added to their commitment transaction. Therefore, we use the
// receiver's version of the HTLC script.
case !isIncoming && whoseCommit.IsRemote():
htlcScriptTree, err = input.ReceiverHTLCScriptTaproot(
timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey,
keyRing.RevocationKey, rHash[:], whoseCommit, auxLeaf,
)
}
return htlcScriptTree, err
}
// genHtlcScript generates the proper P2WSH public key scripts for the HTLC
// output modified by two-bits denoting if this is an incoming HTLC, and if the
// HTLC is being applied to their commitment transaction or ours. A script
// multiplexer for the various spending paths is returned. The script path that
// we need to sign for the remote party (2nd level HTLCs) is also returned
// along side the multiplexer.
func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool,
whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte,
keyRing *CommitmentKeyRing,
auxLeaf input.AuxTapLeaf) (input.ScriptDescriptor, error) {
if !chanType.IsTaproot() {
return genSegwitV0HtlcScript(
chanType, isIncoming, whoseCommit, timeout, rHash,
keyRing,
)
}
return GenTaprootHtlcScript(
isIncoming, whoseCommit, timeout, rHash, keyRing, auxLeaf,
)
}
// addHTLC adds a new HTLC to the passed commitment transaction. One of four
// full scripts will be generated for the HTLC output depending on if the HTLC
// is incoming and if it's being applied to our commitment transaction or that
// of the remote node's. Additionally, in order to be able to efficiently
// locate the added HTLC on the commitment transaction from the
// paymentDescriptor that generated it, the generated script is stored within
// the descriptor itself.
func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty,
isIncoming bool, paymentDesc *paymentDescriptor,
keyRing *CommitmentKeyRing, chanType channeldb.ChannelType,
auxLeaf input.AuxTapLeaf) error {
timeout := paymentDesc.Timeout
rHash := paymentDesc.RHash
scriptInfo, err := genHtlcScript(
chanType, isIncoming, whoseCommit, timeout, rHash, keyRing,
auxLeaf,
)
if err != nil {
return err
}
pkScript := scriptInfo.PkScript()
// Add the new HTLC outputs to the respective commitment transactions.
amountPending := int64(paymentDesc.Amount.ToSatoshis())
commitTx.AddTxOut(wire.NewTxOut(amountPending, pkScript))
// Store the pkScript of this particular paymentDescriptor so we can
// quickly locate it within the commitment transaction later.
if whoseCommit.IsLocal() {
paymentDesc.ourPkScript = pkScript
paymentDesc.ourWitnessScript = scriptInfo.WitnessScriptToSign()
} else {
paymentDesc.theirPkScript = pkScript
//nolint:ll
paymentDesc.theirWitnessScript = scriptInfo.WitnessScriptToSign()
}
return nil
}
// findOutputIndexesFromRemote finds the index of our and their outputs from
// the remote commitment transaction. It derives the key ring to compute the
// output scripts and compares them against the outputs inside the commitment
// to find the match.
func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash,
chanState *channeldb.OpenChannel,
leafStore fn.Option[AuxLeafStore]) (uint32, uint32, error) {
// Init the output indexes as empty.
ourIndex := uint32(channeldb.OutputIndexEmpty)
theirIndex := uint32(channeldb.OutputIndexEmpty)
chanCommit := chanState.RemoteCommitment
_, commitmentPoint := btcec.PrivKeyFromBytes(revocationPreimage[:])
// With the commitment point generated, we can now derive the king ring
// which will be used to generate the output scripts.
keyRing := DeriveCommitmentKeys(
commitmentPoint, lntypes.Remote, chanState.ChanType,
&chanState.LocalChanCfg, &chanState.RemoteChanCfg,
)
// Since it's remote commitment chain, we'd used the mirrored values.
//
// We use the remote's channel config for the csv delay.
theirDelay := uint32(chanState.RemoteChanCfg.CsvDelay)
// If we are the initiator of this channel, then it's be false from the
// remote's PoV.
isRemoteInitiator := !chanState.IsInitiator
var leaseExpiry uint32
if chanState.ChanType.HasLeaseExpiration() {
leaseExpiry = chanState.ThawHeight
}
// If we have a custom blob, then we'll attempt to fetch the aux leaves
// for this state.
auxResult, err := fn.MapOptionZ(
leafStore, func(a AuxLeafStore) fn.Result[CommitDiffAuxResult] {
return a.FetchLeavesFromCommit(
NewAuxChanState(chanState), chanCommit,
*keyRing, lntypes.Remote,
)
},
).Unpack()
if err != nil {
return ourIndex, theirIndex, fmt.Errorf("unable to fetch aux "+
"leaves: %w", err)
}
// Map the scripts from our PoV. When facing a local commitment, the
// to_local output belongs to us and the to_remote output belongs to
// them. When facing a remote commitment, the to_local output belongs to
// them and the to_remote output belongs to us.
// Compute the to_local script. From our PoV, when facing a remote
// commitment, the to_local output belongs to them.
localAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.LocalAuxLeaf
},
)(auxResult.AuxLeaves)
theirScript, err := CommitScriptToSelf(
chanState.ChanType, isRemoteInitiator, keyRing.ToLocalKey,
keyRing.RevocationKey, theirDelay, leaseExpiry, localAuxLeaf,
)
if err != nil {
return ourIndex, theirIndex, err
}
// Compute the to_remote script. From our PoV, when facing a remote
// commitment, the to_remote output belongs to us.
remoteAuxLeaf := fn.FlatMapOption(
func(l CommitAuxLeaves) input.AuxTapLeaf {
return l.RemoteAuxLeaf
},
)(auxResult.AuxLeaves)
ourScript, _, err := CommitScriptToRemote(
chanState.ChanType, isRemoteInitiator, keyRing.ToRemoteKey,
leaseExpiry, remoteAuxLeaf,
)
if err != nil {
return ourIndex, theirIndex, err
}
// Now compare the scripts to find our/their output index.
for i, txOut := range chanCommit.CommitTx.TxOut {
switch {
case bytes.Equal(txOut.PkScript, ourScript.PkScript()):
ourIndex = uint32(i)
case bytes.Equal(txOut.PkScript, theirScript.PkScript()):
theirIndex = uint32(i)
}
}
return ourIndex, theirIndex, nil
}
package lnwallet
import (
"github.com/lightningnetwork/lnd/fn/v2"
)
// commitmentChain represents a chain of unrevoked commitments. The tail of the
// chain is the latest fully signed, yet unrevoked commitment. Two chains are
// tracked, one for the local node, and another for the remote node. New
// commitments we create locally extend the remote node's chain, and vice
// versa. Commitment chains are allowed to grow to a bounded length, after
// which the tail needs to be "dropped" before new commitments can be received.
// The tail is "dropped" when the owner of the chain sends a revocation for the
// previous tail.
type commitmentChain struct {
// commitments is a linked list of commitments to new states. New
// commitments are added to the end of the chain with increase height.
// Once a commitment transaction is revoked, the tail is incremented,
// freeing up the revocation window for new commitments.
commitments *fn.List[*commitment]
}
// newCommitmentChain creates a new commitment chain.
func newCommitmentChain() *commitmentChain {
return &commitmentChain{
commitments: fn.NewList[*commitment](),
}
}
// addCommitment extends the commitment chain by a single commitment. This
// added commitment represents a state update proposed by either party. Once
// the commitment prior to this commitment is revoked, the commitment becomes
// the new defacto state within the channel.
func (s *commitmentChain) addCommitment(c *commitment) {
s.commitments.PushBack(c)
}
// advanceTail reduces the length of the commitment chain by one. The tail of
// the chain should be advanced once a revocation for the lowest unrevoked
// commitment in the chain is received.
func (s *commitmentChain) advanceTail() {
s.commitments.Remove(s.commitments.Front())
}
// tip returns the latest commitment added to the chain.
func (s *commitmentChain) tip() *commitment {
return s.commitments.Back().Value
}
// tail returns the lowest unrevoked commitment transaction in the chain.
func (s *commitmentChain) tail() *commitment {
return s.commitments.Front().Value
}
// hasUnackedCommitment returns true if the commitment chain has more than one
// entry. The tail of the commitment chain has been ACKed by revoking all prior
// commitments, but any subsequent commitments have not yet been ACKed.
func (s *commitmentChain) hasUnackedCommitment() bool {
return s.commitments.Front() != s.commitments.Back()
}
package lnwallet
import (
"errors"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnwire"
)
// ReservationError wraps certain errors returned during channel reservation
// that can be sent across the wire to the remote peer. Errors not being
// ReservationErrors will not be sent to the remote in case of a failed channel
// reservation, as they may contain private information.
type ReservationError struct {
error
}
// A compile time check to ensure ReservationError implements the error
// interface.
var _ error = (*ReservationError)(nil)
// ErrZeroCapacity returns an error indicating the funder attempted to put zero
// funds into the channel.
func ErrZeroCapacity() ReservationError {
return ReservationError{
errors.New("zero channel funds"),
}
}
// ErrChainMismatch returns an error indicating that the initiator tried to
// open a channel for an unknown chain.
func ErrChainMismatch(knownChain,
unknownChain *chainhash.Hash) ReservationError {
return ReservationError{
fmt.Errorf("unknown chain=%v, supported chain=%v",
unknownChain, knownChain),
}
}
// ErrFunderBalanceDust returns an error indicating the initial balance of the
// funder is considered dust at the current commitment fee.
func ErrFunderBalanceDust(commitFee, funderBalance,
minBalance int64) ReservationError {
return ReservationError{
fmt.Errorf("funder balance too small (%v) with fee=%v sat, "+
"minimum=%v sat required", funderBalance,
commitFee, minBalance),
}
}
// ErrCsvDelayTooLarge returns an error indicating that the CSV delay was to
// large to be accepted, along with the current max.
func ErrCsvDelayTooLarge(remoteDelay, maxDelay uint16) ReservationError {
return ReservationError{
fmt.Errorf("CSV delay too large: %v, max is %v",
remoteDelay, maxDelay),
}
}
// ErrChanReserveTooSmall returns an error indicating that the channel reserve
// the remote is requiring is too small to be accepted.
func ErrChanReserveTooSmall(reserve, dustLimit btcutil.Amount) ReservationError {
return ReservationError{
fmt.Errorf("channel reserve of %v sat is too small, min is %v "+
"sat", int64(reserve), int64(dustLimit)),
}
}
// ErrChanReserveTooLarge returns an error indicating that the chan reserve the
// remote is requiring, is too large to be accepted.
func ErrChanReserveTooLarge(reserve,
maxReserve btcutil.Amount) ReservationError {
return ReservationError{
fmt.Errorf("channel reserve is too large: %v sat, max "+
"is %v sat", int64(reserve), int64(maxReserve)),
}
}
// ErrNonZeroPushAmount is returned by a remote peer that receives a
// FundingOpen request for a channel with non-zero push amount while
// they have 'rejectpush' enabled.
func ErrNonZeroPushAmount() ReservationError {
return ReservationError{errors.New("non-zero push amounts are disabled")}
}
// ErrMinHtlcTooLarge returns an error indicating that the MinHTLC value the
// remote required is too large to be accepted.
func ErrMinHtlcTooLarge(minHtlc,
maxMinHtlc lnwire.MilliSatoshi) ReservationError {
return ReservationError{
fmt.Errorf("minimum HTLC value is too large: %v, max is %v",
minHtlc, maxMinHtlc),
}
}
// ErrMaxHtlcNumTooLarge returns an error indicating that the 'max HTLCs in
// flight' value the remote required is too large to be accepted.
func ErrMaxHtlcNumTooLarge(maxHtlc, maxMaxHtlc uint16) ReservationError {
return ReservationError{
fmt.Errorf("maxHtlcs is too large: %d, max is %d",
maxHtlc, maxMaxHtlc),
}
}
// ErrMaxHtlcNumTooSmall returns an error indicating that the 'max HTLCs in
// flight' value the remote required is too small to be accepted.
func ErrMaxHtlcNumTooSmall(maxHtlc, minMaxHtlc uint16) ReservationError {
return ReservationError{
fmt.Errorf("maxHtlcs is too small: %d, min is %d",
maxHtlc, minMaxHtlc),
}
}
// ErrMaxValueInFlightTooSmall returns an error indicating that the 'max HTLC
// value in flight' the remote required is too small to be accepted.
func ErrMaxValueInFlightTooSmall(maxValInFlight,
minMaxValInFlight lnwire.MilliSatoshi) ReservationError {
return ReservationError{
fmt.Errorf("maxValueInFlight too small: %v, min is %v",
maxValInFlight, minMaxValInFlight),
}
}
// ErrNumConfsTooLarge returns an error indicating that the number of
// confirmations required for a channel is too large.
func ErrNumConfsTooLarge(numConfs, maxNumConfs uint32) error {
return ReservationError{
fmt.Errorf("minimum depth of %d is too large, max is %d",
numConfs, maxNumConfs),
}
}
// ErrChanTooSmall returns an error indicating that an incoming channel request
// was too small. We'll reject any incoming channels if they're below our
// configured value for the min channel size we'll accept.
func ErrChanTooSmall(chanSize, minChanSize btcutil.Amount) ReservationError {
return ReservationError{
fmt.Errorf("chan size of %v is below min chan size of %v",
chanSize, minChanSize),
}
}
// ErrChanTooLarge returns an error indicating that an incoming channel request
// was too large. We'll reject any incoming channels if they're above our
// configured value for the max channel size we'll accept.
func ErrChanTooLarge(chanSize, maxChanSize btcutil.Amount) ReservationError {
return ReservationError{
fmt.Errorf("chan size of %v exceeds maximum chan size of %v",
chanSize, maxChanSize),
}
}
// ErrInvalidDustLimit returns an error indicating that a proposed DustLimit
// was rejected.
func ErrInvalidDustLimit(dustLimit btcutil.Amount) ReservationError {
return ReservationError{
fmt.Errorf("dust limit %v is invalid", dustLimit),
}
}
// ErrHtlcIndexAlreadyFailed is returned when the HTLC index has already been
// failed, but has not been committed by our commitment state.
type ErrHtlcIndexAlreadyFailed uint64
// Error returns a message indicating the index that had already been failed.
func (e ErrHtlcIndexAlreadyFailed) Error() string {
return fmt.Sprintf("HTLC with ID %d has already been failed", e)
}
// ErrHtlcIndexAlreadySettled is returned when the HTLC index has already been
// settled, but has not been committed by our commitment state.
type ErrHtlcIndexAlreadySettled uint64
// Error returns a message indicating the index that had already been settled.
func (e ErrHtlcIndexAlreadySettled) Error() string {
return fmt.Sprintf("HTLC with ID %d has already been settled", e)
}
// ErrInvalidSettlePreimage is returned when trying to settle an HTLC, but the
// preimage does not correspond to the payment hash.
type ErrInvalidSettlePreimage struct {
preimage []byte
rhash []byte
}
// Error returns an error message with the offending preimage and intended
// payment hash.
func (e ErrInvalidSettlePreimage) Error() string {
return fmt.Sprintf("Invalid payment preimage %x for hash %x",
e.preimage, e.rhash)
}
// ErrUnknownHtlcIndex is returned when locally settling or failing an HTLC, but
// the HTLC index is not known to the channel. This typically indicates that the
// HTLC was already settled in a prior commitment.
type ErrUnknownHtlcIndex struct {
chanID lnwire.ShortChannelID
index uint64
}
// Error returns an error logging the channel and HTLC index that was unknown.
func (e ErrUnknownHtlcIndex) Error() string {
return fmt.Sprintf("No HTLC with ID %d in channel %v",
e.index, e.chanID)
}
package lnwallet
import (
"errors"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
base "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// DefaultAccountName is the name for the default account used to manage
// on-chain funds within the wallet.
DefaultAccountName = "default"
)
// AddressType is an enum-like type which denotes the possible address types
// WalletController supports.
type AddressType uint8
// AccountAddressMap maps the account properties to an array of
// address properties.
type AccountAddressMap map[*waddrmgr.AccountProperties][]AddressProperty
const (
// UnknownAddressType represents an output with an unknown or non-standard
// script.
UnknownAddressType AddressType = iota
// WitnessPubKey represents a p2wkh address.
WitnessPubKey
// NestedWitnessPubKey represents a p2sh output which is itself a
// nested p2wkh output.
NestedWitnessPubKey
// TaprootPubkey represents a p2tr key path spending address.
TaprootPubkey
)
var (
// DefaultPublicPassphrase is the default public passphrase used for the
// wallet.
DefaultPublicPassphrase = []byte("public")
// DefaultPrivatePassphrase is the default private passphrase used for
// the wallet.
DefaultPrivatePassphrase = []byte("hello")
// ErrDoubleSpend is returned from PublishTransaction in case the
// tx being published is spending an output spent by a conflicting
// transaction.
ErrDoubleSpend = errors.New("transaction rejected: output already spent")
// ErrNotMine is an error denoting that a WalletController instance is
// unable to spend a specified output.
ErrNotMine = errors.New("the passed output doesn't belong to the wallet")
// ErrMempoolFee is returned from PublishTransaction in case the tx
// being published is not accepted into mempool because the fee
// requirements of the mempool backend are not met.
ErrMempoolFee = errors.New("transaction rejected by the mempool " +
"because of low fees")
)
// ErrNoOutputs is returned if we try to create a transaction with no outputs
// or send coins to a set of outputs that is empty.
var ErrNoOutputs = errors.New("no outputs")
// ErrInvalidMinconf is returned if we try to create a transaction with
// invalid minConfs value.
var ErrInvalidMinconf = errors.New("minimum number of confirmations must " +
"be a non-negative number")
// AddressProperty contains wallet related information of an address.
type AddressProperty struct {
// Address is the address of an account.
Address string
// Internal denotes if the address is a change address.
Internal bool
// Balance returns the total balance of an address.
Balance btcutil.Amount
// DerivationPath is the derivation path of the address.
DerivationPath string
// PublicKey is the public key of the address.
PublicKey *btcec.PublicKey
}
// AccountIdentifier contains information to uniquely identify an account.
type AccountIdentifier struct {
// Name is the name of the account.
Name string
// AddressType is the type of addresses supported by the account.
AddressType AddressType
// DerivationPath is the derivation path corresponding to the account
// public key.
DerivationPath string
}
// Utxo is an unspent output denoted by its outpoint, and output value of the
// original output.
type Utxo struct {
AddressType AddressType
Value btcutil.Amount
Confirmations int64
PkScript []byte
wire.OutPoint
PrevTx *wire.MsgTx
}
// OutputDetail contains additional information on a destination address.
type OutputDetail struct {
OutputType txscript.ScriptClass
Addresses []btcutil.Address
PkScript []byte
OutputIndex int
Value btcutil.Amount
IsOurAddress bool
}
// PreviousOutPoint contains information about the previous outpoint.
type PreviousOutPoint struct {
// OutPoint is the transaction out point in the format txid:n.
OutPoint string
// IsOurOutput denotes if the previous output is controlled by the
// internal wallet. The flag will only detect p2wkh, np2wkh and p2tr
// inputs as its own.
IsOurOutput bool
}
// TransactionDetail describes a transaction with either inputs which belong to
// the wallet, or has outputs that pay to the wallet.
type TransactionDetail struct {
// Hash is the transaction hash of the transaction.
Hash chainhash.Hash
// Value is the net value of this transaction (in satoshis) from the
// PoV of the wallet. If this transaction purely spends from the
// wallet's funds, then this value will be negative. Similarly, if this
// transaction credits the wallet, then this value will be positive.
Value btcutil.Amount
// NumConfirmations is the number of confirmations this transaction
// has. If the transaction is unconfirmed, then this value will be
// zero.
NumConfirmations int32
// BlockHeight is the hash of the block which includes this
// transaction. Unconfirmed transactions will have a nil value for this
// field.
BlockHash *chainhash.Hash
// BlockHeight is the height of the block including this transaction.
// Unconfirmed transaction will show a height of zero.
BlockHeight int32
// Timestamp is the unix timestamp of the block including this
// transaction. If the transaction is unconfirmed, then this will be a
// timestamp of txn creation.
Timestamp int64
// TotalFees is the total fee in satoshis paid by this transaction.
TotalFees int64
// OutputDetails contains output data for each destination address, such
// as the output script and amount.
OutputDetails []OutputDetail
// RawTx returns the raw serialized transaction.
RawTx []byte
// Label is an optional transaction label.
Label string
// PreviousOutpoints are the inputs for a transaction.
PreviousOutpoints []PreviousOutPoint
}
// TransactionSubscription is an interface which describes an object capable of
// receiving notifications of new transaction related to the underlying wallet.
// TODO(roasbeef): add balance updates?
type TransactionSubscription interface {
// ConfirmedTransactions returns a channel which will be sent on as new
// relevant transactions are confirmed.
ConfirmedTransactions() chan *TransactionDetail
// UnconfirmedTransactions returns a channel which will be sent on as
// new relevant transactions are seen within the network.
UnconfirmedTransactions() chan *TransactionDetail
// Cancel finalizes the subscription, cleaning up any resources
// allocated.
Cancel()
}
// WalletController defines an abstract interface for controlling a local Pure
// Go wallet, a local or remote wallet via an RPC mechanism, or possibly even
// a daemon assisted hardware wallet. This interface serves the purpose of
// allowing LightningWallet to be seamlessly compatible with several wallets
// such as: uspv, btcwallet, Bitcoin Core, Electrum, etc. This interface then
// serves as a "base wallet", with Lightning Network awareness taking place at
// a "higher" level of abstraction. Essentially, an overlay wallet.
// Implementors of this interface must closely adhere to the documented
// behavior of all interface methods in order to ensure identical behavior
// across all concrete implementations.
type WalletController interface {
// FetchOutpointInfo queries for the WalletController's knowledge of
// the passed outpoint. If the base wallet determines this output is
// under its control, then the original txout should be returned.
// Otherwise, a non-nil error value of ErrNotMine should be returned
// instead.
FetchOutpointInfo(prevOut *wire.OutPoint) (*Utxo, error)
// FetchDerivationInfo queries for the wallet's knowledge of the passed
// pkScript and constructs the derivation info and returns it.
FetchDerivationInfo(pkScript []byte) (*psbt.Bip32Derivation, error)
// ScriptForOutput returns the address, witness program and redeem
// script for a given UTXO. An error is returned if the UTXO does not
// belong to our wallet or it is not a managed pubKey address.
ScriptForOutput(output *wire.TxOut) (waddrmgr.ManagedPubKeyAddress,
[]byte, []byte, error)
// ConfirmedBalance returns the sum of all the wallet's unspent outputs
// that have at least confs confirmations. If confs is set to zero,
// then all unspent outputs, including those currently in the mempool
// will be included in the final sum. The account parameter serves as a
// filter to retrieve the balance for a specific account. When empty,
// the confirmed balance of all wallet accounts is returned.
//
// NOTE: Only witness outputs should be included in the computation of
// the total spendable balance of the wallet. We require this as only
// witness inputs can be used for funding channels.
ConfirmedBalance(confs int32, accountFilter string) (btcutil.Amount,
error)
// NewAddress returns the next external or internal address for the
// wallet dictated by the value of the `change` parameter. If change is
// true, then an internal address should be used, otherwise an external
// address should be returned. The type of address returned is dictated
// by the wallet's capabilities, and may be of type: p2sh, p2wkh,
// p2wsh, etc. The account parameter must be non-empty as it determines
// which account the address should be generated from.
NewAddress(addrType AddressType, change bool,
account string) (btcutil.Address, error)
// LastUnusedAddress returns the last *unused* address known by the
// wallet. An address is unused if it hasn't received any payments.
// This can be useful in UIs in order to continually show the
// "freshest" address without having to worry about "address inflation"
// caused by continual refreshing. Similar to NewAddress it can derive
// a specified address type. By default, this is a non-change address.
// The account parameter must be non-empty as it determines which
// account the address should be generated from.
LastUnusedAddress(addrType AddressType,
account string) (btcutil.Address, error)
// IsOurAddress checks if the passed address belongs to this wallet
IsOurAddress(a btcutil.Address) bool
// AddressInfo returns the information about an address, if it's known
// to this wallet.
AddressInfo(a btcutil.Address) (waddrmgr.ManagedAddress, error)
// ListAccounts retrieves all accounts belonging to the wallet by
// default. A name and key scope filter can be provided to filter
// through all of the wallet accounts and return only those matching.
ListAccounts(string, *waddrmgr.KeyScope) ([]*waddrmgr.AccountProperties,
error)
// RequiredReserve returns the minimum amount of satoshis that should be
// kept in the wallet in order to fee bump anchor channels if necessary.
// The value scales with the number of public anchor channels but is
// capped at a maximum.
RequiredReserve(uint32) btcutil.Amount
// ListAddresses retrieves all the addresses along with their balance. An
// account name filter can be provided to filter through all of the
// wallet accounts and return the addresses of only those matching.
ListAddresses(string, bool) (AccountAddressMap, error)
// ImportAccount imports an account backed by an account extended public
// key. The master key fingerprint denotes the fingerprint of the root
// key corresponding to the account public key (also known as the key
// with derivation path m/). This may be required by some hardware
// wallets for proper identification and signing.
//
// The address type can usually be inferred from the key's version, but
// may be required for certain keys to map them into the proper scope.
//
// For BIP-0044 keys, an address type must be specified as we intend to
// not support importing BIP-0044 keys into the wallet using the legacy
// pay-to-pubkey-hash (P2PKH) scheme. A nested witness address type will
// force the standard BIP-0049 derivation scheme, while a witness
// address type will force the standard BIP-0084 derivation scheme.
//
// For BIP-0049 keys, an address type must also be specified to make a
// distinction between the standard BIP-0049 address schema (nested
// witness pubkeys everywhere) and our own BIP-0049Plus address schema
// (nested pubkeys externally, witness pubkeys internally).
ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKey,
masterKeyFingerprint uint32, addrType *waddrmgr.AddressType,
dryRun bool) (*waddrmgr.AccountProperties, []btcutil.Address,
[]btcutil.Address, error)
// ImportPublicKey imports a single derived public key into the wallet.
// The address type can usually be inferred from the key's version, but
// in the case of legacy versions (xpub, tpub), an address type must be
// specified as we intend to not support importing BIP-44 keys into the
// wallet using the legacy pay-to-pubkey-hash (P2PKH) scheme.
ImportPublicKey(pubKey *btcec.PublicKey,
addrType waddrmgr.AddressType) error
// ImportTaprootScript imports a user-provided taproot script into the
// wallet. The imported script will act as a pay-to-taproot address.
//
// NOTE: Taproot keys imported through this RPC currently _cannot_ be
// used for funding PSBTs. Only tracking the balance and UTXOs is
// currently supported.
ImportTaprootScript(scope waddrmgr.KeyScope,
tapscript *waddrmgr.Tapscript) (waddrmgr.ManagedAddress, error)
// SendOutputs funds, signs, and broadcasts a Bitcoin transaction paying
// out to the specified outputs. In the case the wallet has insufficient
// funds, or the outputs are non-standard, an error should be returned.
// This method also takes the target fee expressed in sat/kw that should
// be used when crafting the transaction.
//
// NOTE: This method requires the global coin selection lock to be held.
SendOutputs(inputs fn.Set[wire.OutPoint], outputs []*wire.TxOut,
feeRate chainfee.SatPerKWeight, minConfs int32, label string,
strategy base.CoinSelectionStrategy) (*wire.MsgTx, error)
// CreateSimpleTx creates a Bitcoin transaction paying to the specified
// outputs. The transaction is not broadcasted to the network. In the
// case the wallet has insufficient funds, or the outputs are
// non-standard, an error should be returned. This method also takes
// the target fee expressed in sat/kw that should be used when crafting
// the transaction.
//
// NOTE: The dryRun argument can be set true to create a tx that
// doesn't alter the database. A tx created with this set to true
// SHOULD NOT be broadcasted.
//
// NOTE: This method requires the global coin selection lock to be held.
CreateSimpleTx(inputs fn.Set[wire.OutPoint], outputs []*wire.TxOut,
feeRate chainfee.SatPerKWeight, minConfs int32,
strategy base.CoinSelectionStrategy, dryRun bool) (
*txauthor.AuthoredTx, error)
// GetTransactionDetails returns a detailed description of a transaction
// given its transaction hash.
GetTransactionDetails(txHash *chainhash.Hash) (
*TransactionDetail, error)
// ListUnspentWitness returns all unspent outputs which are version 0
// witness programs. The 'minConfs' and 'maxConfs' parameters
// indicate the minimum and maximum number of confirmations an output
// needs in order to be returned by this method. Passing -1 as
// 'minConfs' indicates that even unconfirmed outputs should be
// returned. Using MaxInt32 as 'maxConfs' implies returning all
// outputs with at least 'minConfs'. The account parameter serves as
// a filter to retrieve the unspent outputs for a specific account.
// When empty, the unspent outputs of all wallet accounts are returned.
//
// NOTE: This method requires the global coin selection lock to be held.
ListUnspentWitness(minConfs, maxConfs int32,
accountFilter string) ([]*Utxo, error)
// ListTransactionDetails returns a list of all transactions which are
// relevant to the wallet over [startHeight;endHeight]. If start height
// is greater than end height, the transactions will be retrieved in
// reverse order. To include unconfirmed transactions, endHeight should
// be set to the special value -1. This will return transactions from
// the tip of the chain until the start height (inclusive) and
// unconfirmed transactions. The account parameter serves as a filter to
// retrieve the transactions relevant to a specific account. When
// empty, transactions of all wallet accounts are returned.
ListTransactionDetails(startHeight, endHeight int32,
accountFilter string, indexOffset uint32,
maxTransactions uint32) ([]*TransactionDetail, uint64, uint64,
error)
// LeaseOutput locks an output to the given ID, preventing it from being
// available for any future coin selection attempts. The absolute time
// of the lock's expiration is returned. The expiration of the lock can
// be extended by successive invocations of this call. Outputs can be
// unlocked before their expiration through `ReleaseOutput`.
//
// If the output is not known, wtxmgr.ErrUnknownOutput is returned. If
// the output has already been locked to a different ID, then
// wtxmgr.ErrOutputAlreadyLocked is returned.
//
// NOTE: This method requires the global coin selection lock to be held.
LeaseOutput(id wtxmgr.LockID, op wire.OutPoint,
duration time.Duration) (time.Time, error)
// ReleaseOutput unlocks an output, allowing it to be available for coin
// selection if it remains unspent. The ID should match the one used to
// originally lock the output.
//
// NOTE: This method requires the global coin selection lock to be held.
ReleaseOutput(id wtxmgr.LockID, op wire.OutPoint) error
// ListLeasedOutputs returns a list of all currently locked outputs.
ListLeasedOutputs() ([]*base.ListLeasedOutputResult, error)
// PublishTransaction performs cursory validation (dust checks, etc),
// then finally broadcasts the passed transaction to the Bitcoin network.
// If the transaction is rejected because it is conflicting with an
// already known transaction, ErrDoubleSpend is returned. If the
// transaction is already known (published already), no error will be
// returned. Other error returned depends on the currently active chain
// backend. It takes an optional label which will save a label with the
// published transaction.
PublishTransaction(tx *wire.MsgTx, label string) error
// LabelTransaction adds a label to a transaction. If the tx already
// has a label, this call will fail unless the overwrite parameter
// is set. Labels must not be empty, and they are limited to 500 chars.
LabelTransaction(hash chainhash.Hash, label string, overwrite bool) error
// FetchTx attempts to fetch a transaction in the wallet's database
// identified by the passed transaction hash. If the transaction can't
// be found, then a nil pointer is returned.
FetchTx(chainhash.Hash) (*wire.MsgTx, error)
// RemoveDescendants attempts to remove any transaction from the
// wallet's tx store (that may be unconfirmed) that spends outputs
// created by the passed transaction. This remove propagates
// recursively down the chain of descendent transactions.
RemoveDescendants(*wire.MsgTx) error
// FundPsbt creates a fully populated PSBT packet that contains enough
// inputs to fund the outputs specified in the passed in packet with the
// specified fee rate. If there is change left, a change output from the
// internal wallet is added and the index of the change output is
// returned. Otherwise no additional output is created and the index -1
// is returned. If no custom change scope is specified, the BIP0084 will
// be used for default accounts and single imported public keys. For
// custom account, no key scope should be provided as the coin selection
// key scope will always be used to generate the change address.
//
// NOTE: If the packet doesn't contain any inputs, coin selection is
// performed automatically. The account parameter must be non-empty as
// it determines which set of coins are eligible for coin selection. If
// the packet does contain any inputs, it is assumed that full coin
// selection happened externally and no additional inputs are added. If
// the specified inputs aren't enough to fund the outputs with the given
// fee rate, an error is returned. No lock lease is acquired for any of
// the selected/validated inputs. It is in the caller's responsibility
// to lock the inputs before handing them out.
FundPsbt(packet *psbt.Packet, minConfs int32,
feeRate chainfee.SatPerKWeight, account string,
changeScope *waddrmgr.KeyScope,
strategy base.CoinSelectionStrategy,
allowUtxo func(wtxmgr.Credit) bool) (int32, error)
// SignPsbt expects a partial transaction with all inputs and outputs
// fully declared and tries to sign all unsigned inputs that have all
// required fields (UTXO information, BIP32 derivation information,
// witness or sig scripts) set.
// If no error is returned, the PSBT is ready to be given to the next
// signer or to be finalized if lnd was the last signer.
//
// NOTE: This method only signs inputs (and only those it can sign), it
// does not perform any other tasks (such as coin selection, UTXO
// locking or input/output/fee value validation, PSBT finalization). Any
// input that is incomplete will be skipped.
SignPsbt(packet *psbt.Packet) ([]uint32, error)
// FinalizePsbt expects a partial transaction with all inputs and
// outputs fully declared and tries to sign all inputs that belong to
// the specified account. Lnd must be the last signer of the
// transaction. That means, if there are any unsigned non-witness inputs
// or inputs without UTXO information attached or inputs without witness
// data that do not belong to lnd's wallet, this method will fail. If no
// error is returned, the PSBT is ready to be extracted and the final TX
// within to be broadcast.
//
// NOTE: This method does NOT publish the transaction after it's been
// finalized successfully.
FinalizePsbt(packet *psbt.Packet, account string) error
// DecorateInputs fetches the UTXO information of all inputs it can
// identify and adds the required information to the package's inputs.
// The failOnUnknown boolean controls whether the method should return
// an error if it cannot identify an input or if it should just skip it.
DecorateInputs(packet *psbt.Packet, failOnUnknown bool) error
// SubscribeTransactions returns a TransactionSubscription client which
// is capable of receiving async notifications as new transactions
// related to the wallet are seen within the network, or found in
// blocks.
//
// NOTE: a non-nil error should be returned if notifications aren't
// supported.
//
// TODO(roasbeef): make distinct interface?
SubscribeTransactions() (TransactionSubscription, error)
// IsSynced returns a boolean indicating if from the PoV of the wallet,
// it has fully synced to the current best block in the main chain.
// It also returns an int64 indicating the timestamp of the best block
// known to the wallet, expressed in Unix epoch time
IsSynced() (bool, int64, error)
// GetRecoveryInfo returns a boolean indicating whether the wallet is
// started in recovery mode. It also returns a float64 indicating the
// recovery progress made so far.
GetRecoveryInfo() (bool, float64, error)
// Start initializes the wallet, making any necessary connections,
// starting up required goroutines etc.
Start() error
// Stop signals the wallet for shutdown. Shutdown may entail closing
// any active sockets, database handles, stopping goroutines, etc.
Stop() error
// BackEnd returns a name for the wallet's backing chain service,
// which could be e.g. btcd, bitcoind, neutrino, or another consensus
// service.
BackEnd() string
// CheckMempoolAcceptance checks whether a transaction follows mempool
// policies and returns an error if it cannot be accepted into the
// mempool.
CheckMempoolAcceptance(tx *wire.MsgTx) error
}
// BlockChainIO is a dedicated source which will be used to obtain queries
// related to the current state of the blockchain. The data returned by each of
// the defined methods within this interface should always return the most up
// to date data possible.
//
// TODO(roasbeef): move to diff package perhaps?
// TODO(roasbeef): move publish txn here?
type BlockChainIO interface {
// GetBestBlock returns the current height and block hash of the valid
// most-work chain the implementation is aware of.
GetBestBlock() (*chainhash.Hash, int32, error)
// GetUtxo attempts to return the passed outpoint if it's still a
// member of the utxo set. The passed height hint should be the "birth
// height" of the passed outpoint. The script passed should be the
// script that the outpoint creates. In the case that the output is in
// the UTXO set, then the output corresponding to that output is
// returned. Otherwise, a non-nil error will be returned.
// As for some backends this call can initiate a rescan, the passed
// cancel channel can be closed to abort the call.
GetUtxo(op *wire.OutPoint, pkScript []byte, heightHint uint32,
cancel <-chan struct{}) (*wire.TxOut, error)
// GetBlockHash returns the hash of the block in the best blockchain
// at the given height.
GetBlockHash(blockHeight int64) (*chainhash.Hash, error)
// GetBlock returns the block in the main chain identified by the given
// hash.
GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error)
// GetBlockHeader returns the block header for the given block hash.
GetBlockHeader(blockHash *chainhash.Hash) (*wire.BlockHeader, error)
}
// MessageSigner represents an abstract object capable of signing arbitrary
// messages. The capabilities of this interface are used to sign announcements
// to the network, or just arbitrary messages that leverage the wallet's keys
// to attest to some message.
type MessageSigner interface {
// SignMessage attempts to sign a target message with the private key
// described in the key locator. If the target private key is unable to
// be found, then an error will be returned. The actual digest signed is
// the single or double SHA-256 of the passed message.
SignMessage(keyLoc keychain.KeyLocator, msg []byte,
doubleHash bool) (*ecdsa.Signature, error)
}
// AddrWithKey wraps a normal addr, but also includes the internal key for the
// delivery addr if known.
type AddrWithKey struct {
lnwire.DeliveryAddress
InternalKey fn.Option[keychain.KeyDescriptor]
// TODO(roasbeef): consolidate w/ instance in chan closer
}
// InternalKeyForAddr returns the internal key associated with a taproot
// address.
func InternalKeyForAddr(wallet WalletController, netParams *chaincfg.Params,
deliveryScript []byte) (fn.Option[keychain.KeyDescriptor], error) {
none := fn.None[keychain.KeyDescriptor]()
pkScript, err := txscript.ParsePkScript(deliveryScript)
if err != nil {
return none, err
}
addr, err := pkScript.Address(netParams)
if err != nil {
return none, err
}
// If it's not a taproot address, we don't require to know the internal
// key in the first place. So we don't return an error here, but also no
// internal key.
_, isTaproot := addr.(*btcutil.AddressTaproot)
if !isTaproot {
return none, nil
}
walletAddr, err := wallet.AddressInfo(addr)
if err != nil {
// If the error is that the address can't be found, it is not
// an error. This happens when any channel which is not a custom
// taproot channel is cooperatively closed to an external P2TR
// address. In this case there is no internal key associated
// with the address. Callers can use the .Option() method to get
// an option value.
var managerErr waddrmgr.ManagerError
if errors.As(err, &managerErr) &&
managerErr.ErrorCode == waddrmgr.ErrAddressNotFound {
return none, nil
}
return none, err
}
// No wallet addr. No error, but we'll return an nil error value here,
// as callers can use the .Option() method to get an option value.
if walletAddr == nil {
return none, nil
}
pubKeyAddr, ok := walletAddr.(waddrmgr.ManagedPubKeyAddress)
if !ok {
return none, fmt.Errorf("expected pubkey addr, got %T",
pubKeyAddr)
}
_, derivationPath, _ := pubKeyAddr.DerivationInfo()
return fn.Some[keychain.KeyDescriptor](keychain.KeyDescriptor{
KeyLocator: keychain.KeyLocator{
Family: keychain.KeyFamily(derivationPath.Account),
Index: derivationPath.Index,
},
PubKey: pubKeyAddr.PubKey(),
}), nil
}
// WalletDriver represents a "driver" for a particular concrete
// WalletController implementation. A driver is identified by a globally unique
// string identifier along with a 'New()' method which is responsible for
// initializing a particular WalletController concrete implementation.
type WalletDriver struct {
// WalletType is a string which uniquely identifies the
// WalletController that this driver, drives.
WalletType string
// New creates a new instance of a concrete WalletController
// implementation given a variadic set up arguments. The function takes
// a variadic number of interface parameters in order to provide
// initialization flexibility, thereby accommodating several potential
// WalletController implementations.
New func(args ...interface{}) (WalletController, error)
// BackEnds returns a list of available chain service drivers for the
// wallet driver. This could be e.g. bitcoind, btcd, neutrino, etc.
BackEnds func() []string
}
var (
wallets = make(map[string]*WalletDriver)
registerMtx sync.Mutex
)
// RegisteredWallets returns a slice of all currently registered notifiers.
//
// NOTE: This function is safe for concurrent access.
func RegisteredWallets() []*WalletDriver {
registerMtx.Lock()
defer registerMtx.Unlock()
registeredWallets := make([]*WalletDriver, 0, len(wallets))
for _, wallet := range wallets {
registeredWallets = append(registeredWallets, wallet)
}
return registeredWallets
}
// RegisterWallet registers a WalletDriver which is capable of driving a
// concrete WalletController interface. In the case that this driver has
// already been registered, an error is returned.
//
// NOTE: This function is safe for concurrent access.
func RegisterWallet(driver *WalletDriver) error {
registerMtx.Lock()
defer registerMtx.Unlock()
if _, ok := wallets[driver.WalletType]; ok {
return fmt.Errorf("wallet already registered")
}
wallets[driver.WalletType] = driver
return nil
}
// SupportedWallets returns a slice of strings that represents the wallet
// drivers that have been registered and are therefore supported.
//
// NOTE: This function is safe for concurrent access.
func SupportedWallets() []string {
registerMtx.Lock()
defer registerMtx.Unlock()
supportedWallets := make([]string, 0, len(wallets))
for walletName := range wallets {
supportedWallets = append(supportedWallets, walletName)
}
return supportedWallets
}
// FetchFundingTxWrapper is a wrapper around FetchFundingTx, except that it will
// exit when the supplied quit channel is closed.
func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID,
quit chan struct{}) (*wire.MsgTx, error) {
txChan := make(chan *wire.MsgTx, 1)
errChan := make(chan error, 1)
go func() {
tx, err := FetchFundingTx(chain, chanID)
if err != nil {
errChan <- err
return
}
txChan <- tx
}()
select {
case tx := <-txChan:
return tx, nil
case err := <-errChan:
return nil, err
case <-quit:
return nil, fmt.Errorf("quit channel passed to " +
"lnwallet.FetchFundingTxWrapper has been closed")
}
}
// FetchFundingTx uses the given BlockChainIO to fetch and return the funding
// transaction identified by the passed short channel ID.
//
// TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to
// later use getblocktxn).
func FetchFundingTx(chain BlockChainIO,
chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) {
// First fetch the block hash by the block number encoded, then use
// that hash to fetch the block itself.
blockNum := int64(chanID.BlockHeight)
blockHash, err := chain.GetBlockHash(blockNum)
if err != nil {
return nil, err
}
fundingBlock, err := chain.GetBlock(blockHash)
if err != nil {
return nil, err
}
// As a sanity check, ensure that the advertised transaction index is
// within the bounds of the total number of transactions within a
// block.
numTxns := uint32(len(fundingBlock.Transactions))
if chanID.TxIndex > numTxns-1 {
return nil, fmt.Errorf("tx_index=#%v "+
"is out of range (max_index=%v), network_chan_id=%v",
chanID.TxIndex, numTxns-1, chanID)
}
return fundingBlock.Transactions[chanID.TxIndex].Copy(), nil
}
// FetchPKScriptWithQuit fetches the output script for the given SCID and exits
// early with an error if the provided quit channel is closed before
// completion.
func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID,
quit chan struct{}) ([]byte, error) {
tx, err := FetchFundingTxWrapper(chain, chanID, quit)
if err != nil {
return nil, err
}
outputLocator := chanvalidate.ShortChanIDChanLocator{
ID: *chanID,
}
output, _, err := outputLocator.Locate(tx)
if err != nil {
return nil, err
}
return output.PkScript, nil
}
package lnwallet
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
// walletLog is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var walletLog btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("LNWL", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
walletLog = logger
chainfee.UseLogger(logger)
}
package lnwallet
import (
"encoding/hex"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/waddrmgr"
base "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/wallet/txauthor"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/mock"
)
var (
CoinPkScript, _ = hex.DecodeString(
"001431df1bde03c074d0cf21ea2529427e1499b8f1de",
)
)
// mockWalletController is a mock implementation of the WalletController
// interface. It let's us mock the interaction with the bitcoin network.
type mockWalletController struct {
RootKey *btcec.PrivateKey
PublishedTransactions chan *wire.MsgTx
index uint32
Utxos []*Utxo
}
// A compile time check to ensure that mockWalletController implements the
// WalletController.
var _ WalletController = (*mockWalletController)(nil)
// BackEnd returns "mock" to signify a mock wallet controller.
func (w *mockWalletController) BackEnd() string {
return "mock"
}
// FetchOutpointInfo will be called to get info about the inputs to the funding
// transaction.
func (w *mockWalletController) FetchOutpointInfo(
prevOut *wire.OutPoint) (*Utxo, error) {
utxo := &Utxo{
AddressType: WitnessPubKey,
Value: 10 * btcutil.SatoshiPerBitcoin,
PkScript: []byte("dummy"),
Confirmations: 1,
OutPoint: *prevOut,
}
return utxo, nil
}
// ScriptForOutput returns the address, witness program and redeem script for a
// given UTXO. An error is returned if the UTXO does not belong to our wallet or
// it is not a managed pubKey address.
func (w *mockWalletController) ScriptForOutput(*wire.TxOut) (
waddrmgr.ManagedPubKeyAddress, []byte, []byte, error) {
return nil, nil, nil, nil
}
// ConfirmedBalance currently returns dummy values.
func (w *mockWalletController) ConfirmedBalance(int32, string) (btcutil.Amount,
error) {
return 0, nil
}
// NewAddress is called to get new addresses for delivery, change etc.
func (w *mockWalletController) NewAddress(AddressType, bool,
string) (btcutil.Address, error) {
addr, _ := btcutil.NewAddressPubKey(
w.RootKey.PubKey().SerializeCompressed(),
&chaincfg.MainNetParams,
)
return addr, nil
}
// LastUnusedAddress currently returns dummy values.
func (w *mockWalletController) LastUnusedAddress(AddressType,
string) (btcutil.Address, error) {
return nil, nil
}
// IsOurAddress currently returns a dummy value.
func (w *mockWalletController) IsOurAddress(btcutil.Address) bool {
return false
}
// AddressInfo currently returns a dummy value.
func (w *mockWalletController) AddressInfo(
btcutil.Address) (waddrmgr.ManagedAddress, error) {
return nil, nil
}
// ListAccounts currently returns a dummy value.
func (w *mockWalletController) ListAccounts(string,
*waddrmgr.KeyScope) ([]*waddrmgr.AccountProperties, error) {
return nil, nil
}
// RequiredReserve currently returns a dummy value.
func (w *mockWalletController) RequiredReserve(uint32) btcutil.Amount {
return 0
}
// ListAddresses currently returns a dummy value.
func (w *mockWalletController) ListAddresses(string,
bool) (AccountAddressMap, error) {
return nil, nil
}
// ImportAccount currently returns a dummy value.
func (w *mockWalletController) ImportAccount(string, *hdkeychain.ExtendedKey,
uint32, *waddrmgr.AddressType, bool) (*waddrmgr.AccountProperties,
[]btcutil.Address, []btcutil.Address, error) {
return nil, nil, nil, nil
}
// ImportPublicKey currently returns a dummy value.
func (w *mockWalletController) ImportPublicKey(*btcec.PublicKey,
waddrmgr.AddressType) error {
return nil
}
// ImportTaprootScript currently returns a dummy value.
func (w *mockWalletController) ImportTaprootScript(waddrmgr.KeyScope,
*waddrmgr.Tapscript) (waddrmgr.ManagedAddress, error) {
return nil, nil
}
// SendOutputs currently returns dummy values.
func (w *mockWalletController) SendOutputs(fn.Set[wire.OutPoint], []*wire.TxOut,
chainfee.SatPerKWeight, int32, string,
base.CoinSelectionStrategy) (*wire.MsgTx, error) {
return nil, nil
}
// CreateSimpleTx currently returns dummy values.
func (w *mockWalletController) CreateSimpleTx(fn.Set[wire.OutPoint],
[]*wire.TxOut, chainfee.SatPerKWeight, int32,
base.CoinSelectionStrategy, bool) (*txauthor.AuthoredTx, error) {
return nil, nil
}
// ListUnspentWitness is called by the wallet when doing coin selection. We just
// need one unspent for the funding transaction.
func (w *mockWalletController) ListUnspentWitness(int32, int32,
string) ([]*Utxo, error) {
// If the mock already has a list of utxos, return it.
if w.Utxos != nil {
return w.Utxos, nil
}
// Otherwise create one to return.
utxo := &Utxo{
AddressType: WitnessPubKey,
Value: btcutil.Amount(10 * btcutil.SatoshiPerBitcoin),
PkScript: CoinPkScript,
OutPoint: wire.OutPoint{
Hash: chainhash.Hash{},
Index: w.index,
},
}
atomic.AddUint32(&w.index, 1)
var ret []*Utxo
ret = append(ret, utxo)
return ret, nil
}
// ListTransactionDetails currently returns dummy values.
func (w *mockWalletController) ListTransactionDetails(int32, int32,
string, uint32, uint32) ([]*TransactionDetail, uint64, uint64, error) {
return nil, 0, 0, nil
}
// LeaseOutput returns the current time and a nil error.
func (w *mockWalletController) LeaseOutput(wtxmgr.LockID, wire.OutPoint,
time.Duration) (time.Time, error) {
return time.Now(), nil
}
// ReleaseOutput currently does nothing.
func (w *mockWalletController) ReleaseOutput(wtxmgr.LockID,
wire.OutPoint) error {
return nil
}
func (w *mockWalletController) ListLeasedOutputs() (
[]*base.ListLeasedOutputResult, error) {
return nil, nil
}
// FundPsbt currently does nothing.
func (w *mockWalletController) FundPsbt(*psbt.Packet, int32,
chainfee.SatPerKWeight, string, *waddrmgr.KeyScope,
base.CoinSelectionStrategy, func(utxo wtxmgr.Credit) bool) (int32,
error) {
return 0, nil
}
// SignPsbt currently does nothing.
func (w *mockWalletController) SignPsbt(*psbt.Packet) ([]uint32, error) {
return nil, nil
}
// FinalizePsbt currently does nothing.
func (w *mockWalletController) FinalizePsbt(_ *psbt.Packet, _ string) error {
return nil
}
// DecorateInputs currently does nothing.
func (w *mockWalletController) DecorateInputs(*psbt.Packet, bool) error {
return nil
}
// PublishTransaction sends a transaction to the PublishedTransactions chan.
func (w *mockWalletController) PublishTransaction(tx *wire.MsgTx,
_ string) error {
w.PublishedTransactions <- tx
return nil
}
// GetTransactionDetails currently does nothing.
func (w *mockWalletController) GetTransactionDetails(*chainhash.Hash) (
*TransactionDetail, error) {
return nil, nil
}
// LabelTransaction currently does nothing.
func (w *mockWalletController) LabelTransaction(chainhash.Hash, string,
bool) error {
return nil
}
// SubscribeTransactions currently does nothing.
func (w *mockWalletController) SubscribeTransactions() (TransactionSubscription,
error) {
return nil, nil
}
// IsSynced currently returns dummy values.
func (w *mockWalletController) IsSynced() (bool, int64, error) {
return true, int64(0), nil
}
// GetRecoveryInfo currently returns dummy values.
func (w *mockWalletController) GetRecoveryInfo() (bool, float64, error) {
return true, float64(1), nil
}
// Start currently does nothing.
func (w *mockWalletController) Start() error {
return nil
}
// Stop currently does nothing.
func (w *mockWalletController) Stop() error {
return nil
}
func (w *mockWalletController) FetchTx(chainhash.Hash) (*wire.MsgTx, error) {
return nil, nil
}
func (w *mockWalletController) RemoveDescendants(*wire.MsgTx) error {
return nil
}
// FetchDerivationInfo queries for the wallet's knowledge of the passed
// pkScript and constructs the derivation info and returns it.
func (w *mockWalletController) FetchDerivationInfo(
pkScript []byte) (*psbt.Bip32Derivation, error) {
return nil, nil
}
func (w *mockWalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error {
return nil
}
// mockChainNotifier is a mock implementation of the ChainNotifier interface.
type mockChainNotifier struct {
SpendChan chan *chainntnfs.SpendDetail
EpochChan chan *chainntnfs.BlockEpoch
ConfChan chan *chainntnfs.TxConfirmation
}
// RegisterConfirmationsNtfn returns a ConfirmationEvent that contains a channel
// that the tx confirmation will go over.
func (c *mockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts ...chainntnfs.NotifierOption) (*chainntnfs.ConfirmationEvent,
error) {
return &chainntnfs.ConfirmationEvent{
Confirmed: c.ConfChan,
Cancel: func() {},
}, nil
}
// RegisterSpendNtfn returns a SpendEvent that contains a channel that the spend
// details will go over.
func (c *mockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
return &chainntnfs.SpendEvent{
Spend: c.SpendChan,
Cancel: func() {},
}, nil
}
// RegisterBlockEpochNtfn returns a BlockEpochEvent that contains a channel that
// block epochs will go over.
func (c *mockChainNotifier) RegisterBlockEpochNtfn(
blockEpoch *chainntnfs.BlockEpoch) (*chainntnfs.BlockEpochEvent,
error) {
return &chainntnfs.BlockEpochEvent{
Epochs: c.EpochChan,
Cancel: func() {},
}, nil
}
// Start currently returns a dummy value.
func (c *mockChainNotifier) Start() error {
return nil
}
// Started currently returns a dummy value.
func (c *mockChainNotifier) Started() bool {
return true
}
// Stop currently returns a dummy value.
func (c *mockChainNotifier) Stop() error {
return nil
}
type mockChainIO struct{}
func (*mockChainIO) GetBestBlock() (*chainhash.Hash, int32, error) {
return nil, 0, nil
}
func (*mockChainIO) GetUtxo(op *wire.OutPoint, _ []byte,
heightHint uint32, _ <-chan struct{}) (*wire.TxOut, error) {
return nil, nil
}
func (*mockChainIO) GetBlockHash(blockHeight int64) (*chainhash.Hash, error) {
return nil, nil
}
func (*mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock,
error) {
return nil, nil
}
func (*mockChainIO) GetBlockHeader(
blockHash *chainhash.Hash) (*wire.BlockHeader, error) {
return nil, nil
}
type MockAuxLeafStore struct{}
// A compile time check to ensure that MockAuxLeafStore implements the
// AuxLeafStore interface.
var _ AuxLeafStore = (*MockAuxLeafStore)(nil)
// FetchLeavesFromView attempts to fetch the auxiliary leaves that
// correspond to the passed aux blob, and pending original (unfiltered)
// HTLC view.
func (*MockAuxLeafStore) FetchLeavesFromView(
_ CommitDiffAuxInput) fn.Result[CommitDiffAuxResult] {
return fn.Ok(CommitDiffAuxResult{})
}
// FetchLeavesFromCommit attempts to fetch the auxiliary leaves that
// correspond to the passed aux blob, and an existing channel
// commitment.
func (*MockAuxLeafStore) FetchLeavesFromCommit(_ AuxChanState,
_ channeldb.ChannelCommitment, _ CommitmentKeyRing,
_ lntypes.ChannelParty) fn.Result[CommitDiffAuxResult] {
return fn.Ok(CommitDiffAuxResult{})
}
// FetchLeavesFromRevocation attempts to fetch the auxiliary leaves
// from a channel revocation that stores balance + blob information.
func (*MockAuxLeafStore) FetchLeavesFromRevocation(
_ *channeldb.RevocationLog) fn.Result[CommitDiffAuxResult] {
return fn.Ok(CommitDiffAuxResult{})
}
// ApplyHtlcView serves as the state transition function for the custom
// channel's blob. Given the old blob, and an HTLC view, then a new
// blob should be returned that reflects the pending updates.
func (*MockAuxLeafStore) ApplyHtlcView(
_ CommitDiffAuxInput) fn.Result[fn.Option[tlv.Blob]] {
return fn.Ok(fn.None[tlv.Blob]())
}
// EmptyMockJobHandler is a mock job handler that just sends an empty response
// to all jobs.
func EmptyMockJobHandler(jobs []AuxSigJob) {
for _, sigJob := range jobs {
sigJob.Resp <- AuxSigJobResp{}
}
}
// MockAuxSigner is a mock implementation of the AuxSigner interface.
type MockAuxSigner struct {
mock.Mock
jobHandlerFunc func([]AuxSigJob)
}
// NewAuxSignerMock creates a new mock aux signer with the given job handler.
func NewAuxSignerMock(jobHandler func([]AuxSigJob)) *MockAuxSigner {
return &MockAuxSigner{
jobHandlerFunc: jobHandler,
}
}
// SubmitSecondLevelSigBatch takes a batch of aux sign jobs and
// processes them asynchronously.
func (a *MockAuxSigner) SubmitSecondLevelSigBatch(chanState AuxChanState,
tx *wire.MsgTx, jobs []AuxSigJob) error {
args := a.Called(chanState, tx, jobs)
if a.jobHandlerFunc != nil {
a.jobHandlerFunc(jobs)
}
return args.Error(0)
}
// PackSigs takes a series of aux signatures and packs them into a
// single blob that can be sent alongside the CommitSig messages.
func (a *MockAuxSigner) PackSigs(
sigs []fn.Option[tlv.Blob]) fn.Result[fn.Option[tlv.Blob]] {
args := a.Called(sigs)
return args.Get(0).(fn.Result[fn.Option[tlv.Blob]])
}
// UnpackSigs takes a packed blob of signatures and returns the
// original signatures for each HTLC, keyed by HTLC index.
func (a *MockAuxSigner) UnpackSigs(
sigs fn.Option[tlv.Blob]) fn.Result[[]fn.Option[tlv.Blob]] {
args := a.Called(sigs)
return args.Get(0).(fn.Result[[]fn.Option[tlv.Blob]])
}
// VerifySecondLevelSigs attempts to synchronously verify a batch of aux
// sig jobs.
func (a *MockAuxSigner) VerifySecondLevelSigs(chanState AuxChanState,
tx *wire.MsgTx, jobs []AuxVerifyJob) error {
args := a.Called(chanState, tx, jobs)
return args.Error(0)
}
type MockAuxContractResolver struct{}
// ResolveContract is called to resolve a contract that needs
// additional information to resolve properly. If no extra information
// is required, a nil Result error is returned.
func (*MockAuxContractResolver) ResolveContract(
ResolutionReq) fn.Result[tlv.Blob] {
return fn.Ok[tlv.Blob](nil)
}
package lnwallet
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
)
// MusigCommitType is an enum that denotes if this is the local or remote
// commitment.
type MusigCommitType uint8
const (
// LocalMusigCommit denotes that this a session for the local
// commitment.
LocalMusigCommit MusigCommitType = iota
// RemoteMusigCommit denotes that this is a session for the remote
// commitment.
RemoteMusigCommit
)
var (
// ErrSessionNotFinalized is returned when the SignCommit method is
// called for a local commitment, without the session being finalized
// (missing nonce).
ErrSessionNotFinalized = fmt.Errorf("musig2 session not finalized")
)
// tapscriptRootToSignOpt is a function that takes a tapscript root and returns
// a MuSig2 sign opt that'll apply the tweak when signing+verifying.
func tapscriptRootToSignOpt(root chainhash.Hash) musig2.SignOption {
return musig2.WithTaprootSignTweak(root[:])
}
// TapscriptRootToTweak is a helper function that converts a tapscript root
// into a tweak that can be used with the MuSig2 API.
func TapscriptRootToTweak(root chainhash.Hash) input.MuSig2Tweaks {
return input.MuSig2Tweaks{
TaprootTweak: root[:],
}
}
// MusigPartialSig is a wrapper around the base musig2.PartialSignature type
// that also includes information about the set of nonces used, and also the
// signer. This allows us to implement the input.Signature interface, as that
// requires the ability to perform abstract verification based on a public key.
type MusigPartialSig struct {
// sig is the actual musig2 partial signature.
sig *musig2.PartialSignature
// signerNonce is the nonce used by the signer to generate the partial
// signature.
signerNonce lnwire.Musig2Nonce
// combinedNonce is the combined nonce of all signers.
combinedNonce lnwire.Musig2Nonce
// signerKeys is the set of public keys of all signers.
signerKeys []*btcec.PublicKey
// tapscriptTweak is an optional tweak, that if specified, will be used
// instead of the normal BIP 86 tweak when validating the signature.
tapscriptTweak fn.Option[chainhash.Hash]
}
// NewMusigPartialSig creates a new MuSig2 partial signature.
func NewMusigPartialSig(sig *musig2.PartialSignature, signerNonce,
combinedNonce lnwire.Musig2Nonce, signerKeys []*btcec.PublicKey,
tapscriptTweak fn.Option[chainhash.Hash]) *MusigPartialSig {
return &MusigPartialSig{
sig: sig,
signerNonce: signerNonce,
combinedNonce: combinedNonce,
signerKeys: signerKeys,
tapscriptTweak: tapscriptTweak,
}
}
// FromWireSig maps a wire partial sig to this internal type that we'll use to
// perform signature validation.
func (p *MusigPartialSig) FromWireSig(
sig *lnwire.PartialSigWithNonce) *MusigPartialSig {
p.sig = &musig2.PartialSignature{
S: &sig.Sig,
}
p.signerNonce = sig.Nonce
return p
}
// ToWireSig maps the partial signature to something that we can use to write
// out for the wire protocol.
func (p *MusigPartialSig) ToWireSig() *lnwire.PartialSigWithNonce {
return &lnwire.PartialSigWithNonce{
PartialSig: lnwire.NewPartialSig(*p.sig.S),
Nonce: p.signerNonce,
}
}
// Serialize serializes the musig2 partial signature. The serializing includes
// the signer's public nonce _and_ the partial signature. The final signature
// is always 98 bytes in length.
func (p *MusigPartialSig) Serialize() []byte {
var b bytes.Buffer
_ = p.ToWireSig().Encode(&b)
return b.Bytes()
}
// ToSchnorrShell converts the musig partial signature to a regular schnorr.
// This schnorr signature uses a zero value for the 'r' field, so we're just
// only using the last 32-bytes of the signature. This is useful when we need
// to convert an HTLC schnorr signature into something we can send using the
// existing messages.
func (p *MusigPartialSig) ToSchnorrShell() *schnorr.Signature {
var zeroVal btcec.FieldVal
return schnorr.NewSignature(&zeroVal, p.sig.S)
}
// FromSchnorrShell takes a schnorr signature and parses out the last 32 bytes
// as a normal musig2 partial signature.
func (p *MusigPartialSig) FromSchnorrShell(sig *schnorr.Signature) {
var (
partialS btcec.ModNScalar
partialSBytes [32]byte
)
copy(partialSBytes[:], sig.Serialize()[32:])
partialS.SetBytes(&partialSBytes)
p.sig = &musig2.PartialSignature{
S: &partialS,
}
}
// Verify attempts to verify the partial musig2 signature using the passed
// message and signer public key.
//
// NOTE: This implements the input.Signature interface.
func (p *MusigPartialSig) Verify(msg []byte, pub *btcec.PublicKey) bool {
var m [32]byte
copy(m[:], msg)
// If we have a tapscript tweak, then we'll use that as a tweak
// otherwise, we'll fall back to the normal BIP 86 sign tweak.
signOpts := fn.MapOption(tapscriptRootToSignOpt)(
p.tapscriptTweak,
).UnwrapOr(musig2.WithBip86SignTweak())
return p.sig.Verify(
p.signerNonce, p.combinedNonce, p.signerKeys, pub, m,
musig2.WithSortedKeys(), signOpts,
)
}
// MusigNoncePair holds the two nonces needed to sign/verify a new commitment
// state. The signer nonce is the nonce used by the signer (remote nonce), and
// the verification nonce, the nonce used by the verifier (local nonce).
type MusigNoncePair struct {
// SigningNonce is the nonce used by the signer to sign the commitment.
SigningNonce musig2.Nonces
// VerificationNonce is the nonce used by the verifier to verify the
// commitment.
VerificationNonce musig2.Nonces
}
// String returns a string representation of the MusigNoncePair.
func (n *MusigNoncePair) String() string {
return fmt.Sprintf("NoncePair(verification_nonce=%x, "+
"signing_nonce=%x)", n.VerificationNonce.PubNonce[:],
n.SigningNonce.PubNonce[:])
}
// TapscriptRootToTweak is a function that takes a MuSig2 taproot tweak and
// returns the root hash of the tapscript tree.
func muSig2TweakToRoot(tweak input.MuSig2Tweaks) chainhash.Hash {
var root chainhash.Hash
copy(root[:], tweak.TaprootTweak)
return root
}
// MusigSession abstracts over the details of a logical musig session. A single
// session is used for each commitment transactions. The sessions use a JIT
// nonce style, wherein part of the session can be created using only the
// verifier nonce. Once a new state is signed, then the signer nonce is
// generated. Similarly, the verifier then uses the received signer nonce to
// complete the session and verify the incoming signature.
type MusigSession struct {
// session is the backing musig2 session. We'll use this to interact
// with the musig2 signer.
session *input.MuSig2SessionInfo
// combinedNonce is the combined nonce of all signers.
combinedNonce lnwire.Musig2Nonce
// nonces is the set of nonces that'll be used to generate/verify the
// next commitment.
nonces MusigNoncePair
// inputTxOut is the funding input.
inputTxOut *wire.TxOut
// signerKeys is the set of public keys of all signers.
signerKeys []*btcec.PublicKey
// remoteKey is the key desc of the remote key.
remoteKey keychain.KeyDescriptor
// localKey is the key desc of the local key.
localKey keychain.KeyDescriptor
// signer is the signer that'll be used to interact with the musig
// session.
signer input.MuSig2Signer
// commitType tracks if this is the session for the local or remote
// commitment.
commitType MusigCommitType
// tapscriptTweak is an optional tweak, that if specified, will be used
// instead of the normal BIP 86 tweak when creating the MuSig2
// aggregate key and session.
tapscriptTweak fn.Option[input.MuSig2Tweaks]
}
// NewPartialMusigSession creates a new musig2 session given only the
// verification nonce (local nonce), and the other information that has already
// been bound to the session.
func NewPartialMusigSession(verificationNonce musig2.Nonces,
localKey, remoteKey keychain.KeyDescriptor, signer input.MuSig2Signer,
inputTxOut *wire.TxOut, commitType MusigCommitType,
tapscriptTweak fn.Option[input.MuSig2Tweaks]) *MusigSession {
signerKeys := []*btcec.PublicKey{localKey.PubKey, remoteKey.PubKey}
nonces := MusigNoncePair{
VerificationNonce: verificationNonce,
}
return &MusigSession{
nonces: nonces,
remoteKey: remoteKey,
localKey: localKey,
inputTxOut: inputTxOut,
signerKeys: signerKeys,
signer: signer,
commitType: commitType,
tapscriptTweak: tapscriptTweak,
}
}
// FinalizeSession finalizes the session given the signer nonce. This is
// called before signing or verifying a new commitment.
func (m *MusigSession) FinalizeSession(signingNonce musig2.Nonces) error {
var (
localNonce, remoteNonce musig2.Nonces
err error
)
// First, we'll stash the freshly generated signing nonce. Depending on
// who's commitment we're handling, this'll either be our generated
// nonce, or the one we just got from the remote party.
m.nonces.SigningNonce = signingNonce
switch m.commitType {
// If we're making a session for the remote commitment, then the nonce
// we use to sign is actually will be the signing nonce for the
// session, and their nonce the verification nonce.
case RemoteMusigCommit:
localNonce = m.nonces.SigningNonce
remoteNonce = m.nonces.VerificationNonce
// Otherwise, we're generating/receiving a signature for our local
// commitment (to broadcast), so now our verification nonce is the one
// we've already generated, and we want to bind their new signing
// nonce.
case LocalMusigCommit:
localNonce = m.nonces.VerificationNonce
remoteNonce = m.nonces.SigningNonce
}
tweakDesc := m.tapscriptTweak.UnwrapOr(input.MuSig2Tweaks{
TaprootBIP0086Tweak: true,
})
m.session, err = m.signer.MuSig2CreateSession(
input.MuSig2Version100RC2, m.localKey.KeyLocator, m.signerKeys,
&tweakDesc, [][musig2.PubNonceSize]byte{remoteNonce.PubNonce},
&localNonce,
)
if err != nil {
return err
}
// We'll need the raw combined nonces later to be able to verify
// partial signatures, and also combine partial signatures, so we'll
// generate it now ourselves.
aggNonce, err := musig2.AggregateNonces([][musig2.PubNonceSize]byte{
m.nonces.SigningNonce.PubNonce,
m.nonces.VerificationNonce.PubNonce,
})
if err != nil {
return nil
}
m.combinedNonce = aggNonce
return nil
}
// taprootKeyspendSighash generates the sighash for a taproot key spend. As
// this is a musig2 channel output, the keyspend is the only path we can take.
func taprootKeyspendSighash(tx *wire.MsgTx, pkScript []byte,
value int64) ([]byte, error) {
prevOutputFetcher := txscript.NewCannedPrevOutputFetcher(
pkScript, value,
)
sigHashes := txscript.NewTxSigHashes(tx, prevOutputFetcher)
return txscript.CalcTaprootSignatureHash(
sigHashes, txscript.SigHashDefault, tx, 0, prevOutputFetcher,
)
}
// SignCommit signs the passed commitment w/ the current signing (relative
// remote) nonce. Given nonces should only ever be used once, once the method
// returns a new nonce is returned, w/ the existing nonce blanked out.
func (m *MusigSession) SignCommit(tx *wire.MsgTx) (*MusigPartialSig, error) {
switch {
// If we already have a session, then we don't need to finalize as this
// was done up front (symmetric nonce case, like for co-op close).
case m.session == nil && m.commitType == RemoteMusigCommit:
// Before we can sign a new commitment, we'll need to generate
// a fresh nonce that'll be sent along side our signature. With
// the nonce in hand, we can finalize the session.
txHash := tx.TxHash()
signingNonce, err := musig2.GenNonces(
musig2.WithPublicKey(m.localKey.PubKey),
musig2.WithNonceAuxInput(txHash[:]),
)
if err != nil {
return nil, err
}
if err := m.FinalizeSession(*signingNonce); err != nil {
return nil, err
}
// Otherwise, we're trying to make a new commitment transaction without
// an active session, so we'll error out.
case m.session == nil:
return nil, ErrSessionNotFinalized
}
// Next we can sign, we'll need to generate the sighash for their
// commitment transaction.
sigHash, err := taprootKeyspendSighash(
tx, m.inputTxOut.PkScript, m.inputTxOut.Value,
)
if err != nil {
return nil, err
}
// Now that we have our session created, we'll use it to generate the
// initial partial signature over our sighash.
var sigHashMsg [32]byte
copy(sigHashMsg[:], sigHash)
walletLog.Infof("Generating new musig2 sig for session=%x, nonces=%s",
m.session.SessionID[:], m.nonces.String())
sig, err := m.signer.MuSig2Sign(
m.session.SessionID, sigHashMsg, false,
)
if err != nil {
return nil, err
}
tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak)
return NewMusigPartialSig(
sig, m.session.PublicNonce, m.combinedNonce, m.signerKeys,
tapscriptRoot,
), nil
}
// Refresh is called once we receive a new verification nonce from the remote
// party after sending a signature. This nonce will be coupled within the
// revoke-and-ack message of the remote party.
func (m *MusigSession) Refresh(verificationNonce *musig2.Nonces,
) (*MusigSession, error) {
return NewPartialMusigSession(
*verificationNonce, m.localKey, m.remoteKey, m.signer,
m.inputTxOut, m.commitType, m.tapscriptTweak,
), nil
}
// VerificationNonce returns the current verification nonce for the session.
func (m *MusigSession) VerificationNonce() *musig2.Nonces {
return &m.nonces.VerificationNonce
}
// musigSessionOpts is a set of options that can be used to modify calls to the
// musig session.
type musigSessionOpts struct {
// customRand is an optional custom random source that can be used to
// generate nonces via a counter scheme.
customRand io.Reader
}
// defaultMusigSessionOpts returns the default set of options for the musig
// session.
func defaultMusigSessionOpts() *musigSessionOpts {
return &musigSessionOpts{}
}
// MusigSessionOpt is a functional option that can be used to modify calls to
// the musig session.
type MusigSessionOpt func(*musigSessionOpts)
// WithLocalCounterNonce is used to generate local nonces based on the shachain
// producer and the current height. This allows us to not have to write secret
// nonce state to disk. Instead, we can use this to derive the nonce we need to
// sign and broadcast our own commitment transaction.
func WithLocalCounterNonce(targetHeight uint64,
shaGen shachain.Producer) MusigSessionOpt {
return func(opt *musigSessionOpts) {
nextPreimage, _ := shaGen.AtIndex(targetHeight)
opt.customRand = bytes.NewBuffer(nextPreimage[:])
}
}
// invalidPartialSigError is used to return additional debug information to a
// caller that encounters an invalid partial sig.
type invalidPartialSigError struct {
partialSig []byte
sigHash []byte
signingNonce [musig2.PubNonceSize]byte
verificationNonce [musig2.PubNonceSize]byte
}
// Error returns the error string for the partial sig error.
func (i invalidPartialSigError) Error() string {
return fmt.Sprintf("invalid partial sig: partial_sig=%x, "+
"sig_hash=%x, signing_nonce=%x, verification_nonce=%x",
i.partialSig, i.sigHash, i.signingNonce[:],
i.verificationNonce[:])
}
// VerifyCommitSig attempts to verify the passed partial signature against the
// passed commitment transaction. A keyspend sighash is assumed to generate the
// signed message. As we never re-use nonces, a new verification nonce (our
// relative local nonce) returned to transmit to the remote party, which allows
// them to generate another signature.
func (m *MusigSession) VerifyCommitSig(commitTx *wire.MsgTx,
sig *lnwire.PartialSigWithNonce,
musigOpts ...MusigSessionOpt) (*musig2.Nonces, error) {
opts := defaultMusigSessionOpts()
for _, optFunc := range musigOpts {
optFunc(opts)
}
if sig == nil {
return nil, fmt.Errorf("sig not provided")
}
// Before we can verify the signature, we'll need to finalize the
// session by binding the remote party's provided signing nonce.
if err := m.FinalizeSession(musig2.Nonces{
PubNonce: sig.Nonce,
}); err != nil {
return nil, err
}
// When we verify a commitment signature, we always assume that we're
// verifying a signature on our local commitment. Therefore, we'll use:
// their remote nonce, and also public key.
tapscriptRoot := fn.MapOption(muSig2TweakToRoot)(m.tapscriptTweak)
partialSig := NewMusigPartialSig(
&musig2.PartialSignature{S: &sig.Sig},
m.nonces.SigningNonce.PubNonce, m.combinedNonce, m.signerKeys,
tapscriptRoot,
)
// With the partial sig loaded with the proper context, we'll now
// generate the sighash that the remote party should have signed.
sigHash, err := taprootKeyspendSighash(
commitTx, m.inputTxOut.PkScript, m.inputTxOut.Value,
)
if err != nil {
return nil, err
}
walletLog.Infof("Verifying new musig2 sig for session=%x, nonce=%s",
m.session.SessionID[:], m.nonces.String())
if partialSig == nil {
return nil, fmt.Errorf("partial sig not set")
}
if !partialSig.Verify(sigHash, m.remoteKey.PubKey) {
return nil, &invalidPartialSigError{
partialSig: partialSig.Serialize(),
sigHash: sigHash,
verificationNonce: m.nonces.VerificationNonce.PubNonce,
signingNonce: m.nonces.SigningNonce.PubNonce,
}
}
nonceOpts := []musig2.NonceGenOption{
musig2.WithPublicKey(m.localKey.PubKey),
}
if opts.customRand != nil {
nonceOpts = append(
nonceOpts, musig2.WithCustomRand(opts.customRand),
)
}
// At this point, we know that their signature is valid, so we'll
// generate another verification nonce for them, so they can generate a
// new state transition.
nextVerificationNonce, err := musig2.GenNonces(nonceOpts...)
if err != nil {
return nil, fmt.Errorf("unable to gen new nonce: %w", err)
}
return nextVerificationNonce, nil
}
// CombineSigs combines the passed partial signatures into a valid schnorr
// signature.
func (m *MusigSession) CombineSigs(sigs ...*musig2.PartialSignature,
) (*schnorr.Signature, error) {
sig, _, err := m.signer.MuSig2CombineSig(
m.session.SessionID, sigs,
)
if err != nil {
return nil, err
}
return sig, nil
}
// MusigSessionCfg is used to create a new musig2 pair session. It contains the
// keys for both parties, as well as their initial verification nonces.
type MusigSessionCfg struct {
// LocalKey is a key desc for the local key.
LocalKey keychain.KeyDescriptor
// RemoteKey is a key desc for the remote key.
RemoteKey keychain.KeyDescriptor
// LocalNonce is the local party's initial verification nonce.
LocalNonce musig2.Nonces
// RemoteNonce is the remote party's initial verification nonce.
RemoteNonce musig2.Nonces
// Signer is the signer that will be used to generate the session.
Signer input.MuSig2Signer
// InputTxOut is the output that we're signing for. This will be the
// funding input.
InputTxOut *wire.TxOut
// TapscriptTweak is an optional tweak that can be used to modify the
// MuSig2 public key used in the session.
TapscriptTweak fn.Option[chainhash.Hash]
}
// MusigPairSession houses the two musig2 sessions needed to do funding and
// drive forward the state machine. The local session is used to verify
// incoming commitment states. The remote session is used to propose new
// commitment states to the remote party.
type MusigPairSession struct {
// LocalSession is the local party's musig2 session.
LocalSession *MusigSession
// RemoteSession is the remote party's musig2 session.
RemoteSession *MusigSession
// signer is the signer that will be used to drive the session.
signer input.MuSig2Signer
}
// NewMusigPairSession creates a new musig2 pair session.
func NewMusigPairSession(cfg *MusigSessionCfg) *MusigPairSession {
// Given the config passed in, we'll now create our two sessions: one
// for the local commit, and one for the remote commit.
//
// Both sessions will be created using only the verification nonce for
// the local+remote party.
tapscriptTweak := fn.MapOption(TapscriptRootToTweak)(cfg.TapscriptTweak)
localSession := NewPartialMusigSession(
cfg.LocalNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer,
cfg.InputTxOut, LocalMusigCommit, tapscriptTweak,
)
remoteSession := NewPartialMusigSession(
cfg.RemoteNonce, cfg.LocalKey, cfg.RemoteKey, cfg.Signer,
cfg.InputTxOut, RemoteMusigCommit, tapscriptTweak,
)
return &MusigPairSession{
LocalSession: localSession,
RemoteSession: remoteSession,
signer: cfg.Signer,
}
}
package lnwallet
import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/mempool"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// RoutingFee100PercentUpTo is the cut-off amount we allow 100% fees to
// be charged up to.
RoutingFee100PercentUpTo = lnwire.NewMSatFromSatoshis(1_000)
)
const (
// DefaultRoutingFeePercentage is the default off-chain routing fee we
// allow to be charged for a payment over the RoutingFee100PercentUpTo
// size.
DefaultRoutingFeePercentage lnwire.MilliSatoshi = 5
)
// DefaultRoutingFeeLimitForAmount returns the default off-chain routing fee
// limit lnd uses if the user does not specify a limit manually. The fee is
// amount dependent because of the base routing fee that is set on many
// channels. For example the default base fee is 1 satoshi. So sending a payment
// of one satoshi will cost 1 satoshi in fees over most channels, which comes to
// a fee of 100%. That's why for very small amounts we allow 100% fee.
func DefaultRoutingFeeLimitForAmount(a lnwire.MilliSatoshi) lnwire.MilliSatoshi {
// Allow 100% fees up to a certain amount to accommodate for base fees.
if a <= RoutingFee100PercentUpTo {
return a
}
// Everything larger than the cut-off amount will get a default fee
// percentage.
return a * DefaultRoutingFeePercentage / 100
}
// DustLimitForSize retrieves the dust limit for a given pkscript size. Given
// the size, it automatically determines whether the script is a witness script
// or not. It calls btcd's GetDustThreshold method under the hood. It must be
// called with a proper size parameter or else a panic occurs.
func DustLimitForSize(scriptSize int) btcutil.Amount {
var (
dustlimit btcutil.Amount
pkscript []byte
)
// With the size of the script, determine which type of pkscript to
// create. This will be used in the call to GetDustThreshold. We pass
// in an empty byte slice since the contents of the script itself don't
// matter.
switch scriptSize {
case input.P2WPKHSize:
pkscript, _ = input.WitnessPubKeyHash([]byte{})
case input.P2WSHSize:
pkscript, _ = input.WitnessScriptHash([]byte{})
case input.P2SHSize:
pkscript, _ = input.GenerateP2SH([]byte{})
case input.P2PKHSize:
pkscript, _ = input.GenerateP2PKH([]byte{})
case input.UnknownWitnessSize:
pkscript, _ = input.GenerateUnknownWitness()
default:
panic("invalid script size")
}
// Call GetDustThreshold with a TxOut containing the generated
// pkscript.
txout := &wire.TxOut{PkScript: pkscript}
dustlimit = btcutil.Amount(mempool.GetDustThreshold(txout))
return dustlimit
}
// DustLimitUnknownWitness returns the dust limit for an UnknownWitnessSize.
func DustLimitUnknownWitness() btcutil.Amount {
return DustLimitForSize(input.UnknownWitnessSize)
}
package lnwallet
import (
"crypto/sha256"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
// updateType is the exact type of an entry within the shared HTLC log.
type updateType uint8
const (
// Add is an update type that adds a new HTLC entry into the log.
// Either side can add a new pending HTLC by adding a new Add entry
// into their update log.
Add updateType = iota
// Fail is an update type which removes a prior HTLC entry from the
// log. Adding a Fail entry to one's log will modify the _remote_
// party's update log once a new commitment view has been evaluated
// which contains the Fail entry.
Fail
// MalformedFail is an update type which removes a prior HTLC entry
// from the log. Adding a MalformedFail entry to one's log will modify
// the _remote_ party's update log once a new commitment view has been
// evaluated which contains the MalformedFail entry. The difference
// from Fail type lie in the different data we have to store.
MalformedFail
// Settle is an update type which settles a prior HTLC crediting the
// balance of the receiving node. Adding a Settle entry to a log will
// result in the settle entry being removed on the log as well as the
// original add entry from the remote party's log after the next state
// transition.
Settle
// FeeUpdate is an update type sent by the channel initiator that
// updates the fee rate used when signing the commitment transaction.
FeeUpdate
)
// String returns a human readable string that uniquely identifies the target
// update type.
func (u updateType) String() string {
switch u {
case Add:
return "Add"
case Fail:
return "Fail"
case MalformedFail:
return "MalformedFail"
case Settle:
return "Settle"
case FeeUpdate:
return "FeeUpdate"
default:
return "<unknown type>"
}
}
// paymentDescriptor represents a commitment state update which either adds,
// settles, or removes an HTLC. paymentDescriptors encapsulate all necessary
// metadata w.r.t to an HTLC, and additional data pairing a settle message to
// the original added HTLC.
//
// TODO(roasbeef): LogEntry interface??
// - need to separate attrs for cancel/add/settle/feeupdate
type paymentDescriptor struct {
// ChanID is the ChannelID of the LightningChannel that this
// paymentDescriptor belongs to. We track this here so we can
// reconstruct the Messages that this paymentDescriptor is built from.
ChanID lnwire.ChannelID
// RHash is the payment hash for this HTLC. The HTLC can be settled iff
// the preimage to this hash is presented.
RHash PaymentHash
// RPreimage is the preimage that settles the HTLC pointed to within the
// log by the ParentIndex.
RPreimage PaymentHash
// Timeout is the absolute timeout in blocks, after which this HTLC
// expires.
Timeout uint32
// Amount is the HTLC amount in milli-satoshis.
Amount lnwire.MilliSatoshi
// LogIndex is the log entry number that his HTLC update has within the
// log. Depending on if IsIncoming is true, this is either an entry the
// remote party added, or one that we added locally.
LogIndex uint64
// HtlcIndex is the index within the main update log for this HTLC.
// Entries within the log of type Add will have this field populated,
// as other entries will point to the entry via this counter.
//
// NOTE: This field will only be populate if EntryType is Add.
HtlcIndex uint64
// ParentIndex is the HTLC index of the entry that this update settles
// or times out.
//
// NOTE: This field will only be populate if EntryType is Fail or
// Settle.
ParentIndex uint64
// SourceRef points to an Add update in a forwarding package owned by
// this channel.
//
// NOTE: This field will only be populated if EntryType is Fail or
// Settle.
SourceRef *channeldb.AddRef
// DestRef points to a Fail/Settle update in another link's forwarding
// package.
//
// NOTE: This field will only be populated if EntryType is Fail or
// Settle, and the forwarded Add successfully included in an outgoing
// link's commitment txn.
DestRef *channeldb.SettleFailRef
// OpenCircuitKey references the incoming Chan/HTLC ID of an Add HTLC
// packet delivered by the switch.
//
// NOTE: This field is only populated for payment descriptors in the
// *local* update log, and if the Add packet was delivered by the
// switch.
OpenCircuitKey *models.CircuitKey
// ClosedCircuitKey references the incoming Chan/HTLC ID of the Add HTLC
// that opened the circuit.
//
// NOTE: This field is only populated for payment descriptors in the
// *local* update log, and if settle/fails have a committed circuit in
// the circuit map.
ClosedCircuitKey *models.CircuitKey
// localOutputIndex is the output index of this HTLc output in the
// commitment transaction of the local node.
//
// NOTE: If the output is dust from the PoV of the local commitment
// chain, then this value will be -1.
localOutputIndex int32
// remoteOutputIndex is the output index of this HTLC output in the
// commitment transaction of the remote node.
//
// NOTE: If the output is dust from the PoV of the remote commitment
// chain, then this value will be -1.
remoteOutputIndex int32
// sig is the signature for the second-level HTLC transaction that
// spends the version of this HTLC on the commitment transaction of the
// local node. This signature is generated by the remote node and
// stored by the local node in the case that local node needs to
// broadcast their commitment transaction.
sig input.Signature
// addCommitHeight[Remote|Local] encodes the height of the commitment
// which included this HTLC on either the remote or local commitment
// chain. This value is used to determine when an HTLC is fully
// "locked-in".
addCommitHeights lntypes.Dual[uint64]
// removeCommitHeight[Remote|Local] encodes the height of the
// commitment which removed the parent pointer of this
// paymentDescriptor either due to a timeout or a settle. Once both
// these heights are below the tail of both chains, the log entries can
// safely be removed.
removeCommitHeights lntypes.Dual[uint64]
// OnionBlob is an opaque blob which is used to complete multi-hop
// routing.
//
// NOTE: Populated only on add payment descriptor entry types.
OnionBlob [lnwire.OnionPacketSize]byte
// ShaOnionBlob is a sha of the onion blob.
//
// NOTE: Populated only in payment descriptor with MalformedFail type.
ShaOnionBlob [sha256.Size]byte
// FailReason stores the reason why a particular payment was canceled.
//
// NOTE: Populate only in fail payment descriptor entry types.
FailReason []byte
// FailCode stores the code why a particular payment was canceled.
//
// NOTE: Populated only in payment descriptor with MalformedFail type.
FailCode lnwire.FailCode
// [our|their|]PkScript are the raw public key scripts that encodes the
// redemption rules for this particular HTLC. These fields will only be
// populated iff the EntryType of this paymentDescriptor is Add.
// ourPkScript is the ourPkScript from the context of our local
// commitment chain. theirPkScript is the latest pkScript from the
// context of the remote commitment chain.
//
// NOTE: These values may change within the logs themselves, however,
// they'll stay consistent within the commitment chain entries
// themselves.
ourPkScript []byte
ourWitnessScript []byte
theirPkScript []byte
theirWitnessScript []byte
// EntryType denotes the exact type of the paymentDescriptor. In the
// case of a Timeout, or Settle type, then the Parent field will point
// into the log to the HTLC being modified.
EntryType updateType
// isForwarded denotes if an incoming HTLC has been forwarded to any
// possible upstream peers in the route.
isForwarded bool
// BlindingPoint is an optional ephemeral key used in route blinding.
// This value is set for nodes that are relaying payments inside of a
// blinded route (ie, not the introduction node) from update_add_htlc's
// TLVs.
BlindingPoint lnwire.BlindingPointRecord
// CustomRecords also stores the set of optional custom records that
// may have been attached to a sent HTLC.
CustomRecords lnwire.CustomRecords
}
// toLogUpdate recovers the underlying LogUpdate from the paymentDescriptor.
// This operation is lossy and will forget some extra information tracked by the
// paymentDescriptor but the function is total in that all paymentDescriptors
// can be converted back to LogUpdates.
func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate {
var msg lnwire.Message
switch pd.EntryType {
case Add:
msg = &lnwire.UpdateAddHTLC{
ChanID: pd.ChanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
PaymentHash: pd.RHash,
Expiry: pd.Timeout,
OnionBlob: pd.OnionBlob,
BlindingPoint: pd.BlindingPoint,
CustomRecords: pd.CustomRecords.Copy(),
}
case Settle:
msg = &lnwire.UpdateFulfillHTLC{
ChanID: pd.ChanID,
ID: pd.ParentIndex,
PaymentPreimage: pd.RPreimage,
}
case Fail:
msg = &lnwire.UpdateFailHTLC{
ChanID: pd.ChanID,
ID: pd.ParentIndex,
Reason: pd.FailReason,
}
case MalformedFail:
msg = &lnwire.UpdateFailMalformedHTLC{
ChanID: pd.ChanID,
ID: pd.ParentIndex,
ShaOnionBlob: pd.ShaOnionBlob,
FailureCode: pd.FailCode,
}
case FeeUpdate:
// The Amount field holds the feerate denominated in
// msat. Since feerates are only denominated in sat/kw,
// we can convert it without loss of precision.
msg = &lnwire.UpdateFee{
ChanID: pd.ChanID,
FeePerKw: uint32(pd.Amount.ToSatoshis()),
}
}
return channeldb.LogUpdate{
LogIndex: pd.LogIndex,
UpdateMsg: msg,
}
}
// setCommitHeight updates the appropriate addCommitHeight and/or
// removeCommitHeight for whoseCommitChain and locks it in at nextHeight.
func (pd *paymentDescriptor) setCommitHeight(
whoseCommitChain lntypes.ChannelParty, nextHeight uint64) {
switch pd.EntryType {
case Add:
pd.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
case Settle, Fail, MalformedFail:
pd.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
case FeeUpdate:
// Fee updates are applied for all commitments
// after they are sent/received, so we consider
// them being added and removed at the same
// height.
pd.addCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
pd.removeCommitHeights.SetForParty(
whoseCommitChain, nextHeight,
)
}
}
package lnwallet
import (
"errors"
"fmt"
"net"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/lightningnetwork/lnd/lnwire"
)
// CommitmentType is an enum indicating the commitment type we should use for
// the channel we are opening.
type CommitmentType int
const (
// CommitmentTypeLegacy is the legacy commitment format with a tweaked
// to_remote key.
CommitmentTypeLegacy CommitmentType = iota
// CommitmentTypeTweakless is a newer commitment format where the
// to_remote key is static.
CommitmentTypeTweakless
// CommitmentTypeAnchorsZeroFeeHtlcTx is a commitment type that is an
// extension of the outdated CommitmentTypeAnchors, which in addition
// requires second-level HTLC transactions to be signed using a
// zero-fee.
CommitmentTypeAnchorsZeroFeeHtlcTx
// CommitmentTypeScriptEnforcedLease is a commitment type that builds
// upon CommitmentTypeTweakless and CommitmentTypeAnchorsZeroFeeHtlcTx,
// which in addition requires a CLTV clause to spend outputs paying to
// the channel initiator. This is intended for use on leased channels to
// guarantee that the channel initiator has no incentives to close a
// leased channel before its maturity date.
CommitmentTypeScriptEnforcedLease
// CommitmentTypeSimpleTaproot is the base commitment type for the
// channels that use a musig2 funding output and the tapscript tree
// where relevant for the commitment transaction pk scripts.
CommitmentTypeSimpleTaproot
// CommitmentTypeSimpleTaprootOverlay builds on the existing
// CommitmentTypeSimpleTaproot type but layers on a special overlay
// protocol.
CommitmentTypeSimpleTaprootOverlay
)
// HasStaticRemoteKey returns whether the commitment type supports remote
// outputs backed by static keys.
func (c CommitmentType) HasStaticRemoteKey() bool {
switch c {
case CommitmentTypeTweakless,
CommitmentTypeAnchorsZeroFeeHtlcTx,
CommitmentTypeScriptEnforcedLease,
CommitmentTypeSimpleTaproot,
CommitmentTypeSimpleTaprootOverlay:
return true
default:
return false
}
}
// HasAnchors returns whether the commitment type supports anchor outputs.
func (c CommitmentType) HasAnchors() bool {
switch c {
case CommitmentTypeAnchorsZeroFeeHtlcTx,
CommitmentTypeScriptEnforcedLease,
CommitmentTypeSimpleTaproot,
CommitmentTypeSimpleTaprootOverlay:
return true
default:
return false
}
}
// IsTaproot returns true if the channel type is a taproot channel.
func (c CommitmentType) IsTaproot() bool {
return c == CommitmentTypeSimpleTaproot ||
c == CommitmentTypeSimpleTaprootOverlay
}
// String returns the name of the CommitmentType.
func (c CommitmentType) String() string {
switch c {
case CommitmentTypeLegacy:
return "legacy"
case CommitmentTypeTweakless:
return "tweakless"
case CommitmentTypeAnchorsZeroFeeHtlcTx:
return "anchors-zero-fee-second-level"
case CommitmentTypeScriptEnforcedLease:
return "script-enforced-lease"
case CommitmentTypeSimpleTaproot:
return "simple-taproot"
case CommitmentTypeSimpleTaprootOverlay:
return "simple-taproot-overlay"
default:
return "invalid"
}
}
// ReservationState is a type that represents the current state of a channel
// reservation within the funding workflow.
type ReservationState int
const (
// WaitingToSend is the state either the funder/fundee is in after
// creating a reservation, but hasn't sent a message yet.
WaitingToSend ReservationState = iota
// SentOpenChannel is the state the funder is in after sending the
// OpenChannel message.
SentOpenChannel
// SentAcceptChannel is the state the fundee is in after sending the
// AcceptChannel message.
SentAcceptChannel
// SentFundingCreated is the state the funder is in after sending the
// FundingCreated message.
SentFundingCreated
)
// ChannelContribution is the primary constituent of the funding workflow
// within lnwallet. Each side first exchanges their respective contributions
// along with channel specific parameters like the min fee/KB. Once
// contributions have been exchanged, each side will then produce signatures
// for all their inputs to the funding transactions, and finally a signature
// for the other party's version of the commitment transaction.
type ChannelContribution struct {
// FundingOutpoint is the amount of funds contributed to the funding
// transaction.
FundingAmount btcutil.Amount
// Inputs to the funding transaction.
Inputs []*wire.TxIn
// ChangeOutputs are the Outputs to be used in the case that the total
// value of the funding inputs is greater than the total potential
// channel capacity.
ChangeOutputs []*wire.TxOut
// FirstCommitmentPoint is the first commitment point that will be used
// to create the revocation key in the first commitment transaction we
// send to the remote party.
FirstCommitmentPoint *btcec.PublicKey
// ChannelConfig is the concrete contribution that this node is
// offering to the channel. This includes all the various constraints
// such as the min HTLC, and also all the keys which will be used for
// the duration of the channel.
*channeldb.ChannelConfig
// UpfrontShutdown is an optional address to which the channel should be
// paid out to on cooperative close.
UpfrontShutdown lnwire.DeliveryAddress
// LocalNonce is populated if the channel type is a simple taproot
// channel. This stores the public (and secret) nonce that will be used
// to generate commitments for the local party.
LocalNonce *musig2.Nonces
}
// toChanConfig returns the raw channel configuration generated by a node's
// contribution to the channel.
func (c *ChannelContribution) toChanConfig() channeldb.ChannelConfig {
return *c.ChannelConfig
}
// ChannelReservation represents an intent to open a lightning payment channel
// with a counterparty. The funding processes from reservation to channel opening
// is a 3-step process. In order to allow for full concurrency during the
// reservation workflow, resources consumed by a contribution are "locked"
// themselves. This prevents a number of race conditions such as two funding
// transactions double-spending the same input. A reservation can also be
// canceled, which removes the resources from limbo, allowing another
// reservation to claim them.
//
// The reservation workflow consists of the following three steps:
// 1. lnwallet.InitChannelReservation
// * One requests the wallet to allocate the necessary resources for a
// channel reservation. These resources are put in limbo for the lifetime
// of a reservation.
// * Once completed the reservation will have the wallet's contribution
// accessible via the .OurContribution() method. This contribution
// contains the necessary items to allow the remote party to build both
// the funding, and commitment transactions.
// 2. ChannelReservation.ProcessContribution/ChannelReservation.ProcessSingleContribution
// * The counterparty presents their contribution to the payment channel.
// This allows us to build the funding, and commitment transactions
// ourselves.
// * We're now able to sign our inputs to the funding transactions, and
// the counterparty's version of the commitment transaction.
// * All signatures crafted by us, are now available via .OurSignatures().
// 3. ChannelReservation.CompleteReservation/ChannelReservation.CompleteReservationSingle
// * The final step in the workflow. The counterparty presents the
// signatures for all their inputs to the funding transaction, as well
// as a signature to our version of the commitment transaction.
// * We then verify the validity of all signatures before considering the
// channel "open".
type ChannelReservation struct {
// This mutex MUST be held when either reading or modifying any of the
// fields below.
sync.RWMutex
// fundingTx is the funding transaction for this pending channel.
fundingTx *wire.MsgTx
// In order of sorted inputs. Sorting is done in accordance
// to BIP-69: https://github.com/bitcoin/bips/blob/master/bip-0069.mediawiki.
ourFundingInputScripts []*input.Script
theirFundingInputScripts []*input.Script
// Our signature for their version of the commitment transaction.
ourCommitmentSig input.Signature
theirCommitmentSig input.Signature
ourContribution *ChannelContribution
theirContribution *ChannelContribution
partialState *channeldb.OpenChannel
nodeAddr net.Addr
// The ID of this reservation, used to uniquely track the reservation
// throughout its lifetime.
reservationID uint64
// pendingChanID is the pending channel ID for this channel as
// identified within the wire protocol.
pendingChanID [32]byte
// pushMSat the amount of milli-satoshis that should be pushed to the
// responder of a single funding channel as part of the initial
// commitment state.
pushMSat lnwire.MilliSatoshi
wallet *LightningWallet
chanFunder chanfunding.Assembler
fundingIntent chanfunding.Intent
// nextRevocationKeyLoc stores the key locator information for this
// channel.
nextRevocationKeyLoc keychain.KeyLocator
musigSessions *MusigPairSession
state ReservationState
}
// NewChannelReservation creates a new channel reservation. This function is
// used only internally by lnwallet. In order to concurrent safety, the
// creation of all channel reservations should be carried out via the
// lnwallet.InitChannelReservation interface.
func NewChannelReservation(capacity, localFundingAmt btcutil.Amount,
wallet *LightningWallet, id uint64, chainHash *chainhash.Hash,
thawHeight uint32, req *InitFundingReserveMsg) (*ChannelReservation,
error) {
var (
ourBalance lnwire.MilliSatoshi
theirBalance lnwire.MilliSatoshi
initiator bool
)
// Based on the channel type, we determine the initial commit weight
// and fee.
commitWeight := input.CommitWeight
if req.CommitType.IsTaproot() {
commitWeight = input.TaprootCommitWeight
} else if req.CommitType.HasAnchors() {
commitWeight = input.AnchorCommitWeight
}
commitFee := req.CommitFeePerKw.FeeForWeight(
lntypes.WeightUnit(commitWeight),
)
localFundingMSat := lnwire.NewMSatFromSatoshis(localFundingAmt)
// TODO(halseth): make method take remote funding amount directly
// instead of inferring it from capacity and local amt.
capacityMSat := lnwire.NewMSatFromSatoshis(capacity)
// The total fee paid by the initiator will be the commitment fee in
// addition to the two anchor outputs.
feeMSat := lnwire.NewMSatFromSatoshis(commitFee)
if req.CommitType.HasAnchors() {
feeMSat += 2 * lnwire.NewMSatFromSatoshis(AnchorSize)
}
// Used to cut down on verbosity.
defaultDust := DustLimitUnknownWitness()
// If we're the responder to a single-funder reservation, then we have
// no initial balance in the channel unless the remote party is pushing
// some funds to us within the first commitment state.
if localFundingAmt == 0 {
ourBalance = req.PushMSat
theirBalance = capacityMSat - feeMSat - req.PushMSat
initiator = false
// If the responder doesn't have enough funds to actually pay
// the fees, then we'll bail our early.
if int64(theirBalance) < 0 {
return nil, ErrFunderBalanceDust(
int64(commitFee), int64(theirBalance.ToSatoshis()),
int64(2*defaultDust),
)
}
} else {
// TODO(roasbeef): need to rework fee structure in general and
// also when we "unlock" dual funder within the daemon
if capacity == localFundingAmt {
// If we're initiating a single funder workflow, then
// we pay all the initial fees within the commitment
// transaction. We also deduct our balance by the
// amount pushed as part of the initial state.
ourBalance = capacityMSat - feeMSat - req.PushMSat
theirBalance = req.PushMSat
} else {
// Otherwise, this is a dual funder workflow where both
// slides split the amount funded and the commitment
// fee.
ourBalance = localFundingMSat - (feeMSat / 2)
theirBalance = capacityMSat - localFundingMSat - (feeMSat / 2) + req.PushMSat
}
initiator = true
// If we, the initiator don't have enough funds to actually pay
// the fees, then we'll exit with an error.
if int64(ourBalance) < 0 {
return nil, ErrFunderBalanceDust(
int64(commitFee), int64(ourBalance),
int64(2*defaultDust),
)
}
}
// If we're the initiator and our starting balance within the channel
// after we take account of fees is below 2x the dust limit, then we'll
// reject this channel creation request.
//
// TODO(roasbeef): reject if 30% goes to fees? dust channel
if initiator && ourBalance.ToSatoshis() <= 2*defaultDust {
return nil, ErrFunderBalanceDust(
int64(commitFee),
int64(ourBalance.ToSatoshis()),
int64(2*defaultDust),
)
}
// Similarly we ensure their balance is reasonable if we are not the
// initiator.
if !initiator && theirBalance.ToSatoshis() <= 2*defaultDust {
return nil, ErrFunderBalanceDust(
int64(commitFee),
int64(theirBalance.ToSatoshis()),
int64(2*defaultDust),
)
}
// Next we'll set the channel type based on what we can ascertain about
// the balances/push amount within the channel.
var chanType channeldb.ChannelType
// If either of the balances are zero at this point, or we have a
// non-zero push amt (there's no pushing for dual funder), then this is
// a single-funder channel.
if ourBalance == 0 || theirBalance == 0 || req.PushMSat != 0 {
// Both the tweakless type and the anchor type is tweakless,
// hence set the bit.
if req.CommitType.HasStaticRemoteKey() {
chanType |= channeldb.SingleFunderTweaklessBit
} else {
chanType |= channeldb.SingleFunderBit
}
switch a := req.ChanFunder.(type) {
// The first channels of a batch shouldn't publish the batch TX
// to avoid problems if some of the funding flows can't be
// completed. Only the last channel of a batch should publish.
case chanfunding.ConditionalPublishAssembler:
if !a.ShouldPublishFundingTx() {
chanType |= channeldb.NoFundingTxBit
}
// Normal funding flow, the assembler creates a TX from the
// internal wallet.
case chanfunding.FundingTxAssembler:
// Do nothing, a FundingTxAssembler has the transaction.
// If this intent isn't one that's able to provide us with a
// funding transaction, then we'll set the chanType bit to
// signal that we don't have access to one.
default:
chanType |= channeldb.NoFundingTxBit
}
} else {
// Otherwise, this is a dual funder channel, and no side is
// technically the "initiator"
initiator = false
chanType |= channeldb.DualFunderBit
}
// We are adding anchor outputs to our commitment. We only support this
// in combination with zero-fee second-levels HTLCs.
if req.CommitType.HasAnchors() {
chanType |= channeldb.AnchorOutputsBit
chanType |= channeldb.ZeroHtlcTxFeeBit
}
// Set the appropriate LeaseExpiration/Frozen bit based on the
// reservation parameters.
if req.CommitType == CommitmentTypeScriptEnforcedLease {
if thawHeight == 0 {
return nil, errors.New("missing absolute expiration " +
"for script enforced lease commitment type")
}
chanType |= channeldb.LeaseExpirationBit
} else if thawHeight > 0 {
chanType |= channeldb.FrozenBit
}
if req.CommitType.IsTaproot() {
chanType |= channeldb.SimpleTaprootFeatureBit
}
if req.ZeroConf {
chanType |= channeldb.ZeroConfBit
}
if req.OptionScidAlias {
chanType |= channeldb.ScidAliasChanBit
}
if req.ScidAliasFeature {
chanType |= channeldb.ScidAliasFeatureBit
}
taprootOverlay := req.CommitType == CommitmentTypeSimpleTaprootOverlay
switch {
case taprootOverlay && req.TapscriptRoot.IsNone():
fallthrough
case !taprootOverlay && req.TapscriptRoot.IsSome():
return nil, fmt.Errorf("taproot overlay chans must be set " +
"with tapscript root")
case taprootOverlay && req.TapscriptRoot.IsSome():
chanType |= channeldb.TapscriptRootBit
}
return &ChannelReservation{
ourContribution: &ChannelContribution{
FundingAmount: ourBalance.ToSatoshis(),
ChannelConfig: &channeldb.ChannelConfig{},
},
theirContribution: &ChannelContribution{
FundingAmount: theirBalance.ToSatoshis(),
ChannelConfig: &channeldb.ChannelConfig{},
},
partialState: &channeldb.OpenChannel{
ChanType: chanType,
ChainHash: *chainHash,
IsPending: true,
IsInitiator: initiator,
ChannelFlags: req.Flags,
Capacity: capacity,
LocalCommitment: channeldb.ChannelCommitment{
LocalBalance: ourBalance,
RemoteBalance: theirBalance,
FeePerKw: btcutil.Amount(req.CommitFeePerKw),
CommitFee: commitFee,
},
RemoteCommitment: channeldb.ChannelCommitment{
LocalBalance: ourBalance,
RemoteBalance: theirBalance,
FeePerKw: btcutil.Amount(req.CommitFeePerKw),
CommitFee: commitFee,
},
ThawHeight: thawHeight,
Db: wallet.Cfg.Database,
InitialLocalBalance: ourBalance,
InitialRemoteBalance: theirBalance,
Memo: req.Memo,
TapscriptRoot: req.TapscriptRoot,
},
pushMSat: req.PushMSat,
pendingChanID: req.PendingChanID,
reservationID: id,
wallet: wallet,
chanFunder: req.ChanFunder,
state: WaitingToSend,
}, nil
}
// AddAlias stores the first alias for zero-conf channels.
func (r *ChannelReservation) AddAlias(scid lnwire.ShortChannelID) {
r.Lock()
defer r.Unlock()
r.partialState.ShortChannelID = scid
}
// SetState sets the ReservationState.
func (r *ChannelReservation) SetState(state ReservationState) {
r.Lock()
defer r.Unlock()
r.state = state
}
// State returns the current ReservationState.
func (r *ChannelReservation) State() ReservationState {
r.RLock()
defer r.RUnlock()
return r.state
}
// SetNumConfsRequired sets the number of confirmations that are required for
// the ultimate funding transaction before the channel can be considered open.
// This is distinct from the main reservation workflow as it allows
// implementations a bit more flexibility w.r.t to if the responder of the
// initiator sets decides the number of confirmations needed.
func (r *ChannelReservation) SetNumConfsRequired(numConfs uint16) {
r.Lock()
defer r.Unlock()
r.partialState.NumConfsRequired = numConfs
}
// IsZeroConf returns if the reservation's underlying partial channel state is
// a zero-conf channel.
func (r *ChannelReservation) IsZeroConf() bool {
r.RLock()
defer r.RUnlock()
return r.partialState.IsZeroConf()
}
// IsTaproot returns if the reservation's underlying partial channel state is a
// taproot channel.
func (r *ChannelReservation) IsTaproot() bool {
r.RLock()
defer r.RUnlock()
return r.partialState.ChanType.IsTaproot()
}
// CommitConstraints takes the constraints that the remote party specifies for
// the type of commitments that we can generate for them. These constraints
// include several parameters that serve as flow control restricting the amount
// of satoshis that can be transferred in a single commitment. This function
// will also attempt to verify the constraints for sanity, returning an error
// if the parameters are seemed unsound.
func (r *ChannelReservation) CommitConstraints(
bounds *channeldb.ChannelStateBounds,
commitParams *channeldb.CommitmentParams,
maxLocalCSVDelay uint16,
responder bool) error {
r.Lock()
defer r.Unlock()
// First, verify the sanity of the channel constraints.
err := VerifyConstraints(
bounds, commitParams, maxLocalCSVDelay, r.partialState.Capacity,
)
if err != nil {
return err
}
// Our dust limit should always be less than or equal to our proposed
// channel reserve.
if responder && r.ourContribution.DustLimit > bounds.ChanReserve {
r.ourContribution.DustLimit = bounds.ChanReserve
}
r.ourContribution.ChanReserve = bounds.ChanReserve
r.ourContribution.MaxPendingAmount = bounds.MaxPendingAmount
r.ourContribution.MinHTLC = bounds.MinHTLC
r.ourContribution.MaxAcceptedHtlcs = bounds.MaxAcceptedHtlcs
r.ourContribution.CsvDelay = commitParams.CsvDelay
return nil
}
// validateReserveBounds checks that both ChannelReserve values are above both
// DustLimit values. This not only avoids stuck channels, but is also mandated
// by BOLT#02 even if it's not explicit. This returns true if the bounds are
// valid. This function should be called with the lock held.
func (r *ChannelReservation) validateReserveBounds() bool {
ourDustLimit := r.ourContribution.DustLimit
ourRequiredReserve := r.ourContribution.ChanReserve
theirDustLimit := r.theirContribution.DustLimit
theirRequiredReserve := r.theirContribution.ChanReserve
// We take the smaller of the two ChannelReserves and compare it
// against the larger of the two DustLimits.
minChanReserve := ourRequiredReserve
if minChanReserve > theirRequiredReserve {
minChanReserve = theirRequiredReserve
}
maxDustLimit := ourDustLimit
if maxDustLimit < theirDustLimit {
maxDustLimit = theirDustLimit
}
return minChanReserve >= maxDustLimit
}
// OurContribution returns the wallet's fully populated contribution to the
// pending payment channel. See 'ChannelContribution' for further details
// regarding the contents of a contribution.
//
// NOTE: This SHOULD NOT be modified.
// TODO(roasbeef): make copy?
func (r *ChannelReservation) OurContribution() *ChannelContribution {
r.RLock()
defer r.RUnlock()
return r.ourContribution
}
// ProcessContribution verifies the counterparty's contribution to the pending
// payment channel. As a result of this incoming message, lnwallet is able to
// build the funding transaction, and both commitment transactions. Once this
// message has been processed, all signatures to inputs to the funding
// transaction belonging to the wallet are available. Additionally, the wallet
// will generate a signature to the counterparty's version of the commitment
// transaction.
func (r *ChannelReservation) ProcessContribution(theirContribution *ChannelContribution) error {
errChan := make(chan error, 1)
r.wallet.msgChan <- &addContributionMsg{
pendingFundingID: r.reservationID,
contribution: theirContribution,
err: errChan,
}
return <-errChan
}
// IsPsbt returns true if there is a PSBT funding intent mapped to this
// reservation.
func (r *ChannelReservation) IsPsbt() bool {
_, ok := r.fundingIntent.(*chanfunding.PsbtIntent)
return ok
}
// IsCannedShim returns true if there is a canned shim funding intent mapped to
// this reservation.
func (r *ChannelReservation) IsCannedShim() bool {
_, ok := r.fundingIntent.(*chanfunding.ShimIntent)
return ok
}
// ProcessPsbt continues a previously paused funding flow that involves PSBT to
// construct the funding transaction. This method can be called once the PSBT
// is finalized and the signed transaction is available.
func (r *ChannelReservation) ProcessPsbt(
auxFundingDesc fn.Option[AuxFundingDesc]) error {
errChan := make(chan error, 1)
r.wallet.msgChan <- &continueContributionMsg{
auxFundingDesc: auxFundingDesc,
pendingFundingID: r.reservationID,
err: errChan,
}
return <-errChan
}
// RemoteCanceled informs the PSBT funding state machine that the remote peer
// has canceled the pending reservation, likely due to a timeout.
func (r *ChannelReservation) RemoteCanceled() {
psbtIntent, ok := r.fundingIntent.(*chanfunding.PsbtIntent)
if !ok {
return
}
psbtIntent.RemoteCanceled()
}
// ProcessSingleContribution verifies, and records the initiator's contribution
// to this pending single funder channel. Internally, no further action is
// taken other than recording the initiator's contribution to the single funder
// channel.
func (r *ChannelReservation) ProcessSingleContribution(theirContribution *ChannelContribution) error {
errChan := make(chan error, 1)
r.wallet.msgChan <- &addSingleContributionMsg{
pendingFundingID: r.reservationID,
contribution: theirContribution,
err: errChan,
}
return <-errChan
}
// TheirContribution returns the counterparty's pending contribution to the
// payment channel. See 'ChannelContribution' for further details regarding the
// contents of a contribution. This attribute will ONLY be available after a
// call to .ProcessContribution().
//
// NOTE: This SHOULD NOT be modified.
func (r *ChannelReservation) TheirContribution() *ChannelContribution {
r.RLock()
defer r.RUnlock()
return r.theirContribution
}
// OurSignatures retrieves the wallet's signatures to all inputs to the funding
// transaction belonging to itself, and also a signature for the counterparty's
// version of the commitment transaction. The signatures for the wallet's
// inputs to the funding transaction are returned in sorted order according to
// BIP-69: https://github.com/bitcoin/bips/blob/master/bip-0069.mediawiki.
//
// NOTE: These signatures will only be populated after a call to
// .ProcessContribution()
func (r *ChannelReservation) OurSignatures() ([]*input.Script,
input.Signature) {
r.RLock()
defer r.RUnlock()
return r.ourFundingInputScripts, r.ourCommitmentSig
}
// CompleteReservation finalizes the pending channel reservation, transitioning
// from a pending payment channel, to an open payment channel. All passed
// signatures to the counterparty's inputs to the funding transaction will be
// fully verified. Signatures are expected to be passed in sorted order
// according to BIP-69:
// https://github.com/bitcoin/bips/blob/master/bip-0069.mediawiki.
// Additionally, verification is performed in order to ensure that the
// counterparty supplied a valid signature to our version of the commitment
// transaction. Once this method returns, callers should broadcast the
// created funding transaction, then call .WaitForChannelOpen() which will
// block until the funding transaction obtains the configured number of
// confirmations. Once the method unblocks, a LightningChannel instance is
// returned, marking the channel available for updates.
func (r *ChannelReservation) CompleteReservation(fundingInputScripts []*input.Script,
commitmentSig input.Signature) (*channeldb.OpenChannel, error) {
// TODO(roasbeef): add flag for watch or not?
errChan := make(chan error, 1)
completeChan := make(chan *channeldb.OpenChannel, 1)
r.wallet.msgChan <- &addCounterPartySigsMsg{
pendingFundingID: r.reservationID,
theirFundingInputScripts: fundingInputScripts,
theirCommitmentSig: commitmentSig,
completeChan: completeChan,
err: errChan,
}
return <-completeChan, <-errChan
}
// CompleteReservationSingle finalizes the pending single funder channel
// reservation. Using the funding outpoint of the constructed funding
// transaction, and the initiator's signature for our version of the commitment
// transaction, we are able to verify the correctness of our commitment
// transaction as crafted by the initiator. Once this method returns, our
// signature for the initiator's version of the commitment transaction is
// available via the .OurSignatures() method. As this method should only be
// called as a response to a single funder channel, only a commitment signature
// will be populated.
func (r *ChannelReservation) CompleteReservationSingle(
fundingPoint *wire.OutPoint, commitSig input.Signature,
auxFundingDesc fn.Option[AuxFundingDesc]) (*channeldb.OpenChannel,
error) {
errChan := make(chan error, 1)
completeChan := make(chan *channeldb.OpenChannel, 1)
r.wallet.msgChan <- &addSingleFunderSigsMsg{
pendingFundingID: r.reservationID,
fundingOutpoint: fundingPoint,
theirCommitmentSig: commitSig,
completeChan: completeChan,
auxFundingDesc: auxFundingDesc,
err: errChan,
}
return <-completeChan, <-errChan
}
// TheirSignatures returns the counterparty's signatures to all inputs to the
// funding transaction belonging to them, as well as their signature for the
// wallet's version of the commitment transaction. This methods is provided for
// additional verification, such as needed by tests.
//
// NOTE: These attributes will be unpopulated before a call to
// .CompleteReservation().
func (r *ChannelReservation) TheirSignatures() ([]*input.Script,
input.Signature) {
r.RLock()
defer r.RUnlock()
return r.theirFundingInputScripts, r.theirCommitmentSig
}
// FinalFundingTx returns the finalized, fully signed funding transaction for
// this reservation.
//
// NOTE: If this reservation was created as the non-initiator to a single
// funding workflow, then the full funding transaction will not be available.
// Instead we will only have the final outpoint of the funding transaction.
func (r *ChannelReservation) FinalFundingTx() *wire.MsgTx {
r.RLock()
defer r.RUnlock()
return r.fundingTx
}
// FundingOutpoint returns the outpoint of the funding transaction.
//
// NOTE: The pointer returned will only be set once the .ProcessContribution()
// method is called in the case of the initiator of a single funder workflow,
// and after the .CompleteReservationSingle() method is called in the case of
// a responder to a single funder workflow.
func (r *ChannelReservation) FundingOutpoint() *wire.OutPoint {
r.RLock()
defer r.RUnlock()
return &r.partialState.FundingOutpoint
}
// SetOurUpfrontShutdown sets the upfront shutdown address on our contribution.
func (r *ChannelReservation) SetOurUpfrontShutdown(shutdown lnwire.DeliveryAddress) {
r.Lock()
defer r.Unlock()
r.ourContribution.UpfrontShutdown = shutdown
}
// Capacity returns the channel capacity for this reservation.
func (r *ChannelReservation) Capacity() btcutil.Amount {
r.RLock()
defer r.RUnlock()
return r.partialState.Capacity
}
// LeaseExpiry returns the absolute expiration height for a leased channel using
// the script enforced commitment type. A zero value is returned when the
// channel is not using a script enforced lease commitment type.
func (r *ChannelReservation) LeaseExpiry() uint32 {
if !r.partialState.ChanType.HasLeaseExpiration() {
return 0
}
return r.partialState.ThawHeight
}
// Cancel abandons this channel reservation. This method should be called in
// the scenario that communications with the counterparty break down. Upon
// cancellation, all resources previously reserved for this pending payment
// channel are returned to the free pool, allowing subsequent reservations to
// utilize the now freed resources.
func (r *ChannelReservation) Cancel() error {
errChan := make(chan error, 1)
r.wallet.msgChan <- &fundingReserveCancelMsg{
pendingFundingID: r.reservationID,
err: errChan,
}
return <-errChan
}
// ChanState the current open channel state.
func (r *ChannelReservation) ChanState() *channeldb.OpenChannel {
r.RLock()
defer r.RUnlock()
return r.partialState
}
// CommitmentKeyRings returns the local+remote key ring used for the very first
// commitment transaction both parties.
//
//nolint:ll
func (r *ChannelReservation) CommitmentKeyRings() lntypes.Dual[CommitmentKeyRing] {
r.RLock()
defer r.RUnlock()
chanType := r.partialState.ChanType
ourChanCfg := r.ourContribution.ChannelConfig
theirChanCfg := r.theirContribution.ChannelConfig
localKeys := DeriveCommitmentKeys(
r.ourContribution.FirstCommitmentPoint, lntypes.Local, chanType,
ourChanCfg, theirChanCfg,
)
remoteKeys := DeriveCommitmentKeys(
r.theirContribution.FirstCommitmentPoint, lntypes.Remote,
chanType, ourChanCfg, theirChanCfg,
)
return lntypes.Dual[CommitmentKeyRing]{
Local: *localKeys,
Remote: *remoteKeys,
}
}
// VerifyConstraints is a helper function that can be used to check the sanity
// of various channel constraints.
func VerifyConstraints(bounds *channeldb.ChannelStateBounds,
commitParams *channeldb.CommitmentParams, maxLocalCSVDelay uint16,
channelCapacity btcutil.Amount) error {
// Fail if the csv delay for our funds exceeds our maximum.
if commitParams.CsvDelay > maxLocalCSVDelay {
return ErrCsvDelayTooLarge(
commitParams.CsvDelay, maxLocalCSVDelay,
)
}
// The channel reserve should always be greater or equal to the dust
// limit. The reservation request should be denied if otherwise.
if commitParams.DustLimit > bounds.ChanReserve {
return ErrChanReserveTooSmall(
bounds.ChanReserve, commitParams.DustLimit,
)
}
// Validate against the maximum-sized witness script dust limit, and
// also ensure that the DustLimit is not too large.
maxWitnessLimit := DustLimitForSize(input.UnknownWitnessSize)
if commitParams.DustLimit < maxWitnessLimit ||
commitParams.DustLimit > 3*maxWitnessLimit {
return ErrInvalidDustLimit(commitParams.DustLimit)
}
// Fail if we consider the channel reserve to be too large. We
// currently fail if it is greater than 20% of the channel capacity.
maxChanReserve := channelCapacity / 5
if bounds.ChanReserve > maxChanReserve {
return ErrChanReserveTooLarge(
bounds.ChanReserve, maxChanReserve,
)
}
// Fail if the minimum HTLC value is too large. If this is too large,
// the channel won't be useful for sending small payments. This limit
// is currently set to maxValueInFlight, effectively letting the remote
// setting this as large as it wants.
if bounds.MinHTLC > bounds.MaxPendingAmount {
return ErrMinHtlcTooLarge(
bounds.MinHTLC, bounds.MaxPendingAmount,
)
}
// Fail if maxHtlcs is above the maximum allowed number of 483. This
// number is specified in BOLT-02.
if bounds.MaxAcceptedHtlcs > uint16(input.MaxHTLCNumber/2) {
return ErrMaxHtlcNumTooLarge(
bounds.MaxAcceptedHtlcs, uint16(input.MaxHTLCNumber/2),
)
}
// Fail if we consider maxHtlcs too small. If this is too small we
// cannot offer many HTLCs to the remote.
const minNumHtlc = 5
if bounds.MaxAcceptedHtlcs < minNumHtlc {
return ErrMaxHtlcNumTooSmall(
bounds.MaxAcceptedHtlcs, minNumHtlc,
)
}
// Fail if we consider maxValueInFlight too small. We currently require
// the remote to at least allow minNumHtlc * minHtlc in flight.
if bounds.MaxPendingAmount < minNumHtlc*bounds.MinHTLC {
return ErrMaxValueInFlightTooSmall(
bounds.MaxPendingAmount, minNumHtlc*bounds.MinHTLC,
)
}
return nil
}
// OpenChannelDetails wraps the finalized fully confirmed channel which
// resulted from a ChannelReservation instance with details concerning exactly
// _where_ in the chain the channel was ultimately opened.
type OpenChannelDetails struct {
// Channel is the active channel created by an instance of a
// ChannelReservation and the required funding workflow.
Channel *LightningChannel
// ConfirmationHeight is the block height within the chain that
// included the channel.
ConfirmationHeight uint32
// TransactionIndex is the index within the confirming block that the
// transaction resides.
TransactionIndex uint32
}
//go:build !integration
package lnwallet
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/shachain"
)
// nextRevocationProducer creates a new revocation producer, deriving the
// revocation root by applying ECDH to a new key from our revocation root
// family and the multisig key we use for the channel. For taproot channels a
// related shachain revocation root is also returned.
func (l *LightningWallet) nextRevocationProducer(res *ChannelReservation,
keyRing keychain.KeyRing,
) (shachain.Producer, shachain.Producer, error) {
// Derive the next key in the revocation root family.
nextRevocationKeyDesc, err := keyRing.DeriveNextKey(
keychain.KeyFamilyRevocationRoot,
)
if err != nil {
return nil, nil, err
}
// If the DeriveNextKey call returns the first key with Index 0, we need
// to re-derive the key as the keychain/btcwallet.go DerivePrivKey call
// special-cases Index 0.
if nextRevocationKeyDesc.Index == 0 {
nextRevocationKeyDesc, err = keyRing.DeriveNextKey(
keychain.KeyFamilyRevocationRoot,
)
if err != nil {
return nil, nil, err
}
}
res.nextRevocationKeyLoc = nextRevocationKeyDesc.KeyLocator
// Perform an ECDH operation between the private key described in
// nextRevocationKeyDesc and our public multisig key. The result will be
// used to seed the revocation producer.
revRoot, err := l.ECDH(
nextRevocationKeyDesc, res.ourContribution.MultiSigKey.PubKey,
)
if err != nil {
return nil, nil, err
}
// Once we have the root, we can then generate our shachain producer
// and from that generate the per-commitment point.
shaChainRoot := shachain.NewRevocationProducer(revRoot)
taprootShaChainRoot, err := channeldb.DeriveMusig2Shachain(shaChainRoot)
if err != nil {
return nil, nil, err
}
return shaChainRoot, taprootShaChainRoot, nil
}
package lnwallet
import (
"fmt"
"sync"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// jobBuffer is a constant the represents the buffer of jobs in the two
// main queues. This allows clients avoid necessarily blocking when
// submitting jobs into the queue.
jobBuffer = 100
// TODO(roasbeef): job buffer pool?
)
// VerifyJob is a job sent to the sigPool sig pool to verify a signature
// on a transaction. The items contained in the struct are necessary and
// sufficient to verify the full signature. The passed sigHash closure function
// should be set to a function that generates the relevant sighash.
//
// TODO(roasbeef): when we move to ecschnorr, make into batch signature
// verification using bos-coster (or pip?).
type VerifyJob struct {
// PubKey is the public key that was used to generate the purported
// valid signature. Note that with the current channel construction,
// this public key will likely have been tweaked using the current per
// commitment point for a particular commitment transactions.
PubKey *btcec.PublicKey
// Sig is the raw signature generated using the above public key. This
// is the signature to be verified.
Sig input.Signature
// SigHash is a function closure generates the sighashes that the
// passed signature is known to have signed.
SigHash func() ([]byte, error)
// HtlcIndex is the index of the HTLC from the PoV of the remote
// party's update log.
HtlcIndex uint64
// Cancel is a channel that is closed by the caller if they wish to
// cancel all pending verification jobs part of a single batch. This
// channel is closed in the case that a single signature in a batch has
// been returned as invalid, as there is no need to verify the remainder
// of the signatures.
Cancel <-chan struct{}
// ErrResp is the channel that the result of the signature verification
// is to be sent over. In the see that the signature is valid, a nil
// error will be passed. Otherwise, a concrete error detailing the
// issue will be passed. This channel MUST be buffered.
ErrResp chan *HtlcIndexErr
}
// HtlcIndexErr is a special type of error that also includes a pointer to the
// original validation job. This error message allows us to craft more detailed
// errors at upper layers.
type HtlcIndexErr struct {
error
*VerifyJob
}
// SignJob is a job sent to the sigPool sig pool to generate a valid
// signature according to the passed SignDescriptor for the passed transaction.
// Jobs are intended to be sent in batches in order to parallelize the job of
// generating signatures for a new commitment transaction.
type SignJob struct {
// SignDesc is intended to be a full populated SignDescriptor which
// encodes the necessary material (keys, witness script, etc) required
// to generate a valid signature for the specified input.
SignDesc input.SignDescriptor
// Tx is the transaction to be signed. This is required to generate the
// proper sighash for the input to be signed.
Tx *wire.MsgTx
// OutputIndex is the output index of the HTLC on the commitment
// transaction being signed.
OutputIndex int32
// Cancel is a channel that is closed by the caller if they wish to
// abandon all pending sign jobs part of a single batch. This should
// never be closed by the validator.
Cancel <-chan struct{}
// Resp is the channel that the response to this particular SignJob
// will be sent over. This channel MUST be buffered.
//
// TODO(roasbeef): actually need to allow caller to set, need to retain
// order mark commit sig as special
Resp chan SignJobResp
}
// SignJobResp is the response to a sign job. Both channels are to be read in
// order to ensure no unnecessary goroutine blocking occurs. Additionally, both
// channels should be buffered.
type SignJobResp struct {
// Sig is the generated signature for a particular SignJob In the case
// of an error during signature generation, then this value sent will
// be nil.
Sig lnwire.Sig
// Err is the error that occurred when executing the specified
// signature job. In the case that no error occurred, this value will
// be nil.
Err error
}
// SigPool is a struct that is meant to allow the current channel state
// machine to parallelize all signature generation and verification. This
// struct is needed as _each_ HTLC when creating a commitment transaction
// requires a signature, and similarly a receiver of a new commitment must
// verify all the HTLC signatures included within the CommitSig message. A pool
// of workers will be maintained by the sigPool. Batches of jobs (either
// to sign or verify) can be sent to the pool of workers which will
// asynchronously perform the specified job.
type SigPool struct {
started sync.Once
stopped sync.Once
signer input.Signer
verifyJobs chan VerifyJob
signJobs chan SignJob
wg sync.WaitGroup
quit chan struct{}
numWorkers int
}
// NewSigPool creates a new signature pool with the specified number of
// workers. The recommended parameter for the number of works is the number of
// physical CPU cores available on the target machine.
func NewSigPool(numWorkers int, signer input.Signer) *SigPool {
return &SigPool{
signer: signer,
numWorkers: numWorkers,
verifyJobs: make(chan VerifyJob, jobBuffer),
signJobs: make(chan SignJob, jobBuffer),
quit: make(chan struct{}),
}
}
// Start starts of all goroutines that the sigPool sig pool needs to
// carry out its duties.
func (s *SigPool) Start() error {
s.started.Do(func() {
walletLog.Info("SigPool starting")
for i := 0; i < s.numWorkers; i++ {
s.wg.Add(1)
go s.poolWorker()
}
})
return nil
}
// Stop signals any active workers carrying out jobs to exit so the sigPool can
// gracefully shutdown.
func (s *SigPool) Stop() error {
s.stopped.Do(func() {
close(s.quit)
s.wg.Wait()
})
return nil
}
// poolWorker is the main worker goroutine within the sigPool sig pool.
// Individual batches are distributed amongst each of the active workers. The
// workers then execute the task based on the type of job, and return the
// result back to caller.
func (s *SigPool) poolWorker() {
defer s.wg.Done()
for {
select {
// We've just received a new signature job. Given the items
// contained within the message, we'll craft a signature and
// send the result along with a possible error back to the
// caller.
case sigMsg := <-s.signJobs:
rawSig, err := s.signer.SignOutputRaw(
sigMsg.Tx, &sigMsg.SignDesc,
)
if err != nil {
select {
case sigMsg.Resp <- SignJobResp{
Sig: lnwire.Sig{},
Err: err,
}:
continue
case <-sigMsg.Cancel:
continue
case <-s.quit:
return
}
}
// Use the sig mapper to go from the input.Signature
// into the serialized lnwire.Sig that we'll send
// across the wire.
sig, err := lnwire.NewSigFromSignature(rawSig)
select {
case sigMsg.Resp <- SignJobResp{
Sig: sig,
Err: err,
}:
case <-sigMsg.Cancel:
continue
case <-s.quit:
return
}
// We've just received a new verification job from the outside
// world. We'll attempt to construct the sighash, parse the
// signature, and finally verify the signature.
case verifyMsg := <-s.verifyJobs:
sigHash, err := verifyMsg.SigHash()
if err != nil {
select {
case verifyMsg.ErrResp <- &HtlcIndexErr{
error: err,
VerifyJob: &verifyMsg,
}:
continue
case <-verifyMsg.Cancel:
continue
}
}
rawSig := verifyMsg.Sig
if !rawSig.Verify(sigHash, verifyMsg.PubKey) {
err := fmt.Errorf("invalid signature "+
"sighash: %x, sig: %x", sigHash,
rawSig.Serialize())
select {
case verifyMsg.ErrResp <- &HtlcIndexErr{
error: err,
VerifyJob: &verifyMsg,
}:
case <-verifyMsg.Cancel:
case <-s.quit:
return
}
} else {
select {
case verifyMsg.ErrResp <- nil:
case <-verifyMsg.Cancel:
case <-s.quit:
return
}
}
// The sigPool sig pool is exiting, so we will as well.
case <-s.quit:
return
}
}
}
// SubmitSignBatch submits a batch of signature jobs to the sigPool. The
// response and cancel channels for each of the SignJob's are expected to be
// fully populated, as the response for each job will be sent over the
// response channel within the job itself.
func (s *SigPool) SubmitSignBatch(signJobs []SignJob) {
for _, job := range signJobs {
select {
case s.signJobs <- job:
case <-job.Cancel:
// TODO(roasbeef): return error?
case <-s.quit:
return
}
}
}
// SubmitVerifyBatch submits a batch of verification jobs to the sigPool. For
// each job submitted, an error will be passed into the returned channel
// denoting if signature verification was valid or not. The passed cancelChan
// allows the caller to cancel all pending jobs in the case that they wish to
// bail early.
func (s *SigPool) SubmitVerifyBatch(verifyJobs []VerifyJob,
cancelChan chan struct{}) <-chan *HtlcIndexErr {
errChan := make(chan *HtlcIndexErr, len(verifyJobs))
for _, job := range verifyJobs {
job.Cancel = cancelChan
job.ErrResp = errChan
select {
case s.verifyJobs <- job:
case <-job.Cancel:
return errChan
}
}
return errChan
}
package lnwallet
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"io"
prand "math/rand"
"net"
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
var (
// For simplicity a single priv key controls all of our test outputs.
testWalletPrivKey = []byte{
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
}
// We're alice :)
bobsPrivKey = []byte{
0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
0x63, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
0xd, 0xe7, 0x95, 0xe4, 0xb7, 0x25, 0xb8, 0x4d,
0x1e, 0xb, 0x4c, 0xfd, 0x9e, 0xc5, 0x8c, 0xe9,
}
// Use a hard-coded HD seed.
testHdSeed = chainhash.Hash{
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9,
0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
}
// A serializable txn for testing funding txn.
testTx = &wire.MsgTx{
Version: 1,
TxIn: []*wire.TxIn{
{
PreviousOutPoint: wire.OutPoint{
Hash: chainhash.Hash{},
Index: 0xffffffff,
},
SignatureScript: []byte{0x04, 0x31, 0xdc, 0x00, 0x1b, 0x01, 0x62},
Sequence: 0xffffffff,
},
},
TxOut: []*wire.TxOut{
{
Value: 5000000000,
PkScript: []byte{
0x41, // OP_DATA_65
0x04, 0xd6, 0x4b, 0xdf, 0xd0, 0x9e, 0xb1, 0xc5,
0xfe, 0x29, 0x5a, 0xbd, 0xeb, 0x1d, 0xca, 0x42,
0x81, 0xbe, 0x98, 0x8e, 0x2d, 0xa0, 0xb6, 0xc1,
0xc6, 0xa5, 0x9d, 0xc2, 0x26, 0xc2, 0x86, 0x24,
0xe1, 0x81, 0x75, 0xe8, 0x51, 0xc9, 0x6b, 0x97,
0x3d, 0x81, 0xb0, 0x1c, 0xc3, 0x1f, 0x04, 0x78,
0x34, 0xbc, 0x06, 0xd6, 0xd6, 0xed, 0xf6, 0x20,
0xd1, 0x84, 0x24, 0x1a, 0x6a, 0xed, 0x8b, 0x63,
0xa6, // 65-byte signature
0xac, // OP_CHECKSIG
},
},
},
LockTime: 5,
}
// A valid, DER-encoded signature (taken from btcec unit tests).
testSigBytes = []byte{
0x30, 0x44, 0x02, 0x20, 0x4e, 0x45, 0xe1, 0x69,
0x32, 0xb8, 0xaf, 0x51, 0x49, 0x61, 0xa1, 0xd3,
0xa1, 0xa2, 0x5f, 0xdf, 0x3f, 0x4f, 0x77, 0x32,
0xe9, 0xd6, 0x24, 0xc6, 0xc6, 0x15, 0x48, 0xab,
0x5f, 0xb8, 0xcd, 0x41, 0x02, 0x20, 0x18, 0x15,
0x22, 0xec, 0x8e, 0xca, 0x07, 0xde, 0x48, 0x60,
0xa4, 0xac, 0xdd, 0x12, 0x90, 0x9d, 0x83, 0x1c,
0xc5, 0x6c, 0xbb, 0xac, 0x46, 0x22, 0x08, 0x22,
0x21, 0xa8, 0x76, 0x8d, 0x1d, 0x09,
}
aliceDustLimit = btcutil.Amount(200)
bobDustLimit = btcutil.Amount(1300)
testChannelCapacity float64 = 10
// ctxb is a context that will never be cancelled, that is used in
// place of a real quit context.
ctxb = context.Background()
)
// CreateTestChannels creates to fully populated channels to be used within
// testing fixtures. The channels will be returned as if the funding process
// has just completed. The channel itself is funded with 10 BTC, with 5 BTC
// allocated to each side. Within the channel, Alice is the initiator. If
// tweaklessCommits is true, then the commits within the channels will use the
// new format, otherwise the legacy format.
func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType,
dbModifiers ...channeldb.OptionModifier) (*LightningChannel,
*LightningChannel, error) {
channelCapacity, err := btcutil.NewAmount(testChannelCapacity)
if err != nil {
return nil, nil, err
}
channelBal := channelCapacity / 2
csvTimeoutAlice := uint32(5)
csvTimeoutBob := uint32(4)
isAliceInitiator := true
prevOut := &wire.OutPoint{
Hash: chainhash.Hash(testHdSeed),
Index: prand.Uint32(),
}
fundingTxIn := wire.NewTxIn(prevOut, nil, nil)
// For each party, we'll create a distinct set of keys in order to
// emulate the typical set up with live channels.
var (
aliceKeys []*btcec.PrivateKey
bobKeys []*btcec.PrivateKey
)
for i := 0; i < 5; i++ {
key := make([]byte, len(testWalletPrivKey))
copy(key[:], testWalletPrivKey[:])
key[0] ^= byte(i + 1)
aliceKey, _ := btcec.PrivKeyFromBytes(key)
aliceKeys = append(aliceKeys, aliceKey)
key = make([]byte, len(bobsPrivKey))
copy(key[:], bobsPrivKey)
key[0] ^= byte(i + 1)
bobKey, _ := btcec.PrivKeyFromBytes(key)
bobKeys = append(bobKeys, bobKey)
}
aliceCfg := channeldb.ChannelConfig{
ChannelStateBounds: channeldb.ChannelStateBounds{
MaxPendingAmount: lnwire.NewMSatFromSatoshis(channelCapacity),
ChanReserve: channelCapacity / 100,
MinHTLC: 0,
MaxAcceptedHtlcs: input.MaxHTLCNumber / 2,
},
CommitmentParams: channeldb.CommitmentParams{
DustLimit: aliceDustLimit,
CsvDelay: uint16(csvTimeoutAlice),
},
MultiSigKey: keychain.KeyDescriptor{
PubKey: aliceKeys[0].PubKey(),
},
RevocationBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeys[1].PubKey(),
},
PaymentBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeys[2].PubKey(),
},
DelayBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeys[3].PubKey(),
},
HtlcBasePoint: keychain.KeyDescriptor{
PubKey: aliceKeys[4].PubKey(),
},
}
bobCfg := channeldb.ChannelConfig{
ChannelStateBounds: channeldb.ChannelStateBounds{
MaxPendingAmount: lnwire.NewMSatFromSatoshis(channelCapacity),
ChanReserve: channelCapacity / 100,
MinHTLC: 0,
MaxAcceptedHtlcs: input.MaxHTLCNumber / 2,
},
CommitmentParams: channeldb.CommitmentParams{
DustLimit: bobDustLimit,
CsvDelay: uint16(csvTimeoutBob),
},
MultiSigKey: keychain.KeyDescriptor{
PubKey: bobKeys[0].PubKey(),
},
RevocationBasePoint: keychain.KeyDescriptor{
PubKey: bobKeys[1].PubKey(),
},
PaymentBasePoint: keychain.KeyDescriptor{
PubKey: bobKeys[2].PubKey(),
},
DelayBasePoint: keychain.KeyDescriptor{
PubKey: bobKeys[3].PubKey(),
},
HtlcBasePoint: keychain.KeyDescriptor{
PubKey: bobKeys[4].PubKey(),
},
}
bobRoot, err := chainhash.NewHash(bobKeys[0].Serialize())
if err != nil {
return nil, nil, err
}
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, err
}
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
aliceRoot, err := chainhash.NewHash(aliceKeys[0].Serialize())
if err != nil {
return nil, nil, err
}
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, err
}
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
aliceCommitTx, bobCommitTx, err := CreateCommitmentTxns(
channelBal, channelBal, &aliceCfg, &bobCfg, aliceCommitPoint,
bobCommitPoint, *fundingTxIn, chanType, isAliceInitiator, 0,
)
if err != nil {
return nil, nil, err
}
dbAlice := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...)
dbBob := channeldb.OpenForTesting(t, t.TempDir(), dbModifiers...)
estimator := chainfee.NewStaticEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil {
return nil, nil, err
}
commitFee := calcStaticFee(chanType, 0)
var anchorAmt btcutil.Amount
if chanType.HasAnchors() {
anchorAmt += 2 * AnchorSize
}
aliceBalance := lnwire.NewMSatFromSatoshis(
channelBal - commitFee - anchorAmt,
)
bobBalance := lnwire.NewMSatFromSatoshis(channelBal)
aliceLocalCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: aliceBalance,
RemoteBalance: bobBalance,
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: aliceCommitTx,
CommitSig: testSigBytes,
}
aliceRemoteCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: aliceBalance,
RemoteBalance: bobBalance,
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: bobCommitTx,
CommitSig: testSigBytes,
}
bobLocalCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: bobBalance,
RemoteBalance: aliceBalance,
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: bobCommitTx,
CommitSig: testSigBytes,
}
bobRemoteCommit := channeldb.ChannelCommitment{
CommitHeight: 0,
LocalBalance: bobBalance,
RemoteBalance: aliceBalance,
CommitFee: commitFee,
FeePerKw: btcutil.Amount(feePerKw),
CommitTx: aliceCommitTx,
CommitSig: testSigBytes,
}
var chanIDBytes [8]byte
if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil {
return nil, nil, err
}
shortChanID := lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(chanIDBytes[:]),
)
aliceChannelState := &channeldb.OpenChannel{
LocalChanCfg: aliceCfg,
RemoteChanCfg: bobCfg,
IdentityPub: aliceKeys[0].PubKey(),
FundingOutpoint: *prevOut,
ShortChannelID: shortChanID,
ChanType: chanType,
IsInitiator: isAliceInitiator,
Capacity: channelCapacity,
RemoteCurrentRevocation: bobCommitPoint,
RevocationProducer: alicePreimageProducer,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceLocalCommit,
RemoteCommitment: aliceRemoteCommit,
Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: testTx,
}
bobChannelState := &channeldb.OpenChannel{
LocalChanCfg: bobCfg,
RemoteChanCfg: aliceCfg,
IdentityPub: bobKeys[0].PubKey(),
FundingOutpoint: *prevOut,
ShortChannelID: shortChanID,
ChanType: chanType,
IsInitiator: !isAliceInitiator,
Capacity: channelCapacity,
RemoteCurrentRevocation: aliceCommitPoint,
RevocationProducer: bobPreimageProducer,
RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobLocalCommit,
RemoteCommitment: bobRemoteCommit,
Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID),
}
// If the channel type has a tapscript root, then we'll also specify
// one here to apply to both the channels.
if chanType.HasTapscriptRoot() {
var tapscriptRoot chainhash.Hash
_, err := io.ReadFull(rand.Reader, tapscriptRoot[:])
if err != nil {
return nil, nil, err
}
someRoot := fn.Some(tapscriptRoot)
aliceChannelState.TapscriptRoot = someRoot
bobChannelState.TapscriptRoot = someRoot
}
aliceSigner := input.NewMockSigner(aliceKeys, nil)
bobSigner := input.NewMockSigner(bobKeys, nil)
// TODO(roasbeef): make mock version of pre-image store
auxSigner := NewDefaultAuxSignerMock(t)
alicePool := NewSigPool(1, aliceSigner)
channelAlice, err := NewLightningChannel(
aliceSigner, aliceChannelState, alicePool,
WithLeafStore(&MockAuxLeafStore{}),
WithAuxSigner(auxSigner),
)
if err != nil {
return nil, nil, err
}
alicePool.Start()
t.Cleanup(func() {
require.NoError(t, alicePool.Stop())
})
obfuscator := createStateHintObfuscator(aliceChannelState)
bobPool := NewSigPool(1, bobSigner)
channelBob, err := NewLightningChannel(
bobSigner, bobChannelState, bobPool,
WithLeafStore(&MockAuxLeafStore{}),
WithAuxSigner(auxSigner),
)
if err != nil {
return nil, nil, err
}
bobPool.Start()
t.Cleanup(func() {
require.NoError(t, bobPool.Stop())
})
err = SetStateNumHint(
aliceCommitTx, 0, obfuscator,
)
if err != nil {
return nil, nil, err
}
err = SetStateNumHint(
bobCommitTx, 0, obfuscator,
)
if err != nil {
return nil, nil, err
}
addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18556,
}
if err := channelAlice.channelState.SyncPending(addr, 101); err != nil {
return nil, nil, err
}
addr = &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}
if err := channelBob.channelState.SyncPending(addr, 101); err != nil {
return nil, nil, err
}
// Now that the channel are open, simulate the start of a session by
// having Alice and Bob extend their revocation windows to each other.
err = initRevocationWindows(channelAlice, channelBob)
if err != nil {
return nil, nil, err
}
return channelAlice, channelBob, nil
}
// initMusigNonce is used to manually setup musig2 nonces for a new channel,
// outside the normal chan-reest flow.
func initMusigNonce(chanA, chanB *LightningChannel) error {
chanANonces, err := chanA.GenMusigNonces()
if err != nil {
return err
}
chanBNonces, err := chanB.GenMusigNonces()
if err != nil {
return err
}
if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil {
return err
}
if err := chanB.InitRemoteMusigNonces(chanANonces); err != nil {
return err
}
return nil
}
// initRevocationWindows simulates a new channel being opened within the p2p
// network by populating the initial revocation windows of the passed
// commitment state machines.
func initRevocationWindows(chanA, chanB *LightningChannel) error {
// If these are taproot chanenls, then we need to also simulate sending
// either FundingLocked or ChannelReestablish by calling
// InitRemoteMusigNonces for both sides.
if chanA.channelState.ChanType.IsTaproot() {
if err := initMusigNonce(chanA, chanB); err != nil {
return err
}
}
aliceNextRevoke, err := chanA.NextRevocationKey()
if err != nil {
return err
}
if err := chanB.InitNextRevocation(aliceNextRevoke); err != nil {
return err
}
bobNextRevoke, err := chanB.NextRevocationKey()
if err != nil {
return err
}
if err := chanA.InitNextRevocation(bobNextRevoke); err != nil {
return err
}
return nil
}
// pubkeyFromHex parses a Bitcoin public key from a hex encoded string.
func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) {
bytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, err
}
return btcec.ParsePubKey(bytes)
}
// privkeyFromHex parses a Bitcoin private key from a hex encoded string.
func privkeyFromHex(keyHex string) (*btcec.PrivateKey, error) {
bytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, err
}
key, _ := btcec.PrivKeyFromBytes(bytes)
return key, nil
}
// blockFromHex parses a full Bitcoin block from a hex encoded string.
func blockFromHex(blockHex string) (*btcutil.Block, error) {
bytes, err := hex.DecodeString(blockHex)
if err != nil {
return nil, err
}
return btcutil.NewBlockFromBytes(bytes)
}
// txFromHex parses a full Bitcoin transaction from a hex encoded string.
func txFromHex(txHex string) (*btcutil.Tx, error) {
bytes, err := hex.DecodeString(txHex)
if err != nil {
return nil, err
}
return btcutil.NewTxFromBytes(bytes)
}
// calcStaticFee calculates appropriate fees for commitment transactions. This
// function provides a simple way to allow test balance assertions to take fee
// calculations into account.
//
// TODO(bvu): Refactor when dynamic fee estimation is added.
func calcStaticFee(chanType channeldb.ChannelType, numHTLCs int) btcutil.Amount {
const (
htlcWeight = 172
feePerKw = btcutil.Amount(24/4) * 1000
)
htlcsWeight := htlcWeight * int64(numHTLCs)
totalWeight := CommitWeight(chanType) + lntypes.WeightUnit(htlcsWeight)
return feePerKw * (btcutil.Amount(totalWeight)) / 1000
}
// ForceStateTransition executes the necessary interaction between the two
// commitment state machines to transition to a new state locking in any
// pending updates. This method is useful when testing interactions between two
// live state machines.
func ForceStateTransition(chanA, chanB *LightningChannel) error {
aliceNewCommit, err := chanA.SignNextCommitment(ctxb)
if err != nil {
return err
}
err = chanB.ReceiveNewCommitment(aliceNewCommit.CommitSigs)
if err != nil {
return err
}
bobRevocation, _, _, err := chanB.RevokeCurrentCommitment()
if err != nil {
return err
}
bobNewCommit, err := chanB.SignNextCommitment(ctxb)
if err != nil {
return err
}
_, _, err = chanA.ReceiveRevocation(bobRevocation)
if err != nil {
return err
}
err = chanA.ReceiveNewCommitment(bobNewCommit.CommitSigs)
if err != nil {
return err
}
aliceRevocation, _, _, err := chanA.RevokeCurrentCommitment()
if err != nil {
return err
}
_, _, err = chanB.ReceiveRevocation(aliceRevocation)
if err != nil {
return err
}
return nil
}
func NewDefaultAuxSignerMock(t *testing.T) *MockAuxSigner {
auxSigner := NewAuxSignerMock(EmptyMockJobHandler)
type testSigBlob struct {
BlobInt tlv.RecordT[tlv.TlvType65634, uint16]
}
var sigBlobBuf bytes.Buffer
sigBlob := testSigBlob{
BlobInt: tlv.NewPrimitiveRecord[tlv.TlvType65634, uint16](5),
}
tlvStream, err := tlv.NewStream(sigBlob.BlobInt.Record())
require.NoError(t, err, "unable to create tlv stream")
require.NoError(t, tlvStream.Encode(&sigBlobBuf))
auxSigner.On(
"SubmitSecondLevelSigBatch", mock.Anything, mock.Anything,
mock.Anything,
).Return(nil)
auxSigner.On(
"PackSigs", mock.Anything,
).Return(fn.Ok(fn.Some(sigBlobBuf.Bytes())))
auxSigner.On(
"UnpackSigs", mock.Anything,
).Return(fn.Ok([]fn.Option[tlv.Blob]{
fn.Some(sigBlobBuf.Bytes()),
}))
auxSigner.On(
"VerifySecondLevelSigs", mock.Anything, mock.Anything,
mock.Anything,
).Return(nil)
return auxSigner
}
package lnwallet
import (
"encoding/binary"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
)
const (
// StateHintSize is the total number of bytes used between the sequence
// number and locktime of the commitment transaction use to encode a hint
// to the state number of a particular commitment transaction.
StateHintSize = 6
// MaxStateHint is the maximum state number we're able to encode using
// StateHintSize bytes amongst the sequence number and locktime fields
// of the commitment transaction.
maxStateHint uint64 = (1 << 48) - 1
)
var (
// TimelockShift is used to make sure the commitment transaction is
// spendable by setting the locktime with it so that it is larger than
// 500,000,000, thus interpreting it as Unix epoch timestamp and not
// a block height. It is also smaller than the current timestamp which
// has bit (1 << 30) set, so there is no risk of having the commitment
// transaction be rejected. This way we can safely use the lower 24 bits
// of the locktime field for part of the obscured commitment transaction
// number.
TimelockShift = uint32(1 << 29)
)
// CreateHtlcSuccessTx creates a transaction that spends the output on the
// commitment transaction of the peer that receives an HTLC. This transaction
// essentially acts as an off-chain covenant as it's only permitted to spend
// the designated HTLC output, and also that spend can _only_ be used as a
// state transition to create another output which actually allows redemption
// or revocation of an HTLC.
//
// In order to spend the segwit v0 HTLC output, the witness for the passed
// transaction should be:
// - <0> <sender sig> <recvr sig> <preimage>
//
// In order to spend the segwit v1 (taproot) HTLC output, the witness for the
// passed transaction should be:
// - <sender sig> <receiver sig> <preimage> <success_script> <control_block>
func CreateHtlcSuccessTx(chanType channeldb.ChannelType, initiator bool,
htlcOutput wire.OutPoint, htlcAmt btcutil.Amount, csvDelay,
leaseExpiry uint32, revocationKey, delayKey *btcec.PublicKey,
auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) {
// Create a version two transaction (as the success version of this
// spends an output with a CSV timeout).
successTx := wire.NewMsgTx(2)
// The input to the transaction is the outpoint that creates the
// original HTLC on the sender's commitment transaction. Set the
// sequence number based on the channel type.
txin := &wire.TxIn{
PreviousOutPoint: htlcOutput,
Sequence: HtlcSecondLevelInputSequence(chanType),
}
successTx.AddTxIn(txin)
// Next, we'll generate the script used as the output for all second
// level HTLC which forces a covenant w.r.t what can be done with all
// HTLC outputs.
scriptInfo, err := SecondLevelHtlcScript(
chanType, initiator, revocationKey, delayKey, csvDelay,
leaseExpiry, auxLeaf,
)
if err != nil {
return nil, err
}
// Finally, the output is simply the amount of the HTLC (minus the
// required fees), paying to the timeout script.
successTx.AddTxOut(&wire.TxOut{
Value: int64(htlcAmt),
PkScript: scriptInfo.PkScript(),
})
return successTx, nil
}
// CreateHtlcTimeoutTx creates a transaction that spends the HTLC output on the
// commitment transaction of the peer that created an HTLC (the sender). This
// transaction essentially acts as an off-chain covenant as it spends a 2-of-2
// multi-sig output. This output requires a signature from both the sender and
// receiver of the HTLC. By using a distinct transaction, we're able to
// uncouple the timeout and delay clauses of the HTLC contract. This
// transaction is locked with an absolute lock-time so the sender can only
// attempt to claim the output using it after the lock time has passed.
//
// In order to spend the HTLC output for segwit v0, the witness for the passed
// transaction should be:
// - <0> <sender sig> <receiver sig> <0>
//
// In order to spend the HTLC output for segwit v1, then witness for the passed
// transaction should be:
// - <sender sig> <receiver sig> <timeout_script> <control_block>
//
// NOTE: The passed amount for the HTLC should take into account the required
// fee rate at the time the HTLC was created. The fee should be able to
// entirely pay for this (tiny: 1-in 1-out) transaction.
func CreateHtlcTimeoutTx(chanType channeldb.ChannelType, initiator bool,
htlcOutput wire.OutPoint, htlcAmt btcutil.Amount,
cltvExpiry, csvDelay, leaseExpiry uint32,
revocationKey, delayKey *btcec.PublicKey,
auxLeaf input.AuxTapLeaf) (*wire.MsgTx, error) {
// Create a version two transaction (as the success version of this
// spends an output with a CSV timeout), and set the lock-time to the
// specified absolute lock-time in blocks.
timeoutTx := wire.NewMsgTx(2)
timeoutTx.LockTime = cltvExpiry
// The input to the transaction is the outpoint that creates the
// original HTLC on the sender's commitment transaction. Set the
// sequence number based on the channel type.
txin := &wire.TxIn{
PreviousOutPoint: htlcOutput,
SignatureScript: []byte{},
Witness: [][]byte{},
Sequence: HtlcSecondLevelInputSequence(chanType),
}
timeoutTx.AddTxIn(txin)
// Next, we'll generate the script used as the output for all second
// level HTLC which forces a covenant w.r.t what can be done with all
// HTLC outputs.
scriptInfo, err := SecondLevelHtlcScript(
chanType, initiator, revocationKey, delayKey, csvDelay,
leaseExpiry, auxLeaf,
)
if err != nil {
return nil, err
}
// Finally, the output is simply the amount of the HTLC (minus the
// required fees), paying to the regular second level HTLC script.
timeoutTx.AddTxOut(&wire.TxOut{
Value: int64(htlcAmt),
PkScript: scriptInfo.PkScript(),
})
return timeoutTx, nil
}
// SetStateNumHint encodes the current state number within the passed
// commitment transaction by re-purposing the locktime and sequence fields in
// the commitment transaction to encode the obfuscated state number. The state
// number is encoded using 48 bits. The lower 24 bits of the lock time are the
// lower 24 bits of the obfuscated state number and the lower 24 bits of the
// sequence field are the higher 24 bits. Finally before encoding, the
// obfuscator is XOR'd against the state number in order to hide the exact
// state number from the PoV of outside parties.
func SetStateNumHint(commitTx *wire.MsgTx, stateNum uint64,
obfuscator [StateHintSize]byte) error {
// With the current schema we are only able to encode state num
// hints up to 2^48. Therefore if the passed height is greater than our
// state hint ceiling, then exit early.
if stateNum > maxStateHint {
return fmt.Errorf("unable to encode state, %v is greater "+
"state num that max of %v", stateNum, maxStateHint)
}
if len(commitTx.TxIn) != 1 {
return fmt.Errorf("commitment tx must have exactly 1 input, "+
"instead has %v", len(commitTx.TxIn))
}
// Convert the obfuscator into a uint64, then XOR that against the
// targeted height in order to obfuscate the state number of the
// commitment transaction in the case that either commitment
// transaction is broadcast directly on chain.
var obfs [8]byte
copy(obfs[2:], obfuscator[:])
xorInt := binary.BigEndian.Uint64(obfs[:])
stateNum = stateNum ^ xorInt
// Set the height bit of the sequence number in order to disable any
// sequence locks semantics.
commitTx.TxIn[0].Sequence = uint32(stateNum>>24) | wire.SequenceLockTimeDisabled
commitTx.LockTime = uint32(stateNum&0xFFFFFF) | TimelockShift
return nil
}
// GetStateNumHint recovers the current state number given a commitment
// transaction which has previously had the state number encoded within it via
// setStateNumHint and a shared obfuscator.
//
// See setStateNumHint for further details w.r.t exactly how the state-hints
// are encoded.
func GetStateNumHint(commitTx *wire.MsgTx, obfuscator [StateHintSize]byte) uint64 {
// Convert the obfuscator into a uint64, this will be used to
// de-obfuscate the final recovered state number.
var obfs [8]byte
copy(obfs[2:], obfuscator[:])
xorInt := binary.BigEndian.Uint64(obfs[:])
// Retrieve the state hint from the sequence number and locktime
// of the transaction.
stateNumXor := uint64(commitTx.TxIn[0].Sequence&0xFFFFFF) << 24
stateNumXor |= uint64(commitTx.LockTime & 0xFFFFFF)
// Finally, to obtain the final state number, we XOR by the obfuscator
// value to de-obfuscate the state number.
return stateNumXor ^ xorInt
}
package lnwallet
import (
"github.com/lightningnetwork/lnd/fn/v2"
)
// updateLog is an append-only log that stores updates to a node's commitment
// chain. This structure can be seen as the "mempool" within Lightning where
// changes are stored before they're committed to the chain. Once an entry has
// been committed in both the local and remote commitment chain, then it can be
// removed from this log.
//
// TODO(roasbeef): create lightning package, move commitment and update to
// package?
// - also move state machine, separate from lnwallet package
// - possible embed updateLog within commitmentChain.
type updateLog struct {
// logIndex is a monotonically increasing integer that tracks the total
// number of update entries ever applied to the log. When sending new
// commitment states, we include all updates up to this index.
logIndex uint64
// htlcCounter is a monotonically increasing integer that tracks the
// total number of offered HTLC's by the owner of this update log,
// hence the `Add` update type. We use a distinct index for this
// purpose, as update's that remove entries from the log will be
// indexed using this counter.
htlcCounter uint64
// List is the updatelog itself, we embed this value so updateLog has
// access to all the method of a list.List.
*fn.List[*paymentDescriptor]
// updateIndex maps a `logIndex` to a particular update entry. It
// deals with the four update types:
// `Fail|MalformedFail|Settle|FeeUpdate`
updateIndex map[uint64]*fn.Node[*paymentDescriptor]
// htlcIndex maps a `htlcCounter` to an offered HTLC entry, hence the
// `Add` update.
htlcIndex map[uint64]*fn.Node[*paymentDescriptor]
// modifiedHtlcs is a set that keeps track of all the current modified
// htlcs, hence update types `Fail|MalformedFail|Settle`. A modified
// HTLC is one that's present in the log, and has as a pending fail or
// settle that's attempting to consume it.
modifiedHtlcs fn.Set[uint64]
}
// newUpdateLog creates a new updateLog instance.
func newUpdateLog(logIndex, htlcCounter uint64) *updateLog {
return &updateLog{
List: fn.NewList[*paymentDescriptor](),
updateIndex: make(map[uint64]*fn.Node[*paymentDescriptor]),
htlcIndex: make(map[uint64]*fn.Node[*paymentDescriptor]),
logIndex: logIndex,
htlcCounter: htlcCounter,
modifiedHtlcs: fn.NewSet[uint64](),
}
}
// restoreHtlc will "restore" a prior HTLC to the updateLog. We say restore as
// this method is intended to be used when re-covering a prior commitment
// state. This function differs from appendHtlc in that it won't increment
// either of log's counters. If the HTLC is already present, then it is
// ignored.
func (u *updateLog) restoreHtlc(pd *paymentDescriptor) {
if _, ok := u.htlcIndex[pd.HtlcIndex]; ok {
return
}
u.htlcIndex[pd.HtlcIndex] = u.PushBack(pd)
}
// appendUpdate appends a new update to the tip of the updateLog. The entry is
// also added to index accordingly.
func (u *updateLog) appendUpdate(pd *paymentDescriptor) {
u.updateIndex[u.logIndex] = u.PushBack(pd)
u.logIndex++
}
// restoreUpdate appends a new update to the tip of the updateLog. The entry is
// also added to index accordingly. This function differs from appendUpdate in
// that it won't increment the log index counter.
func (u *updateLog) restoreUpdate(pd *paymentDescriptor) {
u.updateIndex[pd.LogIndex] = u.PushBack(pd)
}
// appendHtlc appends a new HTLC offer to the tip of the update log. The entry
// is also added to the offer index accordingly.
func (u *updateLog) appendHtlc(pd *paymentDescriptor) {
u.htlcIndex[u.htlcCounter] = u.PushBack(pd)
u.htlcCounter++
u.logIndex++
}
// lookupHtlc attempts to look up an offered HTLC according to its offer
// index. If the entry isn't found, then a nil pointer is returned.
func (u *updateLog) lookupHtlc(i uint64) *paymentDescriptor {
htlc, ok := u.htlcIndex[i]
if !ok {
return nil
}
return htlc.Value
}
// remove attempts to remove an entry from the update log. If the entry is
// found, then the entry will be removed from the update log and index.
func (u *updateLog) removeUpdate(i uint64) {
entry := u.updateIndex[i]
u.Remove(entry)
delete(u.updateIndex, i)
}
// removeHtlc attempts to remove an HTLC offer form the update log. If the
// entry is found, then the entry will be removed from both the main log and
// the offer index.
func (u *updateLog) removeHtlc(i uint64) {
entry := u.htlcIndex[i]
u.Remove(entry)
delete(u.htlcIndex, i)
u.modifiedHtlcs.Remove(i)
}
// htlcHasModification returns true if the HTLC identified by the passed index
// has a pending modification within the log.
func (u *updateLog) htlcHasModification(i uint64) bool {
return u.modifiedHtlcs.Contains(i)
}
// markHtlcModified marks an HTLC as modified based on its HTLC index. After a
// call to this method, htlcHasModification will return true until the HTLC is
// removed.
func (u *updateLog) markHtlcModified(i uint64) {
u.modifiedHtlcs.Add(i)
}
// compactLogs performs garbage collection within the log removing HTLCs which
// have been removed from the point-of-view of the tail of both chains. The
// entries which timeout/settle HTLCs are also removed.
func compactLogs(ourLog, theirLog *updateLog,
localChainTail, remoteChainTail uint64) {
compactLog := func(logA, logB *updateLog) {
var nextA *fn.Node[*paymentDescriptor]
for e := logA.Front(); e != nil; e = nextA {
// Assign next iteration element at top of loop because
// we may remove the current element from the list,
// which can change the iterated sequence.
nextA = e.Next()
htlc := e.Value
rmvHeights := htlc.removeCommitHeights
// We skip Adds, as they will be removed along with the
// fail/settles below.
if htlc.EntryType == Add {
continue
}
// If the HTLC hasn't yet been removed from either
// chain, the skip it.
if rmvHeights.Remote == 0 || rmvHeights.Local == 0 {
continue
}
// Otherwise if the height of the tail of both chains
// is at least the height in which the HTLC was
// removed, then evict the settle/timeout entry along
// with the original add entry.
if remoteChainTail >= rmvHeights.Remote &&
localChainTail >= rmvHeights.Local {
// Fee updates have no parent htlcs, so we only
// remove the update itself.
if htlc.EntryType == FeeUpdate {
logA.removeUpdate(htlc.LogIndex)
continue
}
// The other types (fail/settle) do have a
// parent HTLC, so we'll remove that HTLC from
// the other log.
logA.removeUpdate(htlc.LogIndex)
logB.removeHtlc(htlc.ParentIndex)
}
}
}
compactLog(ourLog, theirLog)
compactLog(theirLog, ourLog)
}
package lnwallet
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"math"
"net"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/btcutil/txsort"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wallet"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding"
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// The size of the buffered queue of requests to the wallet from the
// outside word.
msgBufferSize = 100
// AnchorChanReservedValue is the amount we'll keep around in the
// wallet in case we have to fee bump anchor channels on force close.
// TODO(halseth): update constant to target a specific commit size at
// set fee rate.
AnchorChanReservedValue = btcutil.Amount(10_000)
// MaxAnchorChanReservedValue is the maximum value we'll reserve for
// anchor channel fee bumping. We cap it at 10 times the per-channel
// amount such that nodes with a high number of channels don't have to
// keep around a very large amount for the unlikely scenario that they
// all close at the same time.
MaxAnchorChanReservedValue = 10 * AnchorChanReservedValue
)
var (
// ErrPsbtFundingRequired is the error that is returned during the
// contribution handling process if the process should be paused for
// the construction of a PSBT outside of lnd's wallet.
ErrPsbtFundingRequired = errors.New("PSBT funding required")
// ErrReservedValueInvalidated is returned if we try to publish a
// transaction that would take the walletbalance below what we require
// to keep around to fee bump our open anchor channels.
ErrReservedValueInvalidated = errors.New("reserved wallet balance " +
"invalidated: transaction would leave insufficient funds for " +
"fee bumping anchor channel closings (see debug log for details)")
// ErrEmptyPendingChanID is returned when an empty value is used for
// the pending channel ID.
ErrEmptyPendingChanID = errors.New("pending channel ID is empty")
// ErrDuplicatePendingChanID is returned when an existing pending
// channel ID is registered again.
ErrDuplicatePendingChanID = errors.New("duplicate pending channel ID")
)
// PsbtFundingRequired is a type that implements the error interface and
// contains the information needed to construct a PSBT.
type PsbtFundingRequired struct {
// Intent is the pending PSBT funding intent that needs to be funded
// if the wrapping error is returned.
Intent *chanfunding.PsbtIntent
}
// Error returns the underlying error.
//
// NOTE: This method is part of the error interface.
func (p *PsbtFundingRequired) Error() string {
return ErrPsbtFundingRequired.Error()
}
// AuxFundingDesc stores a series of attributes that may be used to modify the
// way the channel funding occurs. This struct contains information that can
// only be derived once both sides have received and sent their contributions
// to the channel (keys, etc.).
type AuxFundingDesc struct {
// CustomFundingBlob is a custom blob that'll be stored in the database
// within the OpenChannel struct. This should represent information
// static to the channel lifetime.
CustomFundingBlob tlv.Blob
// CustomLocalCommitBlob is a custom blob that'll be stored in the
// first commitment entry for the local party.
CustomLocalCommitBlob tlv.Blob
// CustomRemoteCommitBlob is a custom blob that'll be stored in the
// first commitment entry for the remote party.
CustomRemoteCommitBlob tlv.Blob
// LocalInitAuxLeaves is the set of aux leaves that'll be used for our
// very first commitment state.
LocalInitAuxLeaves CommitAuxLeaves
// RemoteInitAuxLeaves is the set of aux leaves that'll be used for the
// very first commitment state for the remote party.
RemoteInitAuxLeaves CommitAuxLeaves
}
// InitFundingReserveMsg is the first message sent to initiate the workflow
// required to open a payment channel with a remote peer. The initial required
// parameters are configurable across channels. These parameters are to be
// chosen depending on the fee climate within the network, and time value of
// funds to be locked up within the channel. Upon success a ChannelReservation
// will be created in order to track the lifetime of this pending channel.
// Outputs selected will be 'locked', making them unavailable, for any other
// pending reservations. Therefore, all channels in reservation limbo will be
// periodically timed out after an idle period in order to avoid "exhaustion"
// attacks.
type InitFundingReserveMsg struct {
// ChainHash denotes that chain to be used to ultimately open the
// target channel.
ChainHash *chainhash.Hash
// PendingChanID is the pending channel ID for this funding flow as
// used in the wire protocol.
PendingChanID [32]byte
// NodeID is the ID of the remote node we would like to open a channel
// with.
NodeID *btcec.PublicKey
// NodeAddr is the address port that we used to either establish or
// accept the connection which led to the negotiation of this funding
// workflow.
NodeAddr net.Addr
// SubtractFees should be set if we intend to spend exactly
// LocalFundingAmt when opening the channel, subtracting the fees from
// the funding output. This can be used for instance to use all our
// remaining funds to open the channel, since it will take fees into
// account.
SubtractFees bool
// LocalFundingAmt is the amount of funds requested from us for this
// channel.
LocalFundingAmt btcutil.Amount
// RemoteFundingAmnt is the amount of funds the remote will contribute
// to this channel.
RemoteFundingAmt btcutil.Amount
// FundUpToMaxAmt defines if channel funding should try to add as many
// funds to the channel opening as possible up to this amount. If used,
// then MinFundAmt is treated as the minimum amount of funds that must
// be available to open the channel. If set to zero it is ignored.
FundUpToMaxAmt btcutil.Amount
// MinFundAmt denotes the minimum channel capacity that has to be
// allocated iff the FundUpToMaxAmt is set.
MinFundAmt btcutil.Amount
// Outpoints is a list of client-selected outpoints that should be used
// for funding a channel. If LocalFundingAmt is specified then this
// amount is allocated from the sum of outpoints towards funding. If the
// FundUpToMaxAmt is specified the entirety of selected funds is
// allocated towards channel funding.
Outpoints []wire.OutPoint
// RemoteChanReserve is the channel reserve we required for the remote
// peer.
RemoteChanReserve btcutil.Amount
// CommitFeePerKw is the starting accepted satoshis/Kw fee for the set
// of initial commitment transactions. In order to ensure timely
// confirmation, it is recommended that this fee should be generous,
// paying some multiple of the accepted base fee rate of the network.
CommitFeePerKw chainfee.SatPerKWeight
// FundingFeePerKw is the fee rate in sat/kw to use for the initial
// funding transaction.
FundingFeePerKw chainfee.SatPerKWeight
// PushMSat is the number of milli-satoshis that should be pushed over
// the responder as part of the initial channel creation.
PushMSat lnwire.MilliSatoshi
// Flags are the channel flags specified by the initiator in the
// open_channel message.
Flags lnwire.FundingFlag
// MinConfs indicates the minimum number of confirmations that each
// output selected to fund the channel should satisfy.
MinConfs int32
// CommitType indicates what type of commitment type the channel should
// be using, like tweakless or anchors.
CommitType CommitmentType
// ChanFunder is an optional channel funder that allows the caller to
// control exactly how the channel funding is carried out. If not
// specified, then the default chanfunding.WalletAssembler will be
// used.
ChanFunder chanfunding.Assembler
// AllowUtxoForFunding enables the channel funding workflow to restrict
// the selection of utxos when selecting the inputs for the channel
// opening. This does ONLY apply for the internal wallet backed channel
// opening case.
//
// NOTE: This is very useful when opening channels with unconfirmed
// inputs to make sure stable non-replaceable inputs are used.
AllowUtxoForFunding func(Utxo) bool
// ZeroConf is a boolean that is true if a zero-conf channel was
// negotiated.
ZeroConf bool
// OptionScidAlias is a boolean that is true if an option-scid-alias
// channel type was explicitly negotiated.
OptionScidAlias bool
// ScidAliasFeature is true if the option-scid-alias feature bit was
// negotiated.
ScidAliasFeature bool
// Memo is any arbitrary information we wish to store locally about the
// channel that will be useful to our future selves.
Memo []byte
// TapscriptRoot is an optional tapscript root that if provided, will
// be used to create the combined key for musig2 based channels.
TapscriptRoot fn.Option[chainhash.Hash]
// err is a channel in which all errors will be sent across. Will be
// nil if this initial set is successful.
//
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
// resp is channel in which a ChannelReservation with our contributions
// filled in will be sent across this channel in the case of a
// successfully reservation initiation. In the case of an error, this
// will read a nil pointer.
//
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
resp chan *ChannelReservation
}
// fundingReserveCancelMsg is a message reserved for cancelling an existing
// channel reservation identified by its reservation ID. Cancelling a reservation
// frees its locked outputs up, for inclusion within further reservations.
type fundingReserveCancelMsg struct {
pendingFundingID uint64
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error // Buffered
}
// addContributionMsg represents a message executing the second phase of the
// channel reservation workflow. This message carries the counterparty's
// "contribution" to the payment channel. In the case that this message is
// processed without generating any errors, then channel reservation will then
// be able to construct the funding tx, both commitment transactions, and
// finally generate signatures for all our inputs to the funding transaction,
// and for the remote node's version of the commitment transaction.
type addContributionMsg struct {
pendingFundingID uint64
contribution *ChannelContribution
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
}
// continueContributionMsg represents a message that signals that the
// interrupted funding process involving a PSBT can now be continued because the
// finalized transaction is now available.
type continueContributionMsg struct {
pendingFundingID uint64
// auxFundingDesc is an optional descriptor that contains information
// about the custom channel funding flow.
auxFundingDesc fn.Option[AuxFundingDesc]
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
}
// addSingleContributionMsg represents a message executing the second phase of
// a single funder channel reservation workflow. This messages carries the
// counterparty's "contribution" to the payment channel. As this message is
// sent when on the responding side to a single funder workflow, no further
// action apart from storing the provided contribution is carried out.
type addSingleContributionMsg struct {
pendingFundingID uint64
contribution *ChannelContribution
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
}
// addCounterPartySigsMsg represents the final message required to complete,
// and 'open' a payment channel. This message carries the counterparty's
// signatures for each of their inputs to the funding transaction, and also a
// signature allowing us to spend our version of the commitment transaction.
// If we're able to verify all the signatures are valid, the funding transaction
// will be broadcast to the network. After the funding transaction gains a
// configurable number of confirmations, the channel is officially considered
// 'open'.
type addCounterPartySigsMsg struct {
pendingFundingID uint64
// Should be order of sorted inputs that are theirs. Sorting is done
// in accordance to BIP-69:
// https://github.com/bitcoin/bips/blob/master/bip-0069.mediawiki.
theirFundingInputScripts []*input.Script
// This should be 1/2 of the signatures needed to successfully spend our
// version of the commitment transaction.
theirCommitmentSig input.Signature
// This channel is used to return the completed channel after the wallet
// has completed all of its stages in the funding process.
completeChan chan *channeldb.OpenChannel
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
}
// addSingleFunderSigsMsg represents the next-to-last message required to
// complete a single-funder channel workflow. Once the initiator is able to
// construct the funding transaction, they send both the outpoint and a
// signature for our version of the commitment transaction. Once this message
// is processed we (the responder) are able to construct both commitment
// transactions, signing the remote party's version.
type addSingleFunderSigsMsg struct {
pendingFundingID uint64
// auxFundingDesc is an optional descriptor that contains information
// about the custom channel funding flow.
auxFundingDesc fn.Option[AuxFundingDesc]
// fundingOutpoint is the outpoint of the completed funding
// transaction as assembled by the workflow initiator.
fundingOutpoint *wire.OutPoint
// theirCommitmentSig are the 1/2 of the signatures needed to
// successfully spend our version of the commitment transaction.
theirCommitmentSig input.Signature
// This channel is used to return the completed channel after the wallet
// has completed all of its stages in the funding process.
completeChan chan *channeldb.OpenChannel
// NOTE: In order to avoid deadlocks, this channel MUST be buffered.
err chan error
}
// CheckReservedValueTxReq is the request struct used to call
// CheckReservedValueTx with. It contains the transaction to check as well as
// an optional explicitly defined index to denote a change output that is not
// watched by the wallet.
type CheckReservedValueTxReq struct {
// Tx is the transaction to check the outputs for.
Tx *wire.MsgTx
// ChangeIndex denotes an optional output index that can be explicitly
// set for a change that is not being watched by the wallet and would
// otherwise not be recognized as a change output.
ChangeIndex *int
}
// LightningWallet is a domain specific, yet general Bitcoin wallet capable of
// executing workflow required to interact with the Lightning Network. It is
// domain specific in the sense that it understands all the fancy scripts used
// within the Lightning Network, channel lifetimes, etc. However, it embeds a
// general purpose Bitcoin wallet within it. Therefore, it is also able to
// serve as a regular Bitcoin wallet which uses HD keys. The wallet is highly
// concurrent internally. All communication, and requests towards the wallet
// are dispatched as messages over channels, ensuring thread safety across all
// operations. Interaction has been designed independent of any peer-to-peer
// communication protocol, allowing the wallet to be self-contained and
// embeddable within future projects interacting with the Lightning Network.
//
// NOTE: At the moment the wallet requires a btcd full node, as it's dependent
// on btcd's websockets notifications as event triggers during the lifetime of a
// channel. However, once the chainntnfs package is complete, the wallet will
// be compatible with multiple RPC/notification services such as Electrum,
// Bitcoin Core + ZeroMQ, etc. Eventually, the wallet won't require a full-node
// at all, as SPV support is integrated into btcwallet.
type LightningWallet struct {
started int32 // To be used atomically.
shutdown int32 // To be used atomically.
nextFundingID uint64 // To be used atomically.
// Cfg is the configuration struct that will be used by the wallet to
// access the necessary interfaces and default it needs to carry on its
// duties.
Cfg Config
// WalletController is the core wallet, all non Lightning Network
// specific interaction is proxied to the internal wallet.
WalletController
// SecretKeyRing is the interface we'll use to derive any keys related
// to our purpose within the network including: multi-sig keys, node
// keys, revocation keys, etc.
keychain.SecretKeyRing
// This mutex MUST be held when performing coin selection in order to
// avoid inadvertently creating multiple funding transaction which
// double spend inputs across each other.
coinSelectMtx sync.RWMutex
// All messages to the wallet are to be sent across this channel.
msgChan chan interface{}
// Incomplete payment channels are stored in the map below. An intent
// to create a payment channel is tracked as a "reservation" within
// limbo. Once the final signatures have been exchanged, a reservation
// is removed from limbo. Each reservation is tracked by a unique
// monotonically integer. All requests concerning the channel MUST
// carry a valid, active funding ID.
fundingLimbo map[uint64]*ChannelReservation
// reservationIDs maps a pending channel ID to the reservation ID used
// as key in the fundingLimbo map. Used to easily look up a channel
// reservation given a pending channel ID.
reservationIDs map[[32]byte]uint64
limboMtx sync.RWMutex
// lockedOutPoints is a set of the currently locked outpoint. This
// information is kept in order to provide an easy way to unlock all
// the currently locked outpoints.
lockedOutPoints map[wire.OutPoint]struct{}
// fundingIntents houses all the "interception" registered by a caller
// using the RegisterFundingIntent method.
intentMtx sync.RWMutex
fundingIntents map[[32]byte]chanfunding.Intent
quit chan struct{}
wg sync.WaitGroup
}
// NewLightningWallet creates/opens and initializes a LightningWallet instance.
// If the wallet has never been created (according to the passed dataDir), first-time
// setup is executed.
func NewLightningWallet(Cfg Config) (*LightningWallet, error) {
return &LightningWallet{
Cfg: Cfg,
SecretKeyRing: Cfg.SecretKeyRing,
WalletController: Cfg.WalletController,
msgChan: make(chan interface{}, msgBufferSize),
nextFundingID: 0,
fundingLimbo: make(map[uint64]*ChannelReservation),
reservationIDs: make(map[[32]byte]uint64),
lockedOutPoints: make(map[wire.OutPoint]struct{}),
fundingIntents: make(map[[32]byte]chanfunding.Intent),
quit: make(chan struct{}),
}, nil
}
// Startup establishes a connection to the RPC source, and spins up all
// goroutines required to handle incoming messages.
func (l *LightningWallet) Startup() error {
// Already started?
if atomic.AddInt32(&l.started, 1) != 1 {
return nil
}
// Start the underlying wallet controller.
if err := l.Start(); err != nil {
return err
}
if l.Cfg.Rebroadcaster != nil {
go func() {
if err := l.Cfg.Rebroadcaster.Start(); err != nil {
walletLog.Errorf("unable to start "+
"rebroadcaster: %v", err)
}
}()
}
l.wg.Add(1)
go l.requestHandler()
return nil
}
// Shutdown gracefully stops the wallet, and all active goroutines.
func (l *LightningWallet) Shutdown() error {
if atomic.AddInt32(&l.shutdown, 1) != 1 {
return nil
}
// Signal the underlying wallet controller to shutdown, waiting until
// all active goroutines have been shutdown.
if err := l.Stop(); err != nil {
return err
}
if l.Cfg.Rebroadcaster != nil && l.Cfg.Rebroadcaster.Started() {
l.Cfg.Rebroadcaster.Stop()
}
close(l.quit)
l.wg.Wait()
return nil
}
// PublishTransaction wraps the wallet controller tx publish method with an
// extra rebroadcaster layer if the sub-system is configured.
func (l *LightningWallet) PublishTransaction(tx *wire.MsgTx,
label string) error {
sendTxToWallet := func() error {
return l.WalletController.PublishTransaction(tx, label)
}
// If we don't have rebroadcaster then we can exit early (and send only
// to the wallet).
if l.Cfg.Rebroadcaster == nil || !l.Cfg.Rebroadcaster.Started() {
return sendTxToWallet()
}
// We pass this into the rebroadcaster first, so the initial attempt
// will succeed if the transaction isn't yet in the mempool. However we
// ignore the error here as this might be resent on start up and the
// transaction already exists.
_ = l.Cfg.Rebroadcaster.Broadcast(tx)
// Then we pass things into the wallet as normal, which'll add the
// transaction label on disk.
if err := sendTxToWallet(); err != nil {
return err
}
// TODO(roasbeef): want diff height actually? no context though
_, bestHeight, err := l.Cfg.ChainIO.GetBestBlock()
if err != nil {
return err
}
txHash := tx.TxHash()
go func() {
const numConfs = 6
txConf, err := l.Cfg.Notifier.RegisterConfirmationsNtfn(
&txHash, tx.TxOut[0].PkScript, numConfs,
uint32(bestHeight),
)
if err != nil {
return
}
select {
case <-txConf.Confirmed:
// TODO(roasbeef): also want to remove from
// rebroadcaster if conflict happens...deeper wallet
// integration?
l.Cfg.Rebroadcaster.MarkAsConfirmed(tx.TxHash())
case <-l.quit:
return
}
}()
return nil
}
// ConfirmedBalance returns the current confirmed balance of a wallet account.
// This methods wraps the internal WalletController method so we're able to
// properly hold the coin select mutex while we compute the balance.
func (l *LightningWallet) ConfirmedBalance(confs int32,
account string) (btcutil.Amount, error) {
l.coinSelectMtx.Lock()
defer l.coinSelectMtx.Unlock()
return l.WalletController.ConfirmedBalance(confs, account)
}
// ListUnspentWitnessFromDefaultAccount returns all unspent outputs from the
// default wallet account which are version 0 witness programs. The 'minConfs'
// and 'maxConfs' parameters indicate the minimum and maximum number of
// confirmations an output needs in order to be returned by this method. Passing
// -1 as 'minConfs' indicates that even unconfirmed outputs should be returned.
// Using MaxInt32 as 'maxConfs' implies returning all outputs with at least
// 'minConfs'.
//
// NOTE: This method requires the global coin selection lock to be held.
func (l *LightningWallet) ListUnspentWitnessFromDefaultAccount(
minConfs, maxConfs int32) ([]*Utxo, error) {
return l.WalletController.ListUnspentWitness(
minConfs, maxConfs, DefaultAccountName,
)
}
// LockedOutpoints returns a list of all currently locked outpoint.
func (l *LightningWallet) LockedOutpoints() []*wire.OutPoint {
outPoints := make([]*wire.OutPoint, 0, len(l.lockedOutPoints))
for outPoint := range l.lockedOutPoints {
outPoint := outPoint
outPoints = append(outPoints, &outPoint)
}
return outPoints
}
// ResetReservations reset the volatile wallet state which tracks all currently
// active reservations.
func (l *LightningWallet) ResetReservations() {
l.nextFundingID = 0
l.fundingLimbo = make(map[uint64]*ChannelReservation)
l.reservationIDs = make(map[[32]byte]uint64)
for outpoint := range l.lockedOutPoints {
_ = l.ReleaseOutput(chanfunding.LndInternalLockID, outpoint)
}
l.lockedOutPoints = make(map[wire.OutPoint]struct{})
}
// ActiveReservations returns a slice of all the currently active
// (non-canceled) reservations.
func (l *LightningWallet) ActiveReservations() []*ChannelReservation {
reservations := make([]*ChannelReservation, 0, len(l.fundingLimbo))
for _, reservation := range l.fundingLimbo {
reservations = append(reservations, reservation)
}
return reservations
}
// requestHandler is the primary goroutine(s) responsible for handling, and
// dispatching replies to all messages.
func (l *LightningWallet) requestHandler() {
defer l.wg.Done()
out:
for {
select {
case m := <-l.msgChan:
switch msg := m.(type) {
case *InitFundingReserveMsg:
l.handleFundingReserveRequest(msg)
case *fundingReserveCancelMsg:
l.handleFundingCancelRequest(msg)
case *addSingleContributionMsg:
l.handleSingleContribution(msg)
case *addContributionMsg:
l.handleContributionMsg(msg)
case *continueContributionMsg:
l.handleChanPointReady(msg)
case *addSingleFunderSigsMsg:
l.handleSingleFunderSigs(msg)
case *addCounterPartySigsMsg:
l.handleFundingCounterPartySigs(msg)
}
case <-l.quit:
// TODO: do some clean up
break out
}
}
}
// InitChannelReservation kicks off the 3-step workflow required to successfully
// open a payment channel with a remote node. As part of the funding
// reservation, the inputs selected for the funding transaction are 'locked'.
// This ensures that multiple channel reservations aren't double spending the
// same inputs in the funding transaction. If reservation initialization is
// successful, a ChannelReservation containing our completed contribution is
// returned. Our contribution contains all the items necessary to allow the
// counterparty to build the funding transaction, and both versions of the
// commitment transaction. Otherwise, an error occurred and a nil pointer along
// with an error are returned.
//
// Once a ChannelReservation has been obtained, two additional steps must be
// processed before a payment channel can be considered 'open'. The second step
// validates, and processes the counterparty's channel contribution. The third,
// and final step verifies all signatures for the inputs of the funding
// transaction, and that the signature we record for our version of the
// commitment transaction is valid.
func (l *LightningWallet) InitChannelReservation(
req *InitFundingReserveMsg) (*ChannelReservation, error) {
req.resp = make(chan *ChannelReservation, 1)
req.err = make(chan error, 1)
select {
case l.msgChan <- req:
case <-l.quit:
return nil, errors.New("wallet shutting down")
}
return <-req.resp, <-req.err
}
// RegisterFundingIntent allows a caller to signal to the wallet that if a
// pending channel ID of expectedID is found, then it can skip constructing a
// new chanfunding.Assembler, and instead use the specified chanfunding.Intent.
// As an example, this lets some of the parameters for funding transaction to
// be negotiated outside the regular funding protocol.
func (l *LightningWallet) RegisterFundingIntent(expectedID [32]byte,
shimIntent chanfunding.Intent) error {
l.intentMtx.Lock()
defer l.intentMtx.Unlock()
// Sanity check the pending channel ID is not empty.
var zeroID [32]byte
if expectedID == zeroID {
return ErrEmptyPendingChanID
}
if _, ok := l.fundingIntents[expectedID]; ok {
return fmt.Errorf("%w: already has intent registered: %x",
ErrDuplicatePendingChanID, expectedID[:])
}
l.fundingIntents[expectedID] = shimIntent
return nil
}
// PsbtFundingVerify looks up a previously registered funding intent by its
// pending channel ID and tries to advance the state machine by verifying the
// passed PSBT.
func (l *LightningWallet) PsbtFundingVerify(pendingChanID [32]byte,
packet *psbt.Packet, skipFinalize bool) error {
l.intentMtx.Lock()
defer l.intentMtx.Unlock()
intent, ok := l.fundingIntents[pendingChanID]
if !ok {
return fmt.Errorf("no funding intent found for "+
"pendingChannelID(%x)", pendingChanID[:])
}
psbtIntent, ok := intent.(*chanfunding.PsbtIntent)
if !ok {
return fmt.Errorf("incompatible funding intent")
}
if skipFinalize && psbtIntent.ShouldPublishFundingTX() {
return fmt.Errorf("cannot set skip_finalize for channel that " +
"did not set no_publish")
}
err := psbtIntent.Verify(packet, skipFinalize)
if err != nil {
return fmt.Errorf("error verifying PSBT: %w", err)
}
// Get the channel reservation for that corresponds to this pending
// channel ID.
l.limboMtx.Lock()
pid, ok := l.reservationIDs[pendingChanID]
if !ok {
l.limboMtx.Unlock()
return fmt.Errorf("no channel reservation found for "+
"pendingChannelID(%x)", pendingChanID[:])
}
pendingReservation, ok := l.fundingLimbo[pid]
l.limboMtx.Unlock()
if !ok {
return fmt.Errorf("no channel reservation found for "+
"reservation ID %v", pid)
}
// Now the PSBT has been populated and verified, we can again check
// whether the value reserved for anchor fee bumping is respected.
isPublic := pendingReservation.partialState.ChannelFlags&lnwire.FFAnnounceChannel != 0
hasAnchors := pendingReservation.partialState.ChanType.HasAnchors()
return l.enforceNewReservedValue(intent, isPublic, hasAnchors)
}
// PsbtFundingFinalize looks up a previously registered funding intent by its
// pending channel ID and tries to advance the state machine by finalizing the
// passed PSBT.
func (l *LightningWallet) PsbtFundingFinalize(pid [32]byte, packet *psbt.Packet,
rawTx *wire.MsgTx) error {
l.intentMtx.Lock()
defer l.intentMtx.Unlock()
intent, ok := l.fundingIntents[pid]
if !ok {
return fmt.Errorf("no funding intent found for "+
"pendingChannelID(%x)", pid[:])
}
psbtIntent, ok := intent.(*chanfunding.PsbtIntent)
if !ok {
return fmt.Errorf("incompatible funding intent")
}
// Either the PSBT or the raw TX must be set.
switch {
case packet != nil && rawTx == nil:
err := psbtIntent.Finalize(packet)
if err != nil {
return fmt.Errorf("error finalizing PSBT: %w", err)
}
case rawTx != nil && packet == nil:
err := psbtIntent.FinalizeRawTX(rawTx)
if err != nil {
return fmt.Errorf("error finalizing raw TX: %w", err)
}
default:
return fmt.Errorf("either a PSBT or raw TX must be specified")
}
return nil
}
// CancelFundingIntent allows a caller to cancel a previously registered
// funding intent. If no intent was found, then an error will be returned.
func (l *LightningWallet) CancelFundingIntent(pid [32]byte) error {
l.intentMtx.Lock()
defer l.intentMtx.Unlock()
intent, ok := l.fundingIntents[pid]
if !ok {
return fmt.Errorf("no funding intent found for "+
"pendingChannelID(%x)", pid[:])
}
// Give the intent a chance to clean up after itself, removing coin
// locks or similar reserved resources.
intent.Cancel()
delete(l.fundingIntents, pid)
return nil
}
// handleFundingReserveRequest processes a message intending to create, and
// validate a funding reservation request.
func (l *LightningWallet) handleFundingReserveRequest(req *InitFundingReserveMsg) {
noFundsCommitted := req.LocalFundingAmt == 0 &&
req.RemoteFundingAmt == 0 && req.FundUpToMaxAmt == 0
// It isn't possible to create a channel with zero funds committed.
if noFundsCommitted {
err := ErrZeroCapacity()
req.err <- err
req.resp <- nil
return
}
// If the funding request is for a different chain than the one the
// wallet is aware of, then we'll reject the request.
if !bytes.Equal(l.Cfg.NetParams.GenesisHash[:], req.ChainHash[:]) {
err := ErrChainMismatch(
l.Cfg.NetParams.GenesisHash, req.ChainHash,
)
req.err <- err
req.resp <- nil
return
}
// We need to avoid enforcing reserved value in the middle of PSBT
// funding because some of the following steps may add UTXOs funding
// the on-chain wallet.
// The enforcement still happens at the last step - in PsbtFundingVerify
enforceNewReservedValue := true
// If no chanFunder was provided, then we'll assume the default
// assembler, which is backed by the wallet's internal coin selection.
if req.ChanFunder == nil {
// We use the P2WSH dust limit since it is larger than the
// P2WPKH dust limit and to avoid threading through two
// different dust limits.
cfg := chanfunding.WalletConfig{
CoinSource: NewCoinSource(
l, req.AllowUtxoForFunding,
),
CoinSelectLocker: l,
CoinLeaser: l,
Signer: l.Cfg.Signer,
DustLimit: DustLimitForSize(
input.P2WSHSize,
),
CoinSelectionStrategy: l.Cfg.CoinSelectionStrategy,
}
req.ChanFunder = chanfunding.NewWalletAssembler(cfg)
} else {
_, isPsbtFunder := req.ChanFunder.(*chanfunding.PsbtAssembler)
enforceNewReservedValue = !isPsbtFunder
}
localFundingAmt := req.LocalFundingAmt
remoteFundingAmt := req.RemoteFundingAmt
hasAnchors := req.CommitType.HasAnchors()
var (
fundingIntent chanfunding.Intent
err error
)
// If we've just received an inbound funding request that we have a
// registered shim intent to, then we'll obtain the backing intent now.
// In this case, we're doing a special funding workflow that allows
// more advanced constructions such as channel factories to be
// instantiated.
l.intentMtx.Lock()
fundingIntent, ok := l.fundingIntents[req.PendingChanID]
l.intentMtx.Unlock()
// Otherwise, this is a normal funding flow, so we'll use the chan
// funder in the attached request to provision the inputs/outputs
// that'll ultimately be used to construct the funding transaction.
if !ok {
var err error
var numAnchorChans int
// Get the number of anchor channels to determine if there is a
// reserved value that must be respected when funding up to the
// maximum amount. Since private channels (most likely) won't be
// used for routing other than the last hop, they bear a smaller
// risk that we must force close them in order to resolve a HTLC
// up/downstream. Hence we exclude them from the count of anchor
// channels in order to attribute the respective anchor amount
// to the channel capacity.
if req.FundUpToMaxAmt > 0 && req.MinFundAmt > 0 {
numAnchorChans, err = l.CurrentNumAnchorChans()
if err != nil {
req.err <- err
req.resp <- nil
return
}
isPublic := req.Flags&lnwire.FFAnnounceChannel != 0
if hasAnchors && isPublic {
numAnchorChans++
}
}
// Coin selection is done on the basis of sat/kw, so we'll use
// the fee rate passed in to perform coin selection.
fundingReq := &chanfunding.Request{
RemoteAmt: req.RemoteFundingAmt,
LocalAmt: req.LocalFundingAmt,
FundUpToMaxAmt: req.FundUpToMaxAmt,
MinFundAmt: req.MinFundAmt,
RemoteChanReserve: req.RemoteChanReserve,
PushAmt: lnwire.MilliSatoshi.ToSatoshis(
req.PushMSat,
),
WalletReserve: l.RequiredReserve(
uint32(numAnchorChans),
),
Outpoints: req.Outpoints,
MinConfs: req.MinConfs,
SubtractFees: req.SubtractFees,
FeeRate: req.FundingFeePerKw,
ChangeAddr: func() (btcutil.Address, error) {
return l.NewAddress(
TaprootPubkey, true, DefaultAccountName,
)
},
Musig2: req.CommitType.IsTaproot(),
}
fundingIntent, err = req.ChanFunder.ProvisionChannel(
fundingReq,
)
if err != nil {
req.err <- err
req.resp <- nil
return
}
// Register the funding intent now in case we need to access it
// again later, as it's the case for the PSBT state machine for
// example.
err = l.RegisterFundingIntent(req.PendingChanID, fundingIntent)
if err != nil {
req.err <- err
req.resp <- nil
return
}
walletLog.Debugf("Registered funding intent for "+
"PendingChanID: %x", req.PendingChanID)
localFundingAmt = fundingIntent.LocalFundingAmt()
remoteFundingAmt = fundingIntent.RemoteFundingAmt()
}
// At this point there _has_ to be a funding intent, otherwise something
// went really wrong.
if fundingIntent == nil {
req.err <- fmt.Errorf("no funding intent present")
req.resp <- nil
return
}
// If this is a shim intent, then it may be attempting to use an
// existing set of keys for the funding workflow. In this case, we'll
// make a simple wrapper keychain.KeyRing that will proxy certain
// derivation calls to future callers.
var (
keyRing keychain.KeyRing = l.SecretKeyRing
thawHeight uint32
)
if shimIntent, ok := fundingIntent.(*chanfunding.ShimIntent); ok {
keyRing = &shimKeyRing{
KeyRing: keyRing,
ShimIntent: shimIntent,
}
// As this was a registered shim intent, we'll obtain the thaw
// height of the intent, if present at all. If this is
// non-zero, then we'll mark this as the proper channel type.
thawHeight = shimIntent.ThawHeight()
}
// Now that we have a funding intent, we'll check whether funding a
// channel using it would violate our reserved value for anchor channel
// fee bumping.
//
// Check the reserved value using the inputs and outputs given by the
// intent. Note that for the PSBT intent type we don't yet have the
// funding tx ready, so this will always pass. We'll do another check
// when the PSBT has been verified.
isPublic := req.Flags&lnwire.FFAnnounceChannel != 0
if enforceNewReservedValue {
err = l.enforceNewReservedValue(fundingIntent, isPublic, hasAnchors)
if err != nil {
fundingIntent.Cancel()
req.err <- err
req.resp <- nil
return
}
}
// The total channel capacity will be the size of the funding output we
// created plus the remote contribution.
capacity := localFundingAmt + remoteFundingAmt
id := atomic.AddUint64(&l.nextFundingID, 1)
reservation, err := NewChannelReservation(
capacity, localFundingAmt, l, id, l.Cfg.NetParams.GenesisHash,
thawHeight, req,
)
if err != nil {
fundingIntent.Cancel()
req.err <- err
req.resp <- nil
return
}
err = l.initOurContribution(
reservation, fundingIntent, req.NodeAddr, req.NodeID, keyRing,
)
if err != nil {
fundingIntent.Cancel()
req.err <- err
req.resp <- nil
return
}
// Create a limbo and record entry for this newly pending funding
// request.
l.limboMtx.Lock()
l.fundingLimbo[id] = reservation
l.reservationIDs[req.PendingChanID] = id
l.limboMtx.Unlock()
// Funding reservation request successfully handled. The funding inputs
// will be marked as unavailable until the reservation is either
// completed, or canceled.
req.resp <- reservation
req.err <- nil
walletLog.Debugf("Successfully handled funding reservation with "+
"pendingChanID: %x, reservationID: %v",
reservation.pendingChanID, reservation.reservationID)
}
// enforceReservedValue enforces that the wallet, upon a new channel being
// opened, meets the minimum amount of funds required for each advertised anchor
// channel.
//
// We only enforce the reserve if we are contributing funds to the channel. This
// is done to still allow incoming channels even though we have no UTXOs
// available, as in bootstrapping phases.
func (l *LightningWallet) enforceNewReservedValue(fundingIntent chanfunding.Intent,
isPublic, hasAnchors bool) error {
// Only enforce the reserve when an advertised channel is being opened
// in which we are contributing funds to. This ensures we never dip
// below the reserve.
if !isPublic || fundingIntent.LocalFundingAmt() == 0 {
return nil
}
numAnchors, err := l.CurrentNumAnchorChans()
if err != nil {
return err
}
// Add the to-be-opened channel.
if hasAnchors {
numAnchors++
}
return l.WithCoinSelectLock(func() error {
_, err := l.CheckReservedValue(
fundingIntent.Inputs(), fundingIntent.Outputs(),
numAnchors,
)
return err
})
}
// CurrentNumAnchorChans returns the current number of non-private anchor
// channels the wallet should be ready to fee bump if needed.
func (l *LightningWallet) CurrentNumAnchorChans() (int, error) {
// Count all anchor channels that are open or pending
// open, or waiting close.
chans, err := l.Cfg.Database.FetchAllChannels()
if err != nil {
return 0, err
}
var numAnchors int
cntChannel := func(c *channeldb.OpenChannel) {
// We skip private channels, as we assume they won't be used
// for routing.
if c.ChannelFlags&lnwire.FFAnnounceChannel == 0 {
return
}
// Count anchor channels.
if c.ChanType.HasAnchors() {
numAnchors++
}
}
for _, c := range chans {
cntChannel(c)
}
// We also count pending close channels.
pendingClosed, err := l.Cfg.Database.FetchClosedChannels(
true,
)
if err != nil {
return 0, err
}
for _, c := range pendingClosed {
c, err := l.Cfg.Database.FetchHistoricalChannel(
&c.ChanPoint,
)
if err != nil {
// We don't have a guarantee that all channels re found
// in the historical channels bucket, so we continue.
walletLog.Warnf("Unable to fetch historical "+
"channel: %v", err)
continue
}
cntChannel(c)
}
return numAnchors, nil
}
// CheckReservedValue checks whether publishing a transaction with the given
// inputs and outputs would violate the value we reserve in the wallet for
// bumping the fee of anchor channels. The numAnchorChans argument should be
// set the number of open anchor channels controlled by the wallet after
// the transaction has been published.
//
// If the reserved value is violated, the returned error will be
// ErrReservedValueInvalidated. The method will also return the current
// reserved value, both in case of success and in case of
// ErrReservedValueInvalidated.
//
// NOTE: This method should only be run with the CoinSelectLock held.
func (l *LightningWallet) CheckReservedValue(in []wire.OutPoint,
out []*wire.TxOut, numAnchorChans int) (btcutil.Amount, error) {
// Get all unspent coins in the wallet. We only care about those part of
// the wallet's default account as we know we can readily sign for those
// at any time.
witnessOutputs, err := l.ListUnspentWitnessFromDefaultAccount(
0, math.MaxInt32,
)
if err != nil {
return 0, err
}
ourInput := make(map[wire.OutPoint]struct{})
for _, op := range in {
ourInput[op] = struct{}{}
}
// When crafting a transaction with inputs from the wallet, these coins
// will usually be locked in the process, and not be returned when
// listing unspents. In this case they have already been deducted from
// the wallet balance. In case they haven't been properly locked, we
// check whether they are still listed among our unspents and deduct
// them.
var walletBalance btcutil.Amount
for _, in := range witnessOutputs {
// Spending an unlocked wallet UTXO, don't add it to the
// balance.
if _, ok := ourInput[in.OutPoint]; ok {
continue
}
walletBalance += in.Value
}
// Now we go through the outputs of the transaction, if any of the
// outputs are paying into the wallet (likely a change output), we add
// it to our final balance.
for _, txOut := range out {
_, addrs, _, err := txscript.ExtractPkScriptAddrs(
txOut.PkScript, &l.Cfg.NetParams,
)
if err != nil {
// Non-standard outputs can safely be skipped because
// they're not supported by the wallet.
continue
}
for _, addr := range addrs {
if !l.IsOurAddress(addr) {
continue
}
walletBalance += btcutil.Amount(txOut.Value)
// We break since we don't want to double count the output.
break
}
}
// We reserve a given amount for each anchor channel.
reserved := l.RequiredReserve(uint32(numAnchorChans))
if walletBalance < reserved {
walletLog.Debugf("Reserved value=%v above final "+
"walletbalance=%v with %d anchor channels open",
reserved, walletBalance, numAnchorChans)
return reserved, ErrReservedValueInvalidated
}
return reserved, nil
}
// CheckReservedValueTx calls CheckReservedValue with the inputs and outputs
// from the given tx, with the number of anchor channels currently open in the
// database.
//
// NOTE: This method should only be run with the CoinSelectLock held.
func (l *LightningWallet) CheckReservedValueTx(req CheckReservedValueTxReq) (
btcutil.Amount, error) {
numAnchors, err := l.CurrentNumAnchorChans()
if err != nil {
return 0, err
}
var inputs []wire.OutPoint
for _, txIn := range req.Tx.TxIn {
inputs = append(inputs, txIn.PreviousOutPoint)
}
reservedVal, err := l.CheckReservedValue(
inputs, req.Tx.TxOut, numAnchors,
)
switch {
// If the error returned from CheckReservedValue is
// ErrReservedValueInvalidated, then it did nonetheless return
// the required reserved value and we check for the optional
// change index.
case errors.Is(err, ErrReservedValueInvalidated):
// Without a change index provided there is nothing more to
// check and the error is returned.
if req.ChangeIndex == nil {
return reservedVal, err
}
// If a change index was provided we make only sure that it
// would leave sufficient funds for the reserved balance value.
//
// Note: This is used if a change output index is explicitly set
// but that may not be watched by the wallet and therefore is
// not picked up by the call to CheckReservedValue above.
chIdx := *req.ChangeIndex
if chIdx < 0 || chIdx >= len(req.Tx.TxOut) ||
req.Tx.TxOut[chIdx].Value < int64(reservedVal) {
return reservedVal, err
}
case err != nil:
return reservedVal, err
}
return reservedVal, nil
}
// initOurContribution initializes the given ChannelReservation with our coins
// and change reserved for the channel, and derives the keys to use for this
// channel.
func (l *LightningWallet) initOurContribution(reservation *ChannelReservation,
fundingIntent chanfunding.Intent, nodeAddr net.Addr,
nodeID *btcec.PublicKey, keyRing keychain.KeyRing) error {
// Grab the mutex on the ChannelReservation to ensure thread-safety
reservation.Lock()
defer reservation.Unlock()
// At this point, if we have a funding intent, we'll use it to populate
// the existing reservation state entries for our coin selection.
if fundingIntent != nil {
if intent, ok := fundingIntent.(*chanfunding.FullIntent); ok {
for _, coin := range intent.InputCoins {
reservation.ourContribution.Inputs = append(
reservation.ourContribution.Inputs,
&wire.TxIn{
PreviousOutPoint: coin.OutPoint,
},
)
}
reservation.ourContribution.ChangeOutputs = intent.ChangeOutputs
}
reservation.fundingIntent = fundingIntent
}
reservation.nodeAddr = nodeAddr
reservation.partialState.IdentityPub = nodeID
var err error
reservation.ourContribution.MultiSigKey, err = keyRing.DeriveNextKey(
keychain.KeyFamilyMultiSig,
)
if err != nil {
return err
}
reservation.ourContribution.RevocationBasePoint, err = keyRing.DeriveNextKey(
keychain.KeyFamilyRevocationBase,
)
if err != nil {
return err
}
reservation.ourContribution.HtlcBasePoint, err = keyRing.DeriveNextKey(
keychain.KeyFamilyHtlcBase,
)
if err != nil {
return err
}
reservation.ourContribution.PaymentBasePoint, err = keyRing.DeriveNextKey(
keychain.KeyFamilyPaymentBase,
)
if err != nil {
return err
}
reservation.ourContribution.DelayBasePoint, err = keyRing.DeriveNextKey(
keychain.KeyFamilyDelayBase,
)
if err != nil {
return err
}
// With the above keys created, we'll also need to initialize our
// revocation tree state, and from that generate the per-commitment
// point.
producer, taprootNonceProducer, err := l.nextRevocationProducer(
reservation, keyRing,
)
if err != nil {
return err
}
firstPreimage, err := producer.AtIndex(0)
if err != nil {
return err
}
reservation.ourContribution.FirstCommitmentPoint = input.ComputeCommitmentPoint(
firstPreimage[:],
)
reservation.partialState.RevocationProducer = producer
reservation.ourContribution.CommitmentParams.DustLimit =
DustLimitUnknownWitness()
// If taproot channels are active, then we'll generate our verification
// nonce here. We'll use this nonce to verify the signature for our
// local commitment transaction. If we need to force close, then this
// is also what'll be used to sign that transaction.
if reservation.partialState.ChanType.IsTaproot() {
firstNoncePreimage, err := taprootNonceProducer.AtIndex(0)
if err != nil {
return err
}
// As we'd like the local nonce we send over to be generated
// deterministically, we'll provide a custom reader that
// actually just uses our sha-chain pre-image as the primary
// randomness source.
shaChainRand := musig2.WithCustomRand(
bytes.NewBuffer(firstNoncePreimage[:]),
)
pubKeyOpt := musig2.WithPublicKey(
reservation.ourContribution.MultiSigKey.PubKey,
)
reservation.ourContribution.LocalNonce, err = musig2.GenNonces(
pubKeyOpt, shaChainRand,
)
if err != nil {
return err
}
}
return nil
}
// handleFundingReserveCancel cancels an existing channel reservation. As part
// of the cancellation, outputs previously selected as inputs for the funding
// transaction via coin selection are freed allowing future reservations to
// include them.
func (l *LightningWallet) handleFundingCancelRequest(req *fundingReserveCancelMsg) {
l.limboMtx.Lock()
defer l.limboMtx.Unlock()
pendingReservation, ok := l.fundingLimbo[req.pendingFundingID]
if !ok {
// TODO(roasbeef): make new error, "unknown funding state" or something
req.err <- fmt.Errorf("attempted to cancel non-existent funding state")
return
}
// Grab the mutex on the ChannelReservation to ensure thread-safety
pendingReservation.Lock()
defer pendingReservation.Unlock()
// Mark all previously locked outpoints as usable for future funding
// requests.
for _, unusedInput := range pendingReservation.ourContribution.Inputs {
delete(l.lockedOutPoints, unusedInput.PreviousOutPoint)
_ = l.ReleaseOutput(
chanfunding.LndInternalLockID,
unusedInput.PreviousOutPoint,
)
}
delete(l.fundingLimbo, req.pendingFundingID)
pid := pendingReservation.pendingChanID
delete(l.reservationIDs, pid)
l.intentMtx.Lock()
if intent, ok := l.fundingIntents[pid]; ok {
intent.Cancel()
delete(l.fundingIntents, pendingReservation.pendingChanID)
}
l.intentMtx.Unlock()
req.err <- nil
}
// createCommitOpts is a struct that holds the options for creating a new
// commitment transaction.
type createCommitOpts struct {
localAuxLeaves fn.Option[CommitAuxLeaves]
remoteAuxLeaves fn.Option[CommitAuxLeaves]
}
// defaultCommitOpts returns a new createCommitOpts with default values.
func defaultCommitOpts() createCommitOpts {
return createCommitOpts{}
}
// WithAuxLeaves is a functional option that can be used to set the aux leaves
// for a new commitment transaction.
func WithAuxLeaves(localLeaves,
remoteLeaves fn.Option[CommitAuxLeaves]) CreateCommitOpt {
return func(o *createCommitOpts) {
o.localAuxLeaves = localLeaves
o.remoteAuxLeaves = remoteLeaves
}
}
// CreateCommitOpt is a functional option that can be used to modify the way a
// new commitment transaction is created.
type CreateCommitOpt func(*createCommitOpts)
// CreateCommitmentTxns is a helper function that creates the initial
// commitment transaction for both parties. This function is used during the
// initial funding workflow as both sides must generate a signature for the
// remote party's commitment transaction, and verify the signature for their
// version of the commitment transaction.
func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount,
ourChanCfg, theirChanCfg *channeldb.ChannelConfig,
localCommitPoint, remoteCommitPoint *btcec.PublicKey,
fundingTxIn wire.TxIn, chanType channeldb.ChannelType, initiator bool,
leaseExpiry uint32, opts ...CreateCommitOpt) (*wire.MsgTx, *wire.MsgTx,
error) {
options := defaultCommitOpts()
for _, optFunc := range opts {
optFunc(&options)
}
localCommitmentKeys := DeriveCommitmentKeys(
localCommitPoint, lntypes.Local, chanType, ourChanCfg,
theirChanCfg,
)
remoteCommitmentKeys := DeriveCommitmentKeys(
remoteCommitPoint, lntypes.Remote, chanType, ourChanCfg,
theirChanCfg,
)
ourCommitTx, err := CreateCommitTx(
chanType, fundingTxIn, localCommitmentKeys, ourChanCfg,
theirChanCfg, localBalance, remoteBalance, 0, initiator,
leaseExpiry, options.localAuxLeaves,
)
if err != nil {
return nil, nil, err
}
otxn := btcutil.NewTx(ourCommitTx)
if err := blockchain.CheckTransactionSanity(otxn); err != nil {
return nil, nil, err
}
theirCommitTx, err := CreateCommitTx(
chanType, fundingTxIn, remoteCommitmentKeys, theirChanCfg,
ourChanCfg, remoteBalance, localBalance, 0, !initiator,
leaseExpiry, options.remoteAuxLeaves,
)
if err != nil {
return nil, nil, err
}
ttxn := btcutil.NewTx(theirCommitTx)
if err := blockchain.CheckTransactionSanity(ttxn); err != nil {
return nil, nil, err
}
return ourCommitTx, theirCommitTx, nil
}
// handleContributionMsg processes the second workflow step for the lifetime of
// a channel reservation. Upon completion, the reservation will carry a
// completed funding transaction (minus the counterparty's input signatures),
// both versions of the commitment transaction, and our signature for their
// version of the commitment transaction.
func (l *LightningWallet) handleContributionMsg(req *addContributionMsg) {
l.limboMtx.Lock()
pendingReservation, ok := l.fundingLimbo[req.pendingFundingID]
l.limboMtx.Unlock()
if !ok {
req.err <- fmt.Errorf("attempted to update non-existent funding state")
return
}
// Grab the mutex on the ChannelReservation to ensure thread-safety
pendingReservation.Lock()
defer pendingReservation.Unlock()
// If UpfrontShutdownScript is set, validate that it is a valid script.
shutdown := req.contribution.UpfrontShutdown
if len(shutdown) > 0 {
// Validate the shutdown script.
if !ValidateUpfrontShutdown(shutdown, &l.Cfg.NetParams) {
req.err <- fmt.Errorf("invalid shutdown script")
return
}
}
// Some temporary variables to cut down on the resolution verbosity.
pendingReservation.theirContribution = req.contribution
theirContribution := req.contribution
ourContribution := pendingReservation.ourContribution
// Perform bounds-checking on both ChannelReserve and DustLimit
// parameters.
if !pendingReservation.validateReserveBounds() {
req.err <- fmt.Errorf("invalid reserve and dust bounds")
return
}
var (
chanPoint *wire.OutPoint
err error
)
// At this point, we can now construct our channel point. Depending on
// which type of intent we obtained from our chanfunding.Assembler,
// we'll carry out a distinct set of steps.
switch fundingIntent := pendingReservation.fundingIntent.(type) {
// The transaction was created outside of the wallet and might already
// be published. Nothing left to do other than using the correct
// outpoint.
case *chanfunding.ShimIntent:
chanPoint, err = fundingIntent.ChanPoint()
if err != nil {
req.err <- fmt.Errorf("unable to obtain chan point: %w",
err)
return
}
pendingReservation.partialState.FundingOutpoint = *chanPoint
// The user has signaled that they want to use a PSBT to construct the
// funding transaction. Because we now have the multisig keys from both
// parties, we can create the multisig script that needs to be funded
// and then pause the process until the user supplies the PSBT
// containing the eventual funding transaction.
case *chanfunding.PsbtIntent:
if fundingIntent.PendingPsbt != nil {
req.err <- fmt.Errorf("PSBT funding already in" +
"progress")
return
}
// Now that we know our contribution, we can bind both the local
// and remote key which will be needed to calculate the multisig
// funding output in a next step.
pendingChanID := pendingReservation.pendingChanID
walletLog.Debugf("Advancing PSBT funding flow for "+
"pending_id(%x), binding keys local_key=%v, "+
"remote_key=%x", pendingChanID,
&ourContribution.MultiSigKey,
theirContribution.MultiSigKey.PubKey.SerializeCompressed())
fundingIntent.BindKeys(
&ourContribution.MultiSigKey,
theirContribution.MultiSigKey.PubKey,
)
// We might have a tapscript root, so we'll bind that now to
// ensure we make the proper funding output.
fundingIntent.BindTapscriptRoot(
pendingReservation.partialState.TapscriptRoot,
)
// Exit early because we can't continue the funding flow yet.
req.err <- &PsbtFundingRequired{
Intent: fundingIntent,
}
return
case *chanfunding.FullIntent:
// Now that we know their public key, we can bind theirs as
// well as ours to the funding intent.
fundingIntent.BindKeys(
&pendingReservation.ourContribution.MultiSigKey,
theirContribution.MultiSigKey.PubKey,
)
// With our keys bound, we can now construct+sign the final
// funding transaction and also obtain the chanPoint that
// creates the channel.
fundingTx, err := fundingIntent.CompileFundingTx(
theirContribution.Inputs,
theirContribution.ChangeOutputs,
)
if err != nil {
req.err <- fmt.Errorf("unable to construct funding "+
"tx: %v", err)
return
}
chanPoint, err = fundingIntent.ChanPoint()
if err != nil {
req.err <- fmt.Errorf("unable to obtain chan "+
"point: %v", err)
return
}
// Finally, we'll populate the relevant information in our
// pendingReservation so the rest of the funding flow can
// continue as normal.
pendingReservation.fundingTx = fundingTx
pendingReservation.partialState.FundingOutpoint = *chanPoint
pendingReservation.ourFundingInputScripts = make(
[]*input.Script, 0, len(ourContribution.Inputs),
)
for _, txIn := range fundingTx.TxIn {
_, err := l.FetchOutpointInfo(&txIn.PreviousOutPoint)
if err != nil {
continue
}
pendingReservation.ourFundingInputScripts = append(
pendingReservation.ourFundingInputScripts,
&input.Script{
Witness: txIn.Witness,
SigScript: txIn.SignatureScript,
},
)
}
walletLog.Tracef("Funding tx for ChannelPoint(%v) "+
"generated: %v", chanPoint, spew.Sdump(fundingTx))
}
// If we landed here and didn't exit early, it means we already have
// the channel point ready. We can jump directly to the next step.
l.handleChanPointReady(&continueContributionMsg{
pendingFundingID: req.pendingFundingID,
err: req.err,
})
}
// genMusigSession generates a new musig2 pair session that we can use to sign
// the commitment transaction for the remote party, and verify their incoming
// partial signature.
func genMusigSession(ourContribution, theirContribution *ChannelContribution,
signer input.MuSig2Signer, fundingOutput *wire.TxOut,
tapscriptRoot fn.Option[chainhash.Hash]) *MusigPairSession {
return NewMusigPairSession(&MusigSessionCfg{
LocalKey: ourContribution.MultiSigKey,
RemoteKey: theirContribution.MultiSigKey,
LocalNonce: *ourContribution.LocalNonce,
RemoteNonce: *theirContribution.LocalNonce,
Signer: signer,
InputTxOut: fundingOutput,
TapscriptTweak: tapscriptRoot,
})
}
// signCommitTx generates a valid input.Signature to send to the remote party
// for their version of the commitment transaction. For regular channels, this
// will be a normal ECDSA signature. For taproot channels, this will instead be
// a musig2 partial signature that also includes the nonce used to generate it.
func (l *LightningWallet) signCommitTx(pendingReservation *ChannelReservation,
commitTx *wire.MsgTx, fundingOutput *wire.TxOut,
fundingWitnessScript []byte) (input.Signature, error) {
ourContribution := pendingReservation.ourContribution
theirContribution := pendingReservation.theirContribution
var (
sigTheirCommit input.Signature
err error
)
switch {
// For regular channels, we can just send over a normal ECDSA signature
// w/o any extra steps.
case !pendingReservation.partialState.ChanType.IsTaproot():
ourKey := ourContribution.MultiSigKey
signDesc := input.SignDescriptor{
WitnessScript: fundingWitnessScript,
KeyDesc: ourKey,
Output: fundingOutput,
HashType: txscript.SigHashAll,
SigHashes: input.NewTxSigHashesV0Only(
commitTx,
),
InputIndex: 0,
}
sigTheirCommit, err = l.Cfg.Signer.SignOutputRaw(
commitTx, &signDesc,
)
if err != nil {
return nil, err
}
// If this is a taproot channel, then we'll need to create an initial
// musig2 session here as we'll be sending over a _partial_ signature.
default:
// We're now ready to sign the first commitment. However, we'll
// only create the session if that hasn't been done already.
if pendingReservation.musigSessions == nil {
musigSessions := genMusigSession(
ourContribution, theirContribution,
l.Cfg.Signer, fundingOutput,
pendingReservation.partialState.TapscriptRoot,
)
pendingReservation.musigSessions = musigSessions
}
// Now that we have the funding outpoint, we'll generate a
// musig2 signature for their version of the commitment
// transaction. We use the remote session as this is for the
// remote commitment transaction.
musigSessions := pendingReservation.musigSessions
partialSig, err := musigSessions.RemoteSession.SignCommit(
commitTx,
)
if err != nil {
return nil, fmt.Errorf("unable to sign "+
"commitment: %w", err)
}
sigTheirCommit = partialSig
}
return sigTheirCommit, nil
}
// handleChanPointReady continues the funding process once the channel point is
// known and the funding transaction can be completed.
func (l *LightningWallet) handleChanPointReady(req *continueContributionMsg) {
l.limboMtx.Lock()
pendingReservation, ok := l.fundingLimbo[req.pendingFundingID]
l.limboMtx.Unlock()
if !ok {
req.err <- fmt.Errorf("attempted to update non-existent " +
"funding state")
return
}
chanState := pendingReservation.partialState
// If we have an aux funding desc, then we can use it to populate some
// of the optional, but opaque TLV blobs we'll carry for the channel.
chanState.CustomBlob = fn.MapOption(func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomFundingBlob
})(req.auxFundingDesc)
chanState.LocalCommitment.CustomBlob = fn.MapOption(
func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomLocalCommitBlob
},
)(req.auxFundingDesc)
chanState.RemoteCommitment.CustomBlob = fn.MapOption(
func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomRemoteCommitBlob
},
)(req.auxFundingDesc)
ourContribution := pendingReservation.ourContribution
theirContribution := pendingReservation.theirContribution
chanPoint := pendingReservation.partialState.FundingOutpoint
// If we're in the PSBT funding flow, we now should have everything that
// is needed to construct and publish the full funding transaction.
intent := pendingReservation.fundingIntent
if psbtIntent, ok := intent.(*chanfunding.PsbtIntent); ok {
// With our keys bound, we can now construct and possibly sign
// the final funding transaction and also obtain the chanPoint
// that creates the channel. We _have_ to call CompileFundingTx
// even if we don't publish ourselves as that sets the actual
// funding outpoint in stone for this channel.
fundingTx, err := psbtIntent.CompileFundingTx()
if err != nil {
req.err <- fmt.Errorf("unable to construct funding "+
"tx: %v", err)
return
}
chanPointPtr, err := psbtIntent.ChanPoint()
if err != nil {
req.err <- fmt.Errorf("unable to obtain chan "+
"point: %v", err)
return
}
pendingReservation.partialState.FundingOutpoint = *chanPointPtr
chanPoint = *chanPointPtr
// Finally, we'll populate the relevant information in our
// pendingReservation so the rest of the funding flow can
// continue as normal in case we are going to publish ourselves.
if psbtIntent.ShouldPublishFundingTX() {
pendingReservation.fundingTx = fundingTx
pendingReservation.ourFundingInputScripts = make(
[]*input.Script, 0, len(ourContribution.Inputs),
)
for _, txIn := range fundingTx.TxIn {
pendingReservation.ourFundingInputScripts = append(
pendingReservation.ourFundingInputScripts,
&input.Script{
Witness: txIn.Witness,
SigScript: txIn.SignatureScript,
},
)
}
}
}
// Initialize an empty sha-chain for them, tracking the current pending
// revocation hash (we don't yet know the preimage so we can't add it
// to the chain).
s := shachain.NewRevocationStore()
pendingReservation.partialState.RevocationStore = s
// Store their current commitment point. We'll need this after the
// first state transition in order to verify the authenticity of the
// revocation.
chanState.RemoteCurrentRevocation = theirContribution.FirstCommitmentPoint
// Create the txin to our commitment transaction; required to construct
// the commitment transactions.
fundingTxIn := wire.TxIn{
PreviousOutPoint: chanPoint,
}
// With the funding tx complete, create both commitment transactions.
localBalance := pendingReservation.partialState.LocalCommitment.LocalBalance.ToSatoshis()
remoteBalance := pendingReservation.partialState.LocalCommitment.RemoteBalance.ToSatoshis()
var leaseExpiry uint32
if pendingReservation.partialState.ChanType.HasLeaseExpiration() {
leaseExpiry = pendingReservation.partialState.ThawHeight
}
localAuxLeaves := fn.MapOption(
func(desc AuxFundingDesc) CommitAuxLeaves {
return desc.LocalInitAuxLeaves
},
)(req.auxFundingDesc)
remoteAuxLeaves := fn.MapOption(
func(desc AuxFundingDesc) CommitAuxLeaves {
return desc.RemoteInitAuxLeaves
},
)(req.auxFundingDesc)
ourCommitTx, theirCommitTx, err := CreateCommitmentTxns(
localBalance, remoteBalance, ourContribution.ChannelConfig,
theirContribution.ChannelConfig,
ourContribution.FirstCommitmentPoint,
theirContribution.FirstCommitmentPoint, fundingTxIn,
pendingReservation.partialState.ChanType,
pendingReservation.partialState.IsInitiator, leaseExpiry,
WithAuxLeaves(localAuxLeaves, remoteAuxLeaves),
)
if err != nil {
req.err <- err
return
}
// With both commitment transactions constructed, generate the state
// obfuscator then use it to encode the current state number within
// both commitment transactions.
var stateObfuscator [StateHintSize]byte
if chanState.ChanType.IsSingleFunder() {
stateObfuscator = DeriveStateHintObfuscator(
ourContribution.PaymentBasePoint.PubKey,
theirContribution.PaymentBasePoint.PubKey,
)
} else {
ourSer := ourContribution.PaymentBasePoint.PubKey.SerializeCompressed()
theirSer := theirContribution.PaymentBasePoint.PubKey.SerializeCompressed()
switch bytes.Compare(ourSer, theirSer) {
case -1:
stateObfuscator = DeriveStateHintObfuscator(
ourContribution.PaymentBasePoint.PubKey,
theirContribution.PaymentBasePoint.PubKey,
)
default:
stateObfuscator = DeriveStateHintObfuscator(
theirContribution.PaymentBasePoint.PubKey,
ourContribution.PaymentBasePoint.PubKey,
)
}
}
err = initStateHints(ourCommitTx, theirCommitTx, stateObfuscator)
if err != nil {
req.err <- err
return
}
// Sort both transactions according to the agreed upon canonical
// ordering. This lets us skip sending the entire transaction over,
// instead we'll just send signatures.
txsort.InPlaceSort(ourCommitTx)
txsort.InPlaceSort(theirCommitTx)
walletLog.Tracef("Local commit tx for ChannelPoint(%v): %v",
chanPoint, spew.Sdump(ourCommitTx))
walletLog.Tracef("Remote commit tx for ChannelPoint(%v): %v",
chanPoint, spew.Sdump(theirCommitTx))
// Record newly available information within the open channel state.
chanState.FundingOutpoint = chanPoint
chanState.LocalCommitment.CommitTx = ourCommitTx
chanState.RemoteCommitment.CommitTx = theirCommitTx
// Next, we'll obtain the funding witness script, and the funding
// output itself so we can generate a valid signature for the remote
// party.
fundingIntent := pendingReservation.fundingIntent
fundingWitnessScript, fundingOutput, err := fundingIntent.FundingOutput()
if err != nil {
req.err <- fmt.Errorf("unable to obtain funding "+
"output: %w", err)
return
}
// Generate a signature for their version of the initial commitment
// transaction.
sigTheirCommit, err := l.signCommitTx(
pendingReservation, theirCommitTx, fundingOutput,
fundingWitnessScript,
)
if err != nil {
req.err <- err
return
}
pendingReservation.ourCommitmentSig = sigTheirCommit
req.err <- nil
}
// handleSingleContribution is called as the second step to a single funder
// workflow to which we are the responder. It simply saves the remote peer's
// contribution to the channel, as solely the remote peer will contribute any
// funds to the channel.
func (l *LightningWallet) handleSingleContribution(req *addSingleContributionMsg) {
l.limboMtx.Lock()
pendingReservation, ok := l.fundingLimbo[req.pendingFundingID]
l.limboMtx.Unlock()
if !ok {
req.err <- fmt.Errorf("attempted to update non-existent funding state")
return
}
// Grab the mutex on the channelReservation to ensure thread-safety.
pendingReservation.Lock()
defer pendingReservation.Unlock()
// Validate that the remote's UpfrontShutdownScript is a valid script
// if it's set.
shutdown := req.contribution.UpfrontShutdown
if len(shutdown) > 0 {
// Validate the shutdown script.
if !ValidateUpfrontShutdown(shutdown, &l.Cfg.NetParams) {
req.err <- fmt.Errorf("invalid shutdown script")
return
}
}
// Simply record the counterparty's contribution into the pending
// reservation data as they'll be solely funding the channel entirely.
pendingReservation.theirContribution = req.contribution
theirContribution := pendingReservation.theirContribution
chanState := pendingReservation.partialState
// Perform bounds checking on both ChannelReserve and DustLimit
// parameters. The ChannelReserve may have been changed by the
// ChannelAcceptor RPC, so this is necessary.
if !pendingReservation.validateReserveBounds() {
req.err <- fmt.Errorf("invalid reserve and dust bounds")
return
}
// Initialize an empty sha-chain for them, tracking the current pending
// revocation hash (we don't yet know the preimage so we can't add it
// to the chain).
remotePreimageStore := shachain.NewRevocationStore()
chanState.RevocationStore = remotePreimageStore
// Now that we've received their first commitment point, we'll store it
// within the channel state so we can sync it to disk once the funding
// process is complete.
chanState.RemoteCurrentRevocation = theirContribution.FirstCommitmentPoint
req.err <- nil
}
// verifyFundingInputs attempts to verify all remote inputs to the funding
// transaction.
func (l *LightningWallet) verifyFundingInputs(fundingTx *wire.MsgTx,
remoteInputScripts []*input.Script) error {
sigIndex := 0
fundingHashCache := input.NewTxSigHashesV0Only(fundingTx)
inputScripts := remoteInputScripts
for i, txin := range fundingTx.TxIn {
if len(inputScripts) != 0 && len(txin.Witness) == 0 {
// Attach the input scripts so we can verify it below.
txin.Witness = inputScripts[sigIndex].Witness
txin.SignatureScript = inputScripts[sigIndex].SigScript
// Fetch the alleged previous output along with the
// pkscript referenced by this input.
//
// TODO(roasbeef): when dual funder pass actual
// height-hint
//
// TODO(roasbeef): this fails for neutrino always as it
// treats the height hint as an exact birthday of the
// utxo rather than a lower bound
pkScript, err := txscript.ComputePkScript(
txin.SignatureScript, txin.Witness,
)
if err != nil {
return fmt.Errorf("cannot create script: %w",
err)
}
output, err := l.Cfg.ChainIO.GetUtxo(
&txin.PreviousOutPoint,
pkScript.Script(), 0, l.quit,
)
if output == nil {
return fmt.Errorf("input to funding tx does "+
"not exist: %v", err)
}
// Ensure that the witness+sigScript combo is valid.
vm, err := txscript.NewEngine(
output.PkScript, fundingTx, i,
txscript.StandardVerifyFlags, nil,
fundingHashCache, output.Value,
txscript.NewCannedPrevOutputFetcher(
output.PkScript, output.Value,
),
)
if err != nil {
return fmt.Errorf("cannot create script "+
"engine: %s", err)
}
if err = vm.Execute(); err != nil {
return fmt.Errorf("cannot validate "+
"transaction: %s", err)
}
sigIndex++
}
}
return nil
}
// verifyCommitSig verifies an incoming signature for our version of the
// commitment transaction. For normal channels, this will verify that the ECDSA
// signature is valid. For taproot channels, we'll verify that their partial
// signature is valid, so it can properly be combined with our eventual
// signature when we need to broadcast.
func (l *LightningWallet) verifyCommitSig(res *ChannelReservation,
commitSig input.Signature, commitTx *wire.MsgTx) error {
localKey := res.ourContribution.MultiSigKey.PubKey
remoteKey := res.theirContribution.MultiSigKey.PubKey
channelValue := int64(res.partialState.Capacity)
switch {
// If this isn't a taproot channel, then we'll construct a segwit v0
// p2wsh sighash.
case !res.partialState.ChanType.IsTaproot():
hashCache := input.NewTxSigHashesV0Only(commitTx)
witnessScript, _, err := input.GenFundingPkScript(
localKey.SerializeCompressed(),
remoteKey.SerializeCompressed(), channelValue,
)
if err != nil {
return err
}
sigHash, err := txscript.CalcWitnessSigHash(
witnessScript, hashCache, txscript.SigHashAll,
commitTx, 0, channelValue,
)
if err != nil {
return err
}
// Verify that we've received a valid signature from the remote
// party for our version of the commitment transaction.
if !commitSig.Verify(sigHash, remoteKey) {
return fmt.Errorf("counterparty's commitment " +
"signature is invalid")
}
return nil
// Otherwise for taproot channels, we'll compute the segwit v1 sighash,
// which is slightly different.
default:
// First, check to see if we've generated the musig session
// already. If we're the responder in the funding flow, we may
// not have generated it already.
if res.musigSessions == nil {
_, fundingOutput, err := input.GenTaprootFundingScript(
localKey, remoteKey, channelValue,
res.partialState.TapscriptRoot,
)
if err != nil {
return err
}
res.musigSessions = genMusigSession(
res.ourContribution, res.theirContribution,
l.Cfg.Signer, fundingOutput,
res.partialState.TapscriptRoot,
)
}
// For the musig2 based channels, we'll use the generated local
// musig2 session to verify the signature.
localSession := res.musigSessions.LocalSession
// At this point, the commitment signature passed in should
// actually be a wrapped musig2 signature, so we'll do a type
// asset to the get the signature we actually need.
partialSig, ok := commitSig.(*MusigPartialSig)
if !ok {
return fmt.Errorf("expected *musig2.PartialSignature, "+
"got: %T", commitSig)
}
_, err := localSession.VerifyCommitSig(
commitTx, partialSig.ToWireSig(),
)
return err
}
}
// handleFundingCounterPartySigs is the final step in the channel reservation
// workflow. During this step, we validate *all* the received signatures for
// inputs to the funding transaction. If any of these are invalid, we bail,
// and forcibly cancel this funding request. Additionally, we ensure that the
// signature we received from the counterparty for our version of the commitment
// transaction allows us to spend from the funding output with the addition of
// our signature.
func (l *LightningWallet) handleFundingCounterPartySigs(msg *addCounterPartySigsMsg) {
l.limboMtx.RLock()
res, ok := l.fundingLimbo[msg.pendingFundingID]
l.limboMtx.RUnlock()
if !ok {
msg.err <- fmt.Errorf("attempted to update non-existent funding state")
return
}
// Grab the mutex on the ChannelReservation to ensure thread-safety
res.Lock()
defer res.Unlock()
// Now we can complete the funding transaction by adding their
// signatures to their inputs.
res.theirFundingInputScripts = msg.theirFundingInputScripts
inputScripts := msg.theirFundingInputScripts
// Only if we have the final funding transaction do we need to verify
// the final set of inputs. Otherwise, it may be the case that the
// channel was funded via an external wallet.
fundingTx := res.fundingTx
if res.partialState.ChanType.HasFundingTx() {
err := l.verifyFundingInputs(fundingTx, inputScripts)
if err != nil {
msg.err <- err
msg.completeChan <- nil
return
}
}
// At this point, we can also record and verify their signature for our
// commitment transaction.
res.theirCommitmentSig = msg.theirCommitmentSig
commitTx := res.partialState.LocalCommitment.CommitTx
err := l.verifyCommitSig(res, msg.theirCommitmentSig, commitTx)
if err != nil {
msg.err <- fmt.Errorf("counterparty's commitment signature is "+
"invalid: %w", err)
msg.completeChan <- nil
return
}
theirCommitSigBytes := msg.theirCommitmentSig.Serialize()
res.partialState.LocalCommitment.CommitSig = theirCommitSigBytes
// Funding complete, this entry can be removed from limbo.
l.limboMtx.Lock()
delete(l.fundingLimbo, res.reservationID)
delete(l.reservationIDs, res.pendingChanID)
l.limboMtx.Unlock()
l.intentMtx.Lock()
delete(l.fundingIntents, res.pendingChanID)
l.intentMtx.Unlock()
// As we're about to broadcast the funding transaction, we'll take note
// of the current height for record keeping purposes.
_, bestHeight, err := l.Cfg.ChainIO.GetBestBlock()
if err != nil {
msg.err <- err
msg.completeChan <- nil
return
}
// As we've completed the funding process, we'll no convert the
// contribution structs into their underlying channel config objects to
// he stored within the database.
res.partialState.LocalChanCfg = res.ourContribution.toChanConfig()
res.partialState.RemoteChanCfg = res.theirContribution.toChanConfig()
// We'll also record the finalized funding txn, which will allow us to
// rebroadcast on startup in case we fail.
res.partialState.FundingTxn = fundingTx
// Set optional upfront shutdown scripts on the channel state so that they
// are persisted. These values may be nil.
res.partialState.LocalShutdownScript =
res.ourContribution.UpfrontShutdown
res.partialState.RemoteShutdownScript =
res.theirContribution.UpfrontShutdown
res.partialState.RevocationKeyLocator = res.nextRevocationKeyLoc
// Add the complete funding transaction to the DB, in its open bucket
// which will be used for the lifetime of this channel.
nodeAddr := res.nodeAddr
err = res.partialState.SyncPending(nodeAddr, uint32(bestHeight))
if err != nil {
msg.err <- err
msg.completeChan <- nil
return
}
msg.completeChan <- res.partialState
msg.err <- nil
}
// handleSingleFunderSigs is called once the remote peer who initiated the
// single funder workflow has assembled the funding transaction, and generated
// a signature for our version of the commitment transaction. This method
// progresses the workflow by generating a signature for the remote peer's
// version of the commitment transaction.
func (l *LightningWallet) handleSingleFunderSigs(req *addSingleFunderSigsMsg) {
l.limboMtx.RLock()
pendingReservation, ok := l.fundingLimbo[req.pendingFundingID]
l.limboMtx.RUnlock()
if !ok {
req.err <- fmt.Errorf("attempted to update non-existent funding state")
req.completeChan <- nil
return
}
// Grab the mutex on the ChannelReservation to ensure thread-safety
pendingReservation.Lock()
defer pendingReservation.Unlock()
chanState := pendingReservation.partialState
// If we have an aux funding desc, then we can use it to populate some
// of the optional, but opaque TLV blobs we'll carry for the channel.
chanState.CustomBlob = fn.MapOption(func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomFundingBlob
})(req.auxFundingDesc)
chanState.LocalCommitment.CustomBlob = fn.MapOption(
func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomLocalCommitBlob
},
)(req.auxFundingDesc)
chanState.RemoteCommitment.CustomBlob = fn.MapOption(
func(desc AuxFundingDesc) tlv.Blob {
return desc.CustomRemoteCommitBlob
},
)(req.auxFundingDesc)
chanType := pendingReservation.partialState.ChanType
chanState.FundingOutpoint = *req.fundingOutpoint
fundingTxIn := wire.NewTxIn(req.fundingOutpoint, nil, nil)
// Now that we have the funding outpoint, we can generate both versions
// of the commitment transaction, and generate a signature for the
// remote node's commitment transactions.
localBalance := pendingReservation.partialState.LocalCommitment.LocalBalance.ToSatoshis()
remoteBalance := pendingReservation.partialState.LocalCommitment.RemoteBalance.ToSatoshis()
var leaseExpiry uint32
if pendingReservation.partialState.ChanType.HasLeaseExpiration() {
leaseExpiry = pendingReservation.partialState.ThawHeight
}
localAuxLeaves := fn.MapOption(
func(desc AuxFundingDesc) CommitAuxLeaves {
return desc.LocalInitAuxLeaves
},
)(req.auxFundingDesc)
remoteAuxLeaves := fn.MapOption(
func(desc AuxFundingDesc) CommitAuxLeaves {
return desc.RemoteInitAuxLeaves
},
)(req.auxFundingDesc)
ourCommitTx, theirCommitTx, err := CreateCommitmentTxns(
localBalance, remoteBalance,
pendingReservation.ourContribution.ChannelConfig,
pendingReservation.theirContribution.ChannelConfig,
pendingReservation.ourContribution.FirstCommitmentPoint,
pendingReservation.theirContribution.FirstCommitmentPoint,
*fundingTxIn, chanType,
pendingReservation.partialState.IsInitiator, leaseExpiry,
WithAuxLeaves(localAuxLeaves, remoteAuxLeaves),
)
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
// With both commitment transactions constructed, we can now use the
// generator state obfuscator to encode the current state number within
// both commitment transactions.
stateObfuscator := DeriveStateHintObfuscator(
pendingReservation.theirContribution.PaymentBasePoint.PubKey,
pendingReservation.ourContribution.PaymentBasePoint.PubKey,
)
err = initStateHints(ourCommitTx, theirCommitTx, stateObfuscator)
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
// Sort both transactions according to the agreed upon canonical
// ordering. This ensures that both parties sign the same sighash
// without further synchronization.
txsort.InPlaceSort(ourCommitTx)
txsort.InPlaceSort(theirCommitTx)
chanState.LocalCommitment.CommitTx = ourCommitTx
chanState.RemoteCommitment.CommitTx = theirCommitTx
walletLog.Debugf("Local commit tx for ChannelPoint(%v): %v",
req.fundingOutpoint, spew.Sdump(ourCommitTx))
walletLog.Debugf("Remote commit tx for ChannelPoint(%v): %v",
req.fundingOutpoint, spew.Sdump(theirCommitTx))
// With both commitment transactions created, we'll now verify their
// signature on our commitment.
err = l.verifyCommitSig(
pendingReservation, req.theirCommitmentSig, ourCommitTx,
)
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
theirCommitSigBytes := req.theirCommitmentSig.Serialize()
chanState.LocalCommitment.CommitSig = theirCommitSigBytes
channelValue := int64(pendingReservation.partialState.Capacity)
theirKey := pendingReservation.theirContribution.MultiSigKey
ourKey := pendingReservation.ourContribution.MultiSigKey
var (
fundingWitnessScript []byte
fundingTxOut *wire.TxOut
)
if chanType.IsTaproot() {
//nolint:ll
fundingWitnessScript, fundingTxOut, err = input.GenTaprootFundingScript(
ourKey.PubKey, theirKey.PubKey, channelValue,
pendingReservation.partialState.TapscriptRoot,
)
} else {
//nolint:ll
fundingWitnessScript, fundingTxOut, err = input.GenFundingPkScript(
ourKey.PubKey.SerializeCompressed(),
theirKey.PubKey.SerializeCompressed(), channelValue,
)
}
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
// With their signature for our version of the commitment transactions
// verified, we can now generate a signature for their version,
// allowing the funding transaction to be safely broadcast.
sigTheirCommit, err := l.signCommitTx(
pendingReservation, theirCommitTx, fundingTxOut,
fundingWitnessScript,
)
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
pendingReservation.ourCommitmentSig = sigTheirCommit
_, bestHeight, err := l.Cfg.ChainIO.GetBestBlock()
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
// Set optional upfront shutdown scripts on the channel state so that they
// are persisted. These values may be nil.
chanState.LocalShutdownScript =
pendingReservation.ourContribution.UpfrontShutdown
chanState.RemoteShutdownScript =
pendingReservation.theirContribution.UpfrontShutdown
// Add the complete funding transaction to the DB, in it's open bucket
// which will be used for the lifetime of this channel.
chanState.LocalChanCfg = pendingReservation.ourContribution.toChanConfig()
chanState.RemoteChanCfg = pendingReservation.theirContribution.toChanConfig()
chanState.RevocationKeyLocator = pendingReservation.nextRevocationKeyLoc
err = chanState.SyncPending(pendingReservation.nodeAddr, uint32(bestHeight))
if err != nil {
req.err <- err
req.completeChan <- nil
return
}
req.completeChan <- chanState
req.err <- nil
l.limboMtx.Lock()
delete(l.fundingLimbo, req.pendingFundingID)
delete(l.reservationIDs, pendingReservation.pendingChanID)
l.limboMtx.Unlock()
l.intentMtx.Lock()
delete(l.fundingIntents, pendingReservation.pendingChanID)
l.intentMtx.Unlock()
}
// WithCoinSelectLock will execute the passed function closure in a
// synchronized manner preventing any coin selection operations from proceeding
// while the closure is executing. This can be seen as the ability to execute a
// function closure under an exclusive coin selection lock.
func (l *LightningWallet) WithCoinSelectLock(f func() error) error {
l.coinSelectMtx.Lock()
defer l.coinSelectMtx.Unlock()
return f()
}
// DeriveStateHintObfuscator derives the bytes to be used for obfuscating the
// state hints from the root to be used for a new channel. The obfuscator is
// generated via the following computation:
//
// - sha256(initiatorKey || responderKey)[26:]
// -- where both keys are the multi-sig keys of the respective parties
//
// The first 6 bytes of the resulting hash are used as the state hint.
func DeriveStateHintObfuscator(key1, key2 *btcec.PublicKey) [StateHintSize]byte {
h := sha256.New()
h.Write(key1.SerializeCompressed())
h.Write(key2.SerializeCompressed())
sha := h.Sum(nil)
var obfuscator [StateHintSize]byte
copy(obfuscator[:], sha[26:])
return obfuscator
}
// initStateHints properly sets the obfuscated state hints on both commitment
// transactions using the passed obfuscator.
func initStateHints(commit1, commit2 *wire.MsgTx,
obfuscator [StateHintSize]byte) error {
if err := SetStateNumHint(commit1, 0, obfuscator); err != nil {
return err
}
if err := SetStateNumHint(commit2, 0, obfuscator); err != nil {
return err
}
return nil
}
// ValidateChannel will attempt to fully validate a newly mined channel, given
// its funding transaction and existing channel state. If this method returns
// an error, then the mined channel is invalid, and shouldn't be used.
func (l *LightningWallet) ValidateChannel(channelState *channeldb.OpenChannel,
fundingTx *wire.MsgTx) error {
var chanOpts []ChannelOpt
l.Cfg.AuxLeafStore.WhenSome(func(s AuxLeafStore) {
chanOpts = append(chanOpts, WithLeafStore(s))
})
l.Cfg.AuxSigner.WhenSome(func(s AuxSigner) {
chanOpts = append(chanOpts, WithAuxSigner(s))
})
// First, we'll obtain a fully signed commitment transaction so we can
// pass into it on the chanvalidate package for verification.
channel, err := NewLightningChannel(
l.Cfg.Signer, channelState, nil, chanOpts...,
)
if err != nil {
return err
}
localKey := channelState.LocalChanCfg.MultiSigKey.PubKey
remoteKey := channelState.RemoteChanCfg.MultiSigKey.PubKey
// We'll also need the multi-sig witness script itself so the
// chanvalidate package can check it for correctness against the
// funding transaction, and also commitment validity.
var fundingScript []byte
if channelState.ChanType.IsTaproot() {
fundingScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, int64(channel.Capacity),
channelState.TapscriptRoot,
)
if err != nil {
return err
}
} else {
witnessScript, err := input.GenMultiSigScript(
localKey.SerializeCompressed(),
remoteKey.SerializeCompressed(),
)
if err != nil {
return err
}
fundingScript, err = input.WitnessScriptHash(witnessScript)
if err != nil {
return err
}
}
signedCommitTx, err := channel.getSignedCommitTx()
if err != nil {
return err
}
commitCtx := &chanvalidate.CommitmentContext{
Value: channel.Capacity,
FullySignedCommitTx: signedCommitTx,
}
// Finally, we'll pass in all the necessary context needed to fully
// validate that this channel is indeed what we expect, and can be
// used.
_, err = chanvalidate.Validate(&chanvalidate.Context{
Locator: &chanvalidate.OutPointChanLocator{
ChanPoint: channelState.FundingOutpoint,
},
MultiSigPkScript: fundingScript,
FundingTx: fundingTx,
CommitCtx: commitCtx,
})
if err != nil {
return err
}
return nil
}
// CancelRebroadcast cancels the rebroadcast of the given transaction.
func (l *LightningWallet) CancelRebroadcast(txid chainhash.Hash) {
// For neutrino, we don't config the rebroadcaster for the wallet as it
// manages the rebroadcasting logic in neutrino itself.
if l.Cfg.Rebroadcaster != nil {
l.Cfg.Rebroadcaster.MarkAsConfirmed(txid)
}
}
// CoinSource is a wrapper around the wallet that implements the
// chanfunding.CoinSource interface.
type CoinSource struct {
wallet *LightningWallet
allowUtxo func(Utxo) bool
}
// NewCoinSource creates a new instance of the CoinSource wrapper struct.
func NewCoinSource(w *LightningWallet, allowUtxo func(Utxo) bool) *CoinSource {
return &CoinSource{
wallet: w,
allowUtxo: allowUtxo,
}
}
// ListCoins returns all UTXOs from the source that have between
// minConfs and maxConfs number of confirmations.
func (c *CoinSource) ListCoins(minConfs int32,
maxConfs int32) ([]wallet.Coin, error) {
utxos, err := c.wallet.ListUnspentWitnessFromDefaultAccount(
minConfs, maxConfs,
)
if err != nil {
return nil, err
}
var coins []wallet.Coin
for _, utxo := range utxos {
// If there is a filter function supplied all utxos not adhering
// to these conditions will be discarded.
if c.allowUtxo != nil && !c.allowUtxo(*utxo) {
walletLog.Infof("Cannot use unconfirmed "+
"utxo=%v because it is unstable and could be "+
"replaced", utxo.OutPoint)
continue
}
coins = append(coins, wallet.Coin{
TxOut: wire.TxOut{
Value: int64(utxo.Value),
PkScript: utxo.PkScript,
},
OutPoint: utxo.OutPoint,
})
}
return coins, nil
}
// CoinFromOutPoint attempts to locate details pertaining to a coin based on
// its outpoint. If the coin isn't under the control of the backing CoinSource,
// then an error should be returned.
func (c *CoinSource) CoinFromOutPoint(op wire.OutPoint) (*wallet.Coin, error) {
inputInfo, err := c.wallet.FetchOutpointInfo(&op)
if err != nil {
return nil, err
}
return &wallet.Coin{
TxOut: wire.TxOut{
Value: int64(inputInfo.Value),
PkScript: inputInfo.PkScript,
},
OutPoint: inputInfo.OutPoint,
}, nil
}
// shimKeyRing is a wrapper struct that's used to provide the proper multi-sig
// key for an initiated external funding flow.
type shimKeyRing struct {
keychain.KeyRing
*chanfunding.ShimIntent
}
// DeriveNextKey intercepts the normal DeriveNextKey call to a keychain.KeyRing
// instance, and supplies the multi-sig key specified by the ShimIntent. This
// allows us to transparently insert new keys into the existing funding flow,
// as these keys may not come from the wallet itself.
func (s *shimKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
if keyFam != keychain.KeyFamilyMultiSig {
return s.KeyRing.DeriveNextKey(keyFam)
}
fundingKeys, err := s.ShimIntent.MultiSigKeys()
if err != nil {
return keychain.KeyDescriptor{}, err
}
return *fundingKeys.LocalKey, nil
}
// ValidateUpfrontShutdown checks whether the provided upfront_shutdown_script
// is of a valid type that we accept.
func ValidateUpfrontShutdown(shutdown lnwire.DeliveryAddress,
params *chaincfg.Params) bool {
// We don't need to worry about a large UpfrontShutdownScript since it
// was already checked in lnwire when decoding from the wire.
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(shutdown, params)
switch {
case scriptClass == txscript.WitnessV0PubKeyHashTy,
scriptClass == txscript.WitnessV0ScriptHashTy,
scriptClass == txscript.WitnessV1TaprootTy:
// The above three types are permitted according to BOLT#02 and
// BOLT#05. Everything else is disallowed.
return true
// In this case, we don't know about the actual script template, but it
// might be a witness program with versions 2-16. So we'll check that
// now
case txscript.IsWitnessProgram(shutdown):
version, _, err := txscript.ExtractWitnessProgramInfo(shutdown)
if err != nil {
walletLog.Warnf("unable to extract witness program "+
"version (script=%x): %v", shutdown, err)
return false
}
return version >= 1 && version <= 16
default:
return false
}
}
// WalletPrevOutputFetcher is a txscript.PrevOutputFetcher that can fetch
// outputs from a given wallet controller.
type WalletPrevOutputFetcher struct {
wc WalletController
}
// A compile time assertion that WalletPrevOutputFetcher implements the
// txscript.PrevOutputFetcher interface.
var _ txscript.PrevOutputFetcher = (*WalletPrevOutputFetcher)(nil)
// NewWalletPrevOutputFetcher creates a new WalletPrevOutputFetcher that fetches
// previous outputs from the given wallet controller.
func NewWalletPrevOutputFetcher(wc WalletController) *WalletPrevOutputFetcher {
return &WalletPrevOutputFetcher{
wc: wc,
}
}
// FetchPrevOutput attempts to fetch the previous output referenced by the
// passed outpoint. A nil value will be returned if the passed outpoint doesn't
// exist.
func (w *WalletPrevOutputFetcher) FetchPrevOutput(op wire.OutPoint) *wire.TxOut {
utxo, err := w.wc.FetchOutpointInfo(&op)
if err != nil {
return nil
}
return &wire.TxOut{
Value: int64(utxo.Value),
PkScript: utxo.PkScript,
}
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
// AcceptChannel is the message Bob sends to Alice after she initiates the
// single funder channel workflow via an AcceptChannel message. Once Alice
// receives Bob's response, then she has all the items necessary to construct
// the funding transaction, and both commitment transactions.
type AcceptChannel struct {
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// DustLimit is the specific dust limit the sender of this message
// would like enforced on their version of the commitment transaction.
// Any output below this value will be "trimmed" from the commitment
// transaction, with the amount of the HTLC going to dust.
DustLimit btcutil.Amount
// MaxValueInFlight represents the maximum amount of coins that can be
// pending within the channel at any given time. If the amount of funds
// in limbo exceeds this amount, then the channel will be failed.
MaxValueInFlight MilliSatoshi
// ChannelReserve is the amount of BTC that the receiving party MUST
// maintain a balance above at all times. This is a safety mechanism to
// ensure that both sides always have skin in the game during the
// channel's lifetime.
ChannelReserve btcutil.Amount
// HtlcMinimum is the smallest HTLC that the sender of this message
// will accept.
HtlcMinimum MilliSatoshi
// MinAcceptDepth is the minimum depth that the initiator of the
// channel should wait before considering the channel open.
MinAcceptDepth uint32
// CsvDelay is the number of blocks to use for the relative time lock
// in the pay-to-self output of both commitment transactions.
CsvDelay uint16
// MaxAcceptedHTLCs is the total number of incoming HTLC's that the
// sender of this channel will accept.
//
// TODO(roasbeef): acks the initiator's, same with max in flight?
MaxAcceptedHTLCs uint16
// FundingKey is the key that should be used on behalf of the sender
// within the 2-of-2 multi-sig output that it contained within the
// funding transaction.
FundingKey *btcec.PublicKey
// RevocationPoint is the base revocation point for the sending party.
// Any commitment transaction belonging to the receiver of this message
// should use this key and their per-commitment point to derive the
// revocation key for the commitment transaction.
RevocationPoint *btcec.PublicKey
// PaymentPoint is the base payment point for the sending party. This
// key should be combined with the per commitment point for a
// particular commitment state in order to create the key that should
// be used in any output that pays directly to the sending party, and
// also within the HTLC covenant transactions.
PaymentPoint *btcec.PublicKey
// DelayedPaymentPoint is the delay point for the sending party. This
// key should be combined with the per commitment point to derive the
// keys that are used in outputs of the sender's commitment transaction
// where they claim funds.
DelayedPaymentPoint *btcec.PublicKey
// HtlcPoint is the base point used to derive the set of keys for this
// party that will be used within the HTLC public key scripts. This
// value is combined with the receiver's revocation base point in order
// to derive the keys that are used within HTLC scripts.
HtlcPoint *btcec.PublicKey
// FirstCommitmentPoint is the first commitment point for the sending
// party. This value should be combined with the receiver's revocation
// base point in order to derive the revocation keys that are placed
// within the commitment transaction of the sender.
FirstCommitmentPoint *btcec.PublicKey
// UpfrontShutdownScript is the script to which the channel funds should
// be paid when mutually closing the channel. This field is optional, and
// and has a length prefix, so a zero will be written if it is not set
// and its length followed by the script will be written if it is set.
UpfrontShutdownScript DeliveryAddress
// ChannelType is the explicit channel type the initiator wishes to
// open.
ChannelType *ChannelType
// LeaseExpiry represents the absolute expiration height of a channel
// lease. This is a custom TLV record that will only apply when a leased
// channel is being opened using the script enforced lease commitment
// type.
LeaseExpiry *LeaseExpiry
// LocalNonce is an optional field that transmits the
// local/verification nonce for a party. This nonce will be used to
// verify the very first commitment transaction signature.
// This will only be populated if the simple taproot channels type was
// negotiated.
LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
//
// NOTE: Since the upfront shutdown script MUST be present (though can
// be zero-length) if any TLV data is available, the script will be
// extracted and removed from this blob when decoding. ExtraData will
// contain all TLV records _except_ the DeliveryAddress record in that
// case.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure AcceptChannel implements the lnwire.Message
// interface.
var _ Message = (*AcceptChannel)(nil)
// A compile time check to ensure AcceptChannel implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*AcceptChannel)(nil)
// Encode serializes the target AcceptChannel into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := []tlv.RecordProducer{&a.UpfrontShutdownScript}
if a.ChannelType != nil {
recordProducers = append(recordProducers, a.ChannelType)
}
if a.LeaseExpiry != nil {
recordProducers = append(recordProducers, a.LeaseExpiry)
}
a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, &localNonce)
})
err := EncodeMessageExtraData(&a.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteBytes(w, a.PendingChannelID[:]); err != nil {
return err
}
if err := WriteSatoshi(w, a.DustLimit); err != nil {
return err
}
if err := WriteMilliSatoshi(w, a.MaxValueInFlight); err != nil {
return err
}
if err := WriteSatoshi(w, a.ChannelReserve); err != nil {
return err
}
if err := WriteMilliSatoshi(w, a.HtlcMinimum); err != nil {
return err
}
if err := WriteUint32(w, a.MinAcceptDepth); err != nil {
return err
}
if err := WriteUint16(w, a.CsvDelay); err != nil {
return err
}
if err := WriteUint16(w, a.MaxAcceptedHTLCs); err != nil {
return err
}
if err := WritePublicKey(w, a.FundingKey); err != nil {
return err
}
if err := WritePublicKey(w, a.RevocationPoint); err != nil {
return err
}
if err := WritePublicKey(w, a.PaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, a.DelayedPaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, a.HtlcPoint); err != nil {
return err
}
if err := WritePublicKey(w, a.FirstCommitmentPoint); err != nil {
return err
}
return WriteBytes(w, a.ExtraData)
}
// Decode deserializes the serialized AcceptChannel stored in the passed
// io.Reader into the target AcceptChannel using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error {
// Read all the mandatory fields in the accept message.
err := ReadElements(r,
a.PendingChannelID[:],
&a.DustLimit,
&a.MaxValueInFlight,
&a.ChannelReserve,
&a.HtlcMinimum,
&a.MinAcceptDepth,
&a.CsvDelay,
&a.MaxAcceptedHTLCs,
&a.FundingKey,
&a.RevocationPoint,
&a.PaymentPoint,
&a.DelayedPaymentPoint,
&a.HtlcPoint,
&a.FirstCommitmentPoint,
)
if err != nil {
return err
}
// For backwards compatibility, the optional extra data blob for
// AcceptChannel must contain an entry for the upfront shutdown script.
// We'll read it out and attempt to parse it.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
// Next we'll parse out the set of known records, keeping the raw tlv
// bytes untouched to ensure we don't drop any bytes erroneously.
var (
chanType ChannelType
leaseExpiry LeaseExpiry
localNonce = a.LocalNonce.Zero()
)
typeMap, err := tlvRecords.ExtractRecords(
&a.UpfrontShutdownScript, &chanType, &leaseExpiry,
&localNonce,
)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil {
a.ChannelType = &chanType
}
if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil {
a.LeaseExpiry = &leaseExpiry
}
if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil {
a.LocalNonce = tlv.SomeRecordT(localNonce)
}
a.ExtraData = tlvRecords
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as an AcceptChannel on the wire.
//
// This is part of the lnwire.Message interface.
func (a *AcceptChannel) MsgType() MessageType {
return MsgAcceptChannel
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *AcceptChannel) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
package lnwire
import (
"bytes"
"io"
)
// AnnounceSignatures1 is a direct message between two endpoints of a
// channel and serves as an opt-in mechanism to allow the announcement of
// the channel to the rest of the network. It contains the necessary
// signatures by the sender to construct the channel announcement message.
type AnnounceSignatures1 struct {
// ChannelID is the unique description of the funding transaction.
// Channel id is better for users and debugging and short channel id is
// used for quick test on existence of the particular utxo inside the
// block chain, because it contains information about block.
ChannelID ChannelID
// ShortChannelID is the unique description of the funding
// transaction. It is constructed with the most significant 3 bytes
// as the block height, the next 3 bytes indicating the transaction
// index within the block, and the least significant two bytes
// indicating the output index which pays to the channel.
ShortChannelID ShortChannelID
// NodeSignature is the signature which contains the signed announce
// channel message, by this signature we proof that we possess of the
// node pub key and creating the reference node_key -> bitcoin_key.
NodeSignature Sig
// BitcoinSignature is the signature which contains the signed node
// public key, by this signature we proof that we possess of the
// bitcoin key and and creating the reverse reference bitcoin_key ->
// node_key.
BitcoinSignature Sig
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// A compile time check to ensure AnnounceSignatures1 implements the
// lnwire.Message interface.
var _ Message = (*AnnounceSignatures1)(nil)
// A compile time check to ensure AnnounceSignatures1 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*AnnounceSignatures1)(nil)
// A compile time check to ensure AnnounceSignatures1 implements the
// lnwire.AnnounceSignatures interface.
var _ AnnounceSignatures = (*AnnounceSignatures1)(nil)
// Decode deserializes a serialized AnnounceSignatures1 stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures1) Decode(r io.Reader, _ uint32) error {
return ReadElements(r,
&a.ChannelID,
&a.ShortChannelID,
&a.NodeSignature,
&a.BitcoinSignature,
&a.ExtraOpaqueData,
)
}
// Encode serializes the target AnnounceSignatures1 into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures1) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, a.ChannelID); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteSig(w, a.NodeSignature); err != nil {
return err
}
if err := WriteSig(w, a.BitcoinSignature); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures1) MsgType() MessageType {
return MsgAnnounceSignatures
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *AnnounceSignatures1) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
// SCID returns the ShortChannelID of the channel.
//
// This is part of the lnwire.AnnounceSignatures interface.
func (a *AnnounceSignatures1) SCID() ShortChannelID {
return a.ShortChannelID
}
// ChanID returns the ChannelID identifying the channel.
//
// This is part of the lnwire.AnnounceSignatures interface.
func (a *AnnounceSignatures1) ChanID() ChannelID {
return a.ChannelID
}
package lnwire
import (
"bytes"
"io"
)
// AnnounceSignatures2 is a direct message between two endpoints of a
// channel and serves as an opt-in mechanism to allow the announcement of
// a taproot channel to the rest of the network. It contains the necessary
// signatures by the sender to construct the channel_announcement_2 message.
type AnnounceSignatures2 struct {
// ChannelID is the unique description of the funding transaction.
// Channel id is better for users and debugging and short channel id is
// used for quick test on existence of the particular utxo inside the
// blockchain, because it contains information about block.
ChannelID ChannelID
// ShortChannelID is the unique description of the funding transaction.
// It is constructed with the most significant 3 bytes as the block
// height, the next 3 bytes indicating the transaction index within the
// block, and the least significant two bytes indicating the output
// index which pays to the channel.
ShortChannelID ShortChannelID
// PartialSignature is the combination of the partial Schnorr signature
// created for the node's bitcoin key with the partial signature created
// for the node's node ID key.
PartialSignature PartialSig
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// A compile time check to ensure AnnounceSignatures2 implements the
// lnwire.Message interface.
var _ Message = (*AnnounceSignatures2)(nil)
// A compile time check to ensure AnnounceSignatures2 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*AnnounceSignatures2)(nil)
// Decode deserializes a serialized AnnounceSignatures2 stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error {
return ReadElements(r,
&a.ChannelID,
&a.ShortChannelID,
&a.PartialSignature,
&a.ExtraOpaqueData,
)
}
// Encode serializes the target AnnounceSignatures2 into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, a.ChannelID); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteElement(w, a.PartialSignature); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures2) MsgType() MessageType {
return MsgAnnounceSignatures2
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *AnnounceSignatures2) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
// SCID returns the ShortChannelID of the channel.
//
// NOTE: this is part of the AnnounceSignatures interface.
func (a *AnnounceSignatures2) SCID() ShortChannelID {
return a.ShortChannelID
}
// ChanID returns the ChannelID identifying the channel.
//
// NOTE: this is part of the AnnounceSignatures interface.
func (a *AnnounceSignatures2) ChanID() ChannelID {
return a.ChannelID
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ChannelAnnouncement1 message is used to announce the existence of a channel
// between two peers in the overlay, which is propagated by the discovery
// service over broadcast handler.
type ChannelAnnouncement1 struct {
// This signatures are used by nodes in order to create cross
// references between node's channel and node. Requiring both nodes
// to sign indicates they are both willing to route other payments via
// this node.
NodeSig1 Sig
NodeSig2 Sig
// This signatures are used by nodes in order to create cross
// references between node's channel and node. Requiring the bitcoin
// signatures proves they control the channel.
BitcoinSig1 Sig
BitcoinSig2 Sig
// Features is the feature vector that encodes the features supported
// by the target node. This field can be used to signal the type of the
// channel, or modifications to the fields that would normally follow
// this vector.
Features *RawFeatureVector
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
ChainHash chainhash.Hash
// ShortChannelID is the unique description of the funding transaction,
// or where exactly it's located within the target blockchain.
ShortChannelID ShortChannelID
// The public keys of the two nodes who are operating the channel, such
// that is NodeID1 the numerically-lesser than NodeID2 (ascending
// numerical order).
NodeID1 [33]byte
NodeID2 [33]byte
// Public keys which corresponds to the keys which was declared in
// multisig funding transaction output.
BitcoinKey1 [33]byte
BitcoinKey2 [33]byte
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// A compile time check to ensure ChannelAnnouncement implements the
// lnwire.Message interface.
var _ Message = (*ChannelAnnouncement1)(nil)
// A compile time check to ensure ChannelAnnouncement1 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelAnnouncement1)(nil)
// Decode deserializes a serialized ChannelAnnouncement stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement1) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&a.NodeSig1,
&a.NodeSig2,
&a.BitcoinSig1,
&a.BitcoinSig2,
&a.Features,
a.ChainHash[:],
&a.ShortChannelID,
&a.NodeID1,
&a.NodeID2,
&a.BitcoinKey1,
&a.BitcoinKey2,
&a.ExtraOpaqueData,
)
}
// Encode serializes the target ChannelAnnouncement into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement1) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteSig(w, a.NodeSig1); err != nil {
return err
}
if err := WriteSig(w, a.NodeSig2); err != nil {
return err
}
if err := WriteSig(w, a.BitcoinSig1); err != nil {
return err
}
if err := WriteSig(w, a.BitcoinSig2); err != nil {
return err
}
if err := WriteRawFeatureVector(w, a.Features); err != nil {
return err
}
if err := WriteBytes(w, a.ChainHash[:]); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID2[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey1[:]); err != nil {
return err
}
if err := WriteBytes(w, a.BitcoinKey2[:]); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement1) MsgType() MessageType {
return MsgChannelAnnouncement
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *ChannelAnnouncement1) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
// DataToSign is used to retrieve part of the announcement message which should
// be signed.
func (a *ChannelAnnouncement1) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
b := make([]byte, 0, MaxMsgBody)
buf := bytes.NewBuffer(b)
if err := WriteRawFeatureVector(buf, a.Features); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
return nil, err
}
if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey1[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.BitcoinKey2[:]); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Node1KeyBytes returns the bytes representing the public key of node 1 in the
// channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (a *ChannelAnnouncement1) Node1KeyBytes() [33]byte {
return a.NodeID1
}
// Node2KeyBytes returns the bytes representing the public key of node 2 in the
// channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (a *ChannelAnnouncement1) Node2KeyBytes() [33]byte {
return a.NodeID2
}
// GetChainHash returns the hash of the chain which this channel's funding
// transaction is confirmed in.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (a *ChannelAnnouncement1) GetChainHash() chainhash.Hash {
return a.ChainHash
}
// SCID returns the short channel ID of the channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (a *ChannelAnnouncement1) SCID() ShortChannelID {
return a.ShortChannelID
}
// A compile-time check to ensure that ChannelAnnouncement1 implements the
// ChannelAnnouncement interface.
var _ ChannelAnnouncement = (*ChannelAnnouncement1)(nil)
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// ChannelAnnouncement2 message is used to announce the existence of a taproot
// channel between two peers in the network.
type ChannelAnnouncement2 struct {
// Signature is a Schnorr signature over the TLV stream of the message.
Signature Sig
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash]
// Features is the feature vector that encodes the features supported
// by the target node. This field can be used to signal the type of the
// channel, or modifications to the fields that would normally follow
// this vector.
Features tlv.RecordT[tlv.TlvType2, RawFeatureVector]
// ShortChannelID is the unique description of the funding transaction,
// or where exactly it's located within the target blockchain.
ShortChannelID tlv.RecordT[tlv.TlvType4, ShortChannelID]
// Capacity is the number of satoshis of the capacity of this channel.
// It must be less than or equal to the value of the on-chain funding
// output.
Capacity tlv.RecordT[tlv.TlvType6, uint64]
// NodeID1 is the numerically-lesser public key ID of one of the channel
// operators.
NodeID1 tlv.RecordT[tlv.TlvType8, [33]byte]
// NodeID2 is the numerically-greater public key ID of one of the
// channel operators.
NodeID2 tlv.RecordT[tlv.TlvType10, [33]byte]
// BitcoinKey1 is the public key of the key used by Node1 in the
// construction of the on-chain funding transaction. This is an optional
// field and only needs to be set if the 4-of-4 MuSig construction was
// used in the creation of the message signature.
BitcoinKey1 tlv.OptionalRecordT[tlv.TlvType12, [33]byte]
// BitcoinKey2 is the public key of the key used by Node2 in the
// construction of the on-chain funding transaction. This is an optional
// field and only needs to be set if the 4-of-4 MuSig construction was
// used in the creation of the message signature.
BitcoinKey2 tlv.OptionalRecordT[tlv.TlvType14, [33]byte]
// MerkleRootHash is the hash used to create the optional tweak in the
// funding output. If this is not set but the bitcoin keys are, then
// the funding output is a pure 2-of-2 MuSig aggregate public key.
MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte]
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// Decode deserializes a serialized AnnounceSignatures1 stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error {
err := ReadElement(r, &c.Signature)
if err != nil {
return err
}
c.Signature.ForceSchnorr()
return c.DecodeTLVRecords(r)
}
// DecodeTLVRecords decodes only the TLV section of the message.
func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error {
// First extract into extra opaque data.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var (
chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]()
btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]()
merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]()
)
typeMap, err := tlvRecords.ExtractRecords(
&chainHash, &c.Features, &c.ShortChannelID, &c.Capacity,
&c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash,
)
if err != nil {
return err
}
// By default, the chain-hash is the bitcoin mainnet genesis block hash.
c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash
if _, ok := typeMap[c.ChainHash.TlvType()]; ok {
c.ChainHash.Val = chainHash.Val
}
if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok {
c.BitcoinKey1 = tlv.SomeRecordT(btcKey1)
}
if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok {
c.BitcoinKey2 = tlv.SomeRecordT(btcKey2)
}
if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok {
c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash)
}
if len(tlvRecords) != 0 {
c.ExtraOpaqueData = tlvRecords
}
return nil
}
// Encode serializes the target AnnounceSignatures1 into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error {
_, err := w.Write(c.Signature.RawBytes())
if err != nil {
return err
}
_, err = c.DataToSign()
if err != nil {
return err
}
return WriteBytes(w, c.ExtraOpaqueData)
}
// DataToSign encodes the data to be signed into the ExtraOpaqueData member and
// returns it.
func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) {
// The chain-hash record is only included if it is _not_ equal to the
// bitcoin mainnet genisis block hash.
var recordProducers []tlv.RecordProducer
if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) {
hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
hash.Val = c.ChainHash.Val
recordProducers = append(recordProducers, &hash)
}
recordProducers = append(recordProducers,
&c.Features, &c.ShortChannelID, &c.Capacity, &c.NodeID1,
&c.NodeID2,
)
c.BitcoinKey1.WhenSome(func(key tlv.RecordT[tlv.TlvType12, [33]byte]) {
recordProducers = append(recordProducers, &key)
})
c.BitcoinKey2.WhenSome(func(key tlv.RecordT[tlv.TlvType14, [33]byte]) {
recordProducers = append(recordProducers, &key)
})
c.MerkleRootHash.WhenSome(
func(hash tlv.RecordT[tlv.TlvType16, [32]byte]) {
recordProducers = append(recordProducers, &hash)
},
)
err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...)
if err != nil {
return nil, err
}
return c.ExtraOpaqueData, nil
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ChannelAnnouncement2) MsgType() MessageType {
return MsgChannelAnnouncement2
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ChannelAnnouncement2) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// A compile time check to ensure ChannelAnnouncement2 implements the
// lnwire.Message interface.
var _ Message = (*ChannelAnnouncement2)(nil)
// A compile time check to ensure ChannelAnnouncement2 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelAnnouncement2)(nil)
// Node1KeyBytes returns the bytes representing the public key of node 1 in the
// channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (c *ChannelAnnouncement2) Node1KeyBytes() [33]byte {
return c.NodeID1.Val
}
// Node2KeyBytes returns the bytes representing the public key of node 2 in the
// channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (c *ChannelAnnouncement2) Node2KeyBytes() [33]byte {
return c.NodeID2.Val
}
// GetChainHash returns the hash of the chain which this channel's funding
// transaction is confirmed in.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (c *ChannelAnnouncement2) GetChainHash() chainhash.Hash {
return c.ChainHash.Val
}
// SCID returns the short channel ID of the channel.
//
// NOTE: This is part of the ChannelAnnouncement interface.
func (c *ChannelAnnouncement2) SCID() ShortChannelID {
return c.ShortChannelID.Val
}
// A compile-time check to ensure that ChannelAnnouncement2 implements the
// ChannelAnnouncement interface.
var _ ChannelAnnouncement = (*ChannelAnnouncement2)(nil)
package lnwire
import (
"encoding/binary"
"encoding/hex"
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
const (
// MaxFundingTxOutputs is the maximum number of allowed outputs on a
// funding transaction within the protocol. This is due to the fact
// that we use 2-bytes to encode the index within the funding output
// during the funding workflow. Funding transaction with more outputs
// than this are considered invalid within the protocol.
MaxFundingTxOutputs = math.MaxUint16
)
// ChannelID is a series of 32-bytes that uniquely identifies all channels
// within the network. The ChannelID is computed using the outpoint of the
// funding transaction (the txid, and output index). Given a funding output the
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
// txid and the big-endian serialization of the output index, truncated to
// 2 bytes.
type ChannelID [32]byte
// ConnectionWideID is an all-zero ChannelID, which is used to represent a
// message intended for all channels to specific peer.
var ConnectionWideID = ChannelID{}
// String returns the string representation of the ChannelID. This is just the
// hex string encoding of the ChannelID itself.
func (c ChannelID) String() string {
return hex.EncodeToString(c[:])
}
// NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is
// usable within the network. In order to convert the OutPoint into a ChannelID,
// we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian
// serialization of the Index of the OutPoint, truncated to 2-bytes.
func NewChanIDFromOutPoint(op wire.OutPoint) ChannelID {
// First we'll copy the txid of the outpoint into our channel ID slice.
var cid ChannelID
copy(cid[:], op.Hash[:])
// With the txid copied over, we'll now XOR the lower 2-bytes of the
// partial channelID with big-endian serialization of output index.
xorTxid(&cid, uint16(op.Index))
return cid
}
// xorTxid performs the transformation needed to transform an OutPoint into a
// ChannelID. To do this, we expect the cid parameter to contain the txid
// unaltered and the outputIndex to be the output index
func xorTxid(cid *ChannelID, outputIndex uint16) {
var buf [2]byte
binary.BigEndian.PutUint16(buf[:], outputIndex)
cid[30] ^= buf[0]
cid[31] ^= buf[1]
}
// GenPossibleOutPoints generates all the possible outputs given a channel ID.
// In order to generate these possible outpoints, we perform a brute-force
// search through the candidate output index space, performing a reverse
// mapping from channelID back to OutPoint.
func (c *ChannelID) GenPossibleOutPoints() [MaxFundingTxOutputs]wire.OutPoint {
var possiblePoints [MaxFundingTxOutputs]wire.OutPoint
for i := uint16(0); i < MaxFundingTxOutputs; i++ {
cidCopy := *c
xorTxid(&cidCopy, i)
possiblePoints[i] = wire.OutPoint{
Hash: chainhash.Hash(cidCopy),
Index: uint32(i),
}
}
return possiblePoints
}
// IsChanPoint returns true if the OutPoint passed corresponds to the target
// ChannelID.
func (c ChannelID) IsChanPoint(op *wire.OutPoint) bool {
candidateCid := NewChanIDFromOutPoint(*op)
return candidateCid == c
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tlv"
)
// ChannelReady is the message that both parties to a new channel creation
// send once they have observed the funding transaction being confirmed on the
// blockchain. ChannelReady contains the signatures necessary for the channel
// participants to advertise the existence of the channel to the rest of the
// network.
type ChannelReady struct {
// ChanID is the outpoint of the channel's funding transaction. This
// can be used to query for the channel in the database.
ChanID ChannelID
// NextPerCommitmentPoint is the secret that can be used to revoke the
// next commitment transaction for the channel.
NextPerCommitmentPoint *btcec.PublicKey
// AliasScid is an alias ShortChannelID used to refer to the underlying
// channel. It can be used instead of the confirmed on-chain
// ShortChannelID for forwarding.
AliasScid *ShortChannelID
// NextLocalNonce is an optional field that stores a local musig2 nonce.
// This will only be populated if the simple taproot channels type was
// negotiated. This is the local nonce that will be used by the sender
// to accept a new commitment state transition.
NextLocalNonce OptMusig2NonceTLV
// AnnouncementNodeNonce is an optional field that stores a public
// nonce that will be used along with the node's ID key during signing
// of the ChannelAnnouncement2 message.
AnnouncementNodeNonce tlv.OptionalRecordT[tlv.TlvType0, Musig2Nonce]
// AnnouncementBitcoinNonce is an optional field that stores a public
// nonce that will be used along with the node's bitcoin key during
// signing of the ChannelAnnouncement2 message.
AnnouncementBitcoinNonce tlv.OptionalRecordT[tlv.TlvType2, Musig2Nonce]
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewChannelReady creates a new ChannelReady message, populating it with the
// necessary IDs and revocation secret.
func NewChannelReady(cid ChannelID, npcp *btcec.PublicKey) *ChannelReady {
return &ChannelReady{
ChanID: cid,
NextPerCommitmentPoint: npcp,
ExtraData: nil,
}
}
// A compile time check to ensure ChannelReady implements the lnwire.Message
// interface.
var _ Message = (*ChannelReady)(nil)
// A compile time check to ensure ChannelReady implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelReady)(nil)
// Decode deserializes the serialized ChannelReady message stored in the
// passed io.Reader into the target ChannelReady using the deserialization
// rules defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ChannelReady) Decode(r io.Reader, _ uint32) error {
// Read all the mandatory fields in the message.
err := ReadElements(r,
&c.ChanID,
&c.NextPerCommitmentPoint,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
// Next we'll parse out the set of known records. For now, this is just
// the AliasScidRecordType.
var (
aliasScid ShortChannelID
localNonce = c.NextLocalNonce.Zero()
nodeNonce = tlv.ZeroRecordT[tlv.TlvType0, Musig2Nonce]()
btcNonce = tlv.ZeroRecordT[tlv.TlvType2, Musig2Nonce]()
)
typeMap, err := tlvRecords.ExtractRecords(
&btcNonce, &aliasScid, &nodeNonce, &localNonce,
)
if err != nil {
return err
}
// We'll only set AliasScid if the corresponding TLV type was included
// in the stream.
if val, ok := typeMap[AliasScidRecordType]; ok && val == nil {
c.AliasScid = &aliasScid
}
if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil {
c.NextLocalNonce = tlv.SomeRecordT(localNonce)
}
val, ok := typeMap[c.AnnouncementBitcoinNonce.TlvType()]
if ok && val == nil {
c.AnnouncementBitcoinNonce = tlv.SomeRecordT(btcNonce)
}
val, ok = typeMap[c.AnnouncementNodeNonce.TlvType()]
if ok && val == nil {
c.AnnouncementNodeNonce = tlv.SomeRecordT(nodeNonce)
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ChannelReady message into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WritePublicKey(w, c.NextPerCommitmentPoint); err != nil {
return err
}
// We'll only encode the AliasScid in a TLV segment if it exists.
recordProducers := make([]tlv.RecordProducer, 0, 4)
if c.AliasScid != nil {
recordProducers = append(recordProducers, c.AliasScid)
}
c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, &localNonce)
})
c.AnnouncementBitcoinNonce.WhenSome(
func(nonce tlv.RecordT[tlv.TlvType2, Musig2Nonce]) {
recordProducers = append(recordProducers, &nonce)
},
)
c.AnnouncementNodeNonce.WhenSome(
func(nonce tlv.RecordT[tlv.TlvType0, Musig2Nonce]) {
recordProducers = append(recordProducers, &nonce)
},
)
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// ChannelReady message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *ChannelReady) MsgType() MessageType {
return MsgChannelReady
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ChannelReady) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
CRDynHeight tlv.Type = 20
)
// DynHeight is a newtype wrapper to get the proper RecordProducer instance
// to smoothly integrate with the ChannelReestablish Message instance.
type DynHeight uint64
// Record implements the RecordProducer interface, allowing a full tlv.Record
// object to be constructed from a DynHeight.
func (d *DynHeight) Record() tlv.Record {
return tlv.MakePrimitiveRecord(CRDynHeight, (*uint64)(d))
}
// ChannelReestablish is a message sent between peers that have an existing
// open channel upon connection reestablishment. This message allows both sides
// to report their local state, and their current knowledge of the state of the
// remote commitment chain. If a deviation is detected and can be recovered
// from, then the necessary messages will be retransmitted. If the level of
// desynchronization is irreconcilable, then the channel will be force closed.
type ChannelReestablish struct {
// ChanID is the channel ID of the channel state we're attempting to
// synchronize with the remote party.
ChanID ChannelID
// NextLocalCommitHeight is the next local commitment height of the
// sending party. If the height of the sender's commitment chain from
// the receiver's Pov is one less that this number, then the sender
// should re-send the *exact* same proposed commitment.
//
// In other words, the receiver should re-send their last sent
// commitment iff:
//
// * NextLocalCommitHeight == remoteCommitChain.Height
//
// This covers the case of a lost commitment which was sent by the
// sender of this message, but never received by the receiver of this
// message.
NextLocalCommitHeight uint64
// RemoteCommitTailHeight is the height of the receiving party's
// unrevoked commitment from the PoV of the sender of this message. If
// the height of the receiver's commitment is *one more* than this
// value, then their prior RevokeAndAck message should be
// retransmitted.
//
// In other words, the receiver should re-send their last sent
// RevokeAndAck message iff:
//
// * localCommitChain.tail().Height == RemoteCommitTailHeight + 1
//
// This covers the case of a lost revocation, wherein the receiver of
// the message sent a revocation for a prior state, but the sender of
// the message never fully processed it.
RemoteCommitTailHeight uint64
// LastRemoteCommitSecret is the last commitment secret that the
// receiving node has sent to the sending party. This will be the
// secret of the last revoked commitment transaction. Including this
// provides proof that the sending node at least knows of this state,
// as they couldn't have produced it if it wasn't sent, as the value
// can be authenticated by querying the shachain or the receiving
// party.
LastRemoteCommitSecret [32]byte
// LocalUnrevokedCommitPoint is the commitment point used in the
// current un-revoked commitment transaction of the sending party.
LocalUnrevokedCommitPoint *btcec.PublicKey
// LocalNonce is an optional field that stores a local musig2 nonce.
// This will only be populated if the simple taproot channels type was
// negotiated.
//
LocalNonce OptMusig2NonceTLV
// DynHeight is an optional field that stores the dynamic commitment
// negotiation height that is incremented upon successful completion of
// a dynamic commitment negotiation
DynHeight fn.Option[DynHeight]
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure ChannelReestablish implements the
// lnwire.Message interface.
var _ Message = (*ChannelReestablish)(nil)
// A compile time check to ensure ChannelReestablish implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelReestablish)(nil)
// Encode serializes the target ChannelReestablish into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, a.ChanID); err != nil {
return err
}
if err := WriteUint64(w, a.NextLocalCommitHeight); err != nil {
return err
}
if err := WriteUint64(w, a.RemoteCommitTailHeight); err != nil {
return err
}
// If the commit point wasn't sent, then we won't write out any of the
// remaining fields as they're optional.
if a.LocalUnrevokedCommitPoint == nil {
// However, we'll still write out the extra data if it's
// present.
//
// NOTE: This is here primarily for the quickcheck tests, in
// practice, we'll always populate this field.
return WriteBytes(w, a.ExtraData)
}
// Otherwise, we'll write out the remaining elements.
if err := WriteBytes(w, a.LastRemoteCommitSecret[:]); err != nil {
return err
}
if err := WritePublicKey(w, a.LocalUnrevokedCommitPoint); err != nil {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, &localNonce)
})
a.DynHeight.WhenSome(func(h DynHeight) {
recordProducers = append(recordProducers, &h)
})
err := EncodeMessageExtraData(&a.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, a.ExtraData)
}
// Decode deserializes a serialized ChannelReestablish stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.ChanID,
&a.NextLocalCommitHeight,
&a.RemoteCommitTailHeight,
)
if err != nil {
return err
}
// This message has currently defined optional fields. As a result,
// we'll only proceed if there's still bytes remaining within the
// reader.
//
// We'll manually parse out the optional fields in order to be able to
// still utilize the io.Reader interface.
// We'll first attempt to read the optional commit secret, if we're at
// the EOF, then this means the field wasn't included so we can exit
// early.
var buf [32]byte
_, err = io.ReadFull(r, buf[:32])
if err == io.EOF {
// If there aren't any more bytes, then we'll emplace an empty
// extra data to make our quickcheck tests happy.
a.ExtraData = make([]byte, 0)
return nil
} else if err != nil {
return err
}
// If the field is present, then we'll copy it over and proceed.
copy(a.LastRemoteCommitSecret[:], buf[:])
// We'll conclude by parsing out the commitment point. We don't check
// the error in this case, as it has included the commit secret, then
// they MUST also include the commit point.
if err = ReadElement(r, &a.LocalUnrevokedCommitPoint); err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var (
dynHeight DynHeight
localNonce = a.LocalNonce.Zero()
)
typeMap, err := tlvRecords.ExtractRecords(
&localNonce, &dynHeight,
)
if err != nil {
return err
}
if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil {
a.LocalNonce = tlv.SomeRecordT(localNonce)
}
if val, ok := typeMap[CRDynHeight]; ok && val == nil {
a.DynHeight = fn.Some(dynHeight)
}
if len(tlvRecords) != 0 {
a.ExtraData = tlvRecords
}
return nil
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) MsgType() MessageType {
return MsgChannelReestablish
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *ChannelReestablish) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// ChannelTypeRecordType is the type of the experimental record used
// to denote which channel type is being negotiated.
ChannelTypeRecordType tlv.Type = 1
)
// ChannelType represents a specific channel type as a set of feature bits that
// comprise it.
type ChannelType RawFeatureVector
// featureBitLen returns the length in bytes of the encoded feature bits.
func (c ChannelType) featureBitLen() uint64 {
fv := RawFeatureVector(c)
return fv.sizeFunc()
}
// Record returns a TLV record that can be used to encode/decode the channel
// type from a given TLV stream.
func (c *ChannelType) Record() tlv.Record {
return tlv.MakeDynamicRecord(
ChannelTypeRecordType, c, c.featureBitLen, channelTypeEncoder,
channelTypeDecoder,
)
}
// channelTypeEncoder is a custom TLV encoder for the ChannelType record.
func channelTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*ChannelType); ok {
fv := RawFeatureVector(*v)
return rawFeatureEncoder(w, &fv, buf)
}
return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType")
}
// channelTypeDecoder is a custom TLV decoder for the ChannelType record.
func channelTypeDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*ChannelType); ok {
fv := NewRawFeatureVector()
if err := rawFeatureDecoder(r, fv, buf, l); err != nil {
return err
}
*v = ChannelType(*fv)
return nil
}
return tlv.NewTypeForEncodingErr(val, "*lnwire.ChannelType")
}
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ChanUpdateMsgFlags is a bitfield that signals whether optional fields are
// present in the ChannelUpdate.
type ChanUpdateMsgFlags uint8
const (
// ChanUpdateRequiredMaxHtlc is a bit that indicates whether the
// required htlc_maximum_msat field is present in this ChannelUpdate.
ChanUpdateRequiredMaxHtlc ChanUpdateMsgFlags = 1 << iota
)
// String returns the bitfield flags as a string.
func (c ChanUpdateMsgFlags) String() string {
return fmt.Sprintf("%08b", c)
}
// HasMaxHtlc returns true if the htlc_maximum_msat option bit is set in the
// message flags.
func (c ChanUpdateMsgFlags) HasMaxHtlc() bool {
return c&ChanUpdateRequiredMaxHtlc != 0
}
// ChanUpdateChanFlags is a bitfield that signals various options concerning a
// particular channel edge. Each bit is to be examined in order to determine
// how the ChannelUpdate message is to be interpreted.
type ChanUpdateChanFlags uint8
const (
// ChanUpdateDirection indicates the direction of a channel update. If
// this bit is set to 0 if Node1 (the node with the "smaller" Node ID)
// is updating the channel, and to 1 otherwise.
ChanUpdateDirection ChanUpdateChanFlags = 1 << iota
// ChanUpdateDisabled is a bit that indicates if the channel edge
// selected by the ChanUpdateDirection bit is to be treated as being
// disabled.
ChanUpdateDisabled
)
// IsDisabled determines whether the channel flags has the disabled bit set.
func (c ChanUpdateChanFlags) IsDisabled() bool {
return c&ChanUpdateDisabled == ChanUpdateDisabled
}
// String returns the bitfield flags as a string.
func (c ChanUpdateChanFlags) String() string {
return fmt.Sprintf("%08b", c)
}
// ChannelUpdate1 message is used after channel has been initially announced.
// Each side independently announces its fees and minimum expiry for HTLCs and
// other parameters. Also this message is used to redeclare initially set
// channel parameters.
type ChannelUpdate1 struct {
// Signature is used to validate the announced data and prove the
// ownership of node id.
Signature Sig
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
// Along with the short channel ID, this uniquely identifies the
// channel globally in a blockchain.
ChainHash chainhash.Hash
// ShortChannelID is the unique description of the funding transaction.
ShortChannelID ShortChannelID
// Timestamp allows ordering in the case of multiple announcements. We
// should ignore the message if timestamp is not greater than
// the last-received.
Timestamp uint32
// MessageFlags is a bitfield that describes whether optional fields
// are present in this update. Currently, the least-significant bit
// must be set to 1 if the optional field MaxHtlc is present.
MessageFlags ChanUpdateMsgFlags
// ChannelFlags is a bitfield that describes additional meta-data
// concerning how the update is to be interpreted. Currently, the
// least-significant bit must be set to 0 if the creating node
// corresponds to the first node in the previously sent channel
// announcement and 1 otherwise. If the second bit is set, then the
// channel is set to be disabled.
ChannelFlags ChanUpdateChanFlags
// TimeLockDelta is the minimum number of blocks this node requires to
// be added to the expiry of HTLCs. This is a security parameter
// determined by the node operator. This value represents the required
// gap between the time locks of the incoming and outgoing HTLC's set
// to this node.
TimeLockDelta uint16
// HtlcMinimumMsat is the minimum HTLC value which will be accepted.
HtlcMinimumMsat MilliSatoshi
// BaseFee is the base fee that must be used for incoming HTLC's to
// this particular channel. This value will be tacked onto the required
// for a payment independent of the size of the payment.
BaseFee uint32
// FeeRate is the fee rate that will be charged per millionth of a
// satoshi.
FeeRate uint32
// HtlcMaximumMsat is the maximum HTLC value which will be accepted.
HtlcMaximumMsat MilliSatoshi
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraOpaqueData ExtraOpaqueData
}
// A compile time check to ensure ChannelUpdate implements the lnwire.Message
// interface.
var _ Message = (*ChannelUpdate1)(nil)
// A compile time check to ensure ChannelUpdate1 implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelUpdate1)(nil)
// Decode deserializes a serialized ChannelUpdate stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate1) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&a.Signature,
a.ChainHash[:],
&a.ShortChannelID,
&a.Timestamp,
&a.MessageFlags,
&a.ChannelFlags,
&a.TimeLockDelta,
&a.HtlcMinimumMsat,
&a.BaseFee,
&a.FeeRate,
)
if err != nil {
return err
}
// Now check whether the max HTLC field is present and read it if so.
if a.MessageFlags.HasMaxHtlc() {
if err := ReadElements(r, &a.HtlcMaximumMsat); err != nil {
return err
}
}
return a.ExtraOpaqueData.Decode(r)
}
// Encode serializes the target ChannelUpdate into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate1) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteSig(w, a.Signature); err != nil {
return err
}
if err := WriteBytes(w, a.ChainHash[:]); err != nil {
return err
}
if err := WriteShortChannelID(w, a.ShortChannelID); err != nil {
return err
}
if err := WriteUint32(w, a.Timestamp); err != nil {
return err
}
if err := WriteChanUpdateMsgFlags(w, a.MessageFlags); err != nil {
return err
}
if err := WriteChanUpdateChanFlags(w, a.ChannelFlags); err != nil {
return err
}
if err := WriteUint16(w, a.TimeLockDelta); err != nil {
return err
}
if err := WriteMilliSatoshi(w, a.HtlcMinimumMsat); err != nil {
return err
}
if err := WriteUint32(w, a.BaseFee); err != nil {
return err
}
if err := WriteUint32(w, a.FeeRate); err != nil {
return err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
err := WriteMilliSatoshi(w, a.HtlcMaximumMsat)
if err != nil {
return err
}
}
// Finally, append any extra opaque data.
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *ChannelUpdate1) MsgType() MessageType {
return MsgChannelUpdate
}
// DataToSign is used to retrieve part of the announcement message which should
// be signed.
func (a *ChannelUpdate1) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
b := make([]byte, 0, MaxMsgBody)
buf := bytes.NewBuffer(b)
if err := WriteBytes(buf, a.ChainHash[:]); err != nil {
return nil, err
}
if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.Timestamp); err != nil {
return nil, err
}
if err := WriteChanUpdateMsgFlags(buf, a.MessageFlags); err != nil {
return nil, err
}
if err := WriteChanUpdateChanFlags(buf, a.ChannelFlags); err != nil {
return nil, err
}
if err := WriteUint16(buf, a.TimeLockDelta); err != nil {
return nil, err
}
if err := WriteMilliSatoshi(buf, a.HtlcMinimumMsat); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.BaseFee); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.FeeRate); err != nil {
return nil, err
}
// Now append optional fields if they are set. Currently, the only
// optional field is max HTLC.
if a.MessageFlags.HasMaxHtlc() {
err := WriteMilliSatoshi(buf, a.HtlcMaximumMsat)
if err != nil {
return nil, err
}
}
// Finally, append any extra opaque data.
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// SCID returns the ShortChannelID of the channel that the update applies to.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) SCID() ShortChannelID {
return a.ShortChannelID
}
// IsNode1 is true if the update was produced by node 1 of the channel peers.
// Node 1 is the node with the lexicographically smaller public key.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) IsNode1() bool {
return a.ChannelFlags&ChanUpdateDirection == 0
}
// IsDisabled is true if the update is announcing that the channel should be
// considered disabled.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) IsDisabled() bool {
return a.ChannelFlags&ChanUpdateDisabled == ChanUpdateDisabled
}
// GetChainHash returns the hash of the chain that the message is referring to.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) GetChainHash() chainhash.Hash {
return a.ChainHash
}
// ForwardingPolicy returns the set of forwarding constraints of the update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) ForwardingPolicy() *ForwardingPolicy {
return &ForwardingPolicy{
TimeLockDelta: a.TimeLockDelta,
BaseFee: MilliSatoshi(a.BaseFee),
FeeRate: MilliSatoshi(a.FeeRate),
MinHTLC: a.HtlcMinimumMsat,
HasMaxHTLC: a.MessageFlags.HasMaxHtlc(),
MaxHTLC: a.HtlcMaximumMsat,
}
}
// CmpAge can be used to determine if the update is older or newer than the
// passed update. It returns 1 if this update is newer, -1 if it is older, and
// 0 if they are the same age.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) CmpAge(update ChannelUpdate) (CompareResult, error) {
other, ok := update.(*ChannelUpdate1)
if !ok {
return 0, fmt.Errorf("expected *ChannelUpdate1, got: %T",
update)
}
switch {
case a.Timestamp > other.Timestamp:
return GreaterThan, nil
case a.Timestamp < other.Timestamp:
return LessThan, nil
default:
return EqualTo, nil
}
}
// SetDisabledFlag can be used to adjust the disabled flag of an update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) SetDisabledFlag(disabled bool) {
if disabled {
a.ChannelFlags |= ChanUpdateDisabled
} else {
a.ChannelFlags &= ^ChanUpdateDisabled
}
}
// SetSCID can be used to overwrite the SCID of the update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (a *ChannelUpdate1) SetSCID(scid ShortChannelID) {
a.ShortChannelID = scid
}
// A compile time assertion to ensure ChannelUpdate1 implements the
// ChannelUpdate interface.
var _ ChannelUpdate = (*ChannelUpdate1)(nil)
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *ChannelUpdate1) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
const (
defaultCltvExpiryDelta = uint16(80)
defaultHtlcMinMsat = MilliSatoshi(1)
defaultFeeBaseMsat = uint32(1000)
defaultFeeProportionalMillionths = uint32(1)
)
// ChannelUpdate2 message is used after taproot channel has been initially
// announced. Each side independently announces its fees and minimum expiry for
// HTLCs and other parameters. This message is also used to redeclare initially
// set channel parameters.
type ChannelUpdate2 struct {
// Signature is used to validate the announced data and prove the
// ownership of node id.
Signature Sig
// ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain.
// Along with the short channel ID, this uniquely identifies the
// channel globally in a blockchain.
ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash]
// ShortChannelID is the unique description of the funding transaction.
ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID]
// BlockHeight allows ordering in the case of multiple announcements. We
// should ignore the message if block height is not greater than the
// last-received. The block height must always be greater or equal to
// the block height that the channel funding transaction was confirmed
// in.
BlockHeight tlv.RecordT[tlv.TlvType4, uint32]
// DisabledFlags is an optional bitfield that describes various reasons
// that the node is communicating that the channel should be considered
// disabled.
DisabledFlags tlv.RecordT[tlv.TlvType6, ChanUpdateDisableFlags]
// SecondPeer is used to indicate which node the channel node has
// created and signed this message. If this field is present, it was
// node 2 otherwise it was node 1.
SecondPeer tlv.OptionalRecordT[tlv.TlvType8, TrueBoolean]
// CLTVExpiryDelta is the minimum number of blocks this node requires to
// be added to the expiry of HTLCs. This is a security parameter
// determined by the node operator. This value represents the required
// gap between the time locks of the incoming and outgoing HTLC's set
// to this node.
CLTVExpiryDelta tlv.RecordT[tlv.TlvType10, uint16]
// HTLCMinimumMsat is the minimum HTLC value which will be accepted.
HTLCMinimumMsat tlv.RecordT[tlv.TlvType12, MilliSatoshi]
// HtlcMaximumMsat is the maximum HTLC value which will be accepted.
HTLCMaximumMsat tlv.RecordT[tlv.TlvType14, MilliSatoshi]
// FeeBaseMsat is the base fee that must be used for incoming HTLC's to
// this particular channel. This value will be tacked onto the required
// for a payment independent of the size of the payment.
FeeBaseMsat tlv.RecordT[tlv.TlvType16, uint32]
// FeeProportionalMillionths is the fee rate that will be charged per
// millionth of a satoshi.
FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32]
// ExtraOpaqueData is the set of data that was appended to this message
// to fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraOpaqueData ExtraOpaqueData
}
// Decode deserializes a serialized ChannelUpdate2 stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error {
err := ReadElement(r, &c.Signature)
if err != nil {
return err
}
c.Signature.ForceSchnorr()
return c.DecodeTLVRecords(r)
}
// DecodeTLVRecords decodes only the TLV section of the message.
func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error {
// First extract into extra opaque data.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var (
chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
secondPeer = tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean]()
)
typeMap, err := tlvRecords.ExtractRecords(
&chainHash, &c.ShortChannelID, &c.BlockHeight, &c.DisabledFlags,
&secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat,
&c.HTLCMaximumMsat, &c.FeeBaseMsat,
&c.FeeProportionalMillionths,
)
if err != nil {
return err
}
// By default, the chain-hash is the bitcoin mainnet genesis block hash.
c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash
if _, ok := typeMap[c.ChainHash.TlvType()]; ok {
c.ChainHash.Val = chainHash.Val
}
// The presence of the second_peer tlv type indicates "true".
if _, ok := typeMap[c.SecondPeer.TlvType()]; ok {
c.SecondPeer = tlv.SomeRecordT(secondPeer)
}
// If the CLTV expiry delta was not encoded, then set it to the default
// value.
if _, ok := typeMap[c.CLTVExpiryDelta.TlvType()]; !ok {
c.CLTVExpiryDelta.Val = defaultCltvExpiryDelta
}
// If the HTLC Minimum msat was not encoded, then set it to the default
// value.
if _, ok := typeMap[c.HTLCMinimumMsat.TlvType()]; !ok {
c.HTLCMinimumMsat.Val = defaultHtlcMinMsat
}
// If the base fee was not encoded, then set it to the default value.
if _, ok := typeMap[c.FeeBaseMsat.TlvType()]; !ok {
c.FeeBaseMsat.Val = defaultFeeBaseMsat
}
// If the proportional fee was not encoded, then set it to the default
// value.
if _, ok := typeMap[c.FeeProportionalMillionths.TlvType()]; !ok {
c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:ll
}
if len(tlvRecords) != 0 {
c.ExtraOpaqueData = tlvRecords
}
return nil
}
// Encode serializes the target ChannelUpdate2 into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error {
_, err := w.Write(c.Signature.RawBytes())
if err != nil {
return err
}
_, err = c.DataToSign()
if err != nil {
return err
}
return WriteBytes(w, c.ExtraOpaqueData)
}
// DataToSign is used to retrieve part of the announcement message which should
// be signed. For the ChannelUpdate2 message, this includes the serialised TLV
// records.
func (c *ChannelUpdate2) DataToSign() ([]byte, error) {
// The chain-hash record is only included if it is _not_ equal to the
// bitcoin mainnet genisis block hash.
var recordProducers []tlv.RecordProducer
if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) {
hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
hash.Val = c.ChainHash.Val
recordProducers = append(recordProducers, &hash)
}
recordProducers = append(recordProducers,
&c.ShortChannelID, &c.BlockHeight,
)
// Only include the disable flags if any bit is set.
if !c.DisabledFlags.Val.IsEnabled() {
recordProducers = append(recordProducers, &c.DisabledFlags)
}
// We only need to encode the second peer boolean if it is true
c.SecondPeer.WhenSome(func(r tlv.RecordT[tlv.TlvType8, TrueBoolean]) {
recordProducers = append(recordProducers, &r)
})
// We only encode the cltv expiry delta if it is not equal to the
// default.
if c.CLTVExpiryDelta.Val != defaultCltvExpiryDelta {
recordProducers = append(recordProducers, &c.CLTVExpiryDelta)
}
if c.HTLCMinimumMsat.Val != defaultHtlcMinMsat {
recordProducers = append(recordProducers, &c.HTLCMinimumMsat)
}
recordProducers = append(recordProducers, &c.HTLCMaximumMsat)
if c.FeeBaseMsat.Val != defaultFeeBaseMsat {
recordProducers = append(recordProducers, &c.FeeBaseMsat)
}
if c.FeeProportionalMillionths.Val != defaultFeeProportionalMillionths {
recordProducers = append(
recordProducers, &c.FeeProportionalMillionths,
)
}
err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...)
if err != nil {
return nil, err
}
return c.ExtraOpaqueData, nil
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ChannelUpdate2) MsgType() MessageType {
return MsgChannelUpdate2
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ChannelUpdate2) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData {
return c.ExtraOpaqueData
}
// A compile time check to ensure ChannelUpdate2 implements the
// lnwire.Message interface.
var _ Message = (*ChannelUpdate2)(nil)
// SCID returns the ShortChannelID of the channel that the update applies to.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) SCID() ShortChannelID {
return c.ShortChannelID.Val
}
// IsNode1 is true if the update was produced by node 1 of the channel peers.
// Node 1 is the node with the lexicographically smaller public key.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) IsNode1() bool {
return c.SecondPeer.IsNone()
}
// IsDisabled is true if the update is announcing that the channel should be
// considered disabled.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) IsDisabled() bool {
return !c.DisabledFlags.Val.IsEnabled()
}
// GetChainHash returns the hash of the chain that the message is referring to.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) GetChainHash() chainhash.Hash {
return c.ChainHash.Val
}
// ForwardingPolicy returns the set of forwarding constraints of the update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) ForwardingPolicy() *ForwardingPolicy {
return &ForwardingPolicy{
TimeLockDelta: c.CLTVExpiryDelta.Val,
BaseFee: MilliSatoshi(c.FeeBaseMsat.Val),
FeeRate: MilliSatoshi(c.FeeProportionalMillionths.Val),
MinHTLC: c.HTLCMinimumMsat.Val,
HasMaxHTLC: true,
MaxHTLC: c.HTLCMaximumMsat.Val,
}
}
// CmpAge can be used to determine if the update is older or newer than the
// passed update. It returns 1 if this update is newer, -1 if it is older, and
// 0 if they are the same age.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) CmpAge(update ChannelUpdate) (CompareResult, error) {
other, ok := update.(*ChannelUpdate2)
if !ok {
return 0, fmt.Errorf("expected *ChannelUpdate2, got: %T",
update)
}
switch {
case c.BlockHeight.Val > other.BlockHeight.Val:
return GreaterThan, nil
case c.BlockHeight.Val < other.BlockHeight.Val:
return LessThan, nil
default:
return EqualTo, nil
}
}
// SetDisabledFlag can be used to adjust the disabled flag of an update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) SetDisabledFlag(disabled bool) {
if disabled {
c.DisabledFlags.Val |= ChanUpdateDisableIncoming
c.DisabledFlags.Val |= ChanUpdateDisableOutgoing
} else {
c.DisabledFlags.Val &^= ChanUpdateDisableIncoming
c.DisabledFlags.Val &^= ChanUpdateDisableOutgoing
}
}
// SetSCID can be used to overwrite the SCID of the update.
//
// NOTE: this is part of the ChannelUpdate interface.
func (c *ChannelUpdate2) SetSCID(scid ShortChannelID) {
c.ShortChannelID.Val = scid
}
// A compile time check to ensure ChannelUpdate2 implements the
// lnwire.ChannelUpdate interface.
var _ ChannelUpdate = (*ChannelUpdate2)(nil)
// ChanUpdateDisableFlags is a bit vector that can be used to indicate various
// reasons for the channel being marked as disabled.
type ChanUpdateDisableFlags uint8
const (
// ChanUpdateDisableIncoming is a bit indicates that a channel is
// disabled in the inbound direction meaning that the node broadcasting
// the update is communicating that they cannot receive funds.
ChanUpdateDisableIncoming ChanUpdateDisableFlags = 1 << iota
// ChanUpdateDisableOutgoing is a bit indicates that a channel is
// disabled in the outbound direction meaning that the node broadcasting
// the update is communicating that they cannot send or route funds.
ChanUpdateDisableOutgoing = 2
)
// IncomingDisabled returns true if the ChanUpdateDisableIncoming bit is set.
func (c ChanUpdateDisableFlags) IncomingDisabled() bool {
return c&ChanUpdateDisableIncoming == ChanUpdateDisableIncoming
}
// OutgoingDisabled returns true if the ChanUpdateDisableOutgoing bit is set.
func (c ChanUpdateDisableFlags) OutgoingDisabled() bool {
return c&ChanUpdateDisableOutgoing == ChanUpdateDisableOutgoing
}
// IsEnabled returns true if none of the disable bits are set.
func (c ChanUpdateDisableFlags) IsEnabled() bool {
return c == 0
}
// String returns the bitfield flags as a string.
func (c ChanUpdateDisableFlags) String() string {
return fmt.Sprintf("%08b", c)
}
// Record returns the tlv record for the disable flags.
func (c *ChanUpdateDisableFlags) Record() tlv.Record {
return tlv.MakeStaticRecord(0, c, 1, encodeDisableFlags,
decodeDisableFlags)
}
func encodeDisableFlags(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*ChanUpdateDisableFlags); ok {
flagsInt := uint8(*v)
return tlv.EUint8(w, &flagsInt, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.ChanUpdateDisableFlags")
}
func decodeDisableFlags(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*ChanUpdateDisableFlags); ok {
var flagsInt uint8
err := tlv.DUint8(r, &flagsInt, buf, l)
if err != nil {
return err
}
*v = ChanUpdateDisableFlags(flagsInt)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.ChanUpdateDisableFlags",
l, l)
}
// TrueBoolean is a record that indicates true or false using the presence of
// the record. If the record is absent, it indicates false. If it is present,
// it indicates true.
type TrueBoolean struct{}
// Record returns the tlv record for the boolean entry.
func (b *TrueBoolean) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, b, 0, booleanEncoder, booleanDecoder,
)
}
func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error {
if _, ok := val.(*TrueBoolean); ok {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}
func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) {
return nil
}
return tlv.NewTypeForEncodingErr(val, "TrueBoolean")
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
// ClosingSigs houses the 3 possible signatures that can be sent when
// attempting to complete a cooperative channel closure. A signature will
// either include both outputs, or only one of the outputs from either side.
type ClosingSigs struct {
// CloserNoClosee is a signature that excludes the output of the
// clsoee.
CloserNoClosee tlv.OptionalRecordT[tlv.TlvType1, Sig]
// NoCloserClosee is a signature that excludes the output of the
// closer.
NoCloserClosee tlv.OptionalRecordT[tlv.TlvType2, Sig]
// CloserAndClosee is a signature that includes both outputs.
CloserAndClosee tlv.OptionalRecordT[tlv.TlvType3, Sig]
}
// ClosingComplete is sent by either side to kick off the process of obtaining
// a valid signature on a c o-operative channel closure of their choice.
type ClosingComplete struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// CloserScript is the script to which the channel funds will be paid
// for the closer (the person sending the ClosingComplete) message.
CloserScript DeliveryAddress
// CloseeScript is the script to which the channel funds will be paid
// (the person receiving the ClosingComplete message).
CloseeScript DeliveryAddress
// FeeSatoshis is the total fee in satoshis that the party to the
// channel would like to propose for the close transaction.
FeeSatoshis btcutil.Amount
// LockTime is the locktime number to be used in the input spending the
// funding transaction.
LockTime uint32
// ClosingSigs houses the 3 possible signatures that can be sent.
ClosingSigs
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// decodeClosingSigs decodes the closing sig TLV records in the passed
// ExtraOpaqueData.
func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error {
sig1 := c.CloserNoClosee.Zero()
sig2 := c.NoCloserClosee.Zero()
sig3 := c.CloserAndClosee.Zero()
typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3)
if err != nil {
return err
}
// TODO(roasbeef): helper func to made decode of the optional vals
// easier?
if val, ok := typeMap[c.CloserNoClosee.TlvType()]; ok && val == nil {
c.CloserNoClosee = tlv.SomeRecordT(sig1)
}
if val, ok := typeMap[c.NoCloserClosee.TlvType()]; ok && val == nil {
c.NoCloserClosee = tlv.SomeRecordT(sig2)
}
if val, ok := typeMap[c.CloserAndClosee.TlvType()]; ok && val == nil {
c.CloserAndClosee = tlv.SomeRecordT(sig3)
}
return nil
}
// Decode deserializes a serialized ClosingComplete message stored in the
// passed io.Reader.
func (c *ClosingComplete) Decode(r io.Reader, _ uint32) error {
// First, read out all the fields that are hard coded into the message.
err := ReadElements(
r, &c.ChannelID, &c.CloserScript, &c.CloseeScript,
&c.FeeSatoshis, &c.LockTime,
)
if err != nil {
return err
}
// With the hard coded messages read, we'll now read out the TLV fields
// of the message.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil {
return err
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// closingSigRecords returns the set of records that encode the closing sigs,
// if present.
func closingSigRecords(c *ClosingSigs) []tlv.RecordProducer {
recordProducers := make([]tlv.RecordProducer, 0, 3)
c.CloserNoClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType1, Sig]) {
recordProducers = append(recordProducers, &sig)
})
c.NoCloserClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType2, Sig]) {
recordProducers = append(recordProducers, &sig)
})
c.CloserAndClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType3, Sig]) {
recordProducers = append(recordProducers, &sig)
})
return recordProducers
}
// Encode serializes the target ClosingComplete into the passed io.Writer.
func (c *ClosingComplete) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, c.ChannelID); err != nil {
return err
}
if err := WriteDeliveryAddress(w, c.CloserScript); err != nil {
return err
}
if err := WriteDeliveryAddress(w, c.CloseeScript); err != nil {
return err
}
if err := WriteSatoshi(w, c.FeeSatoshis); err != nil {
return err
}
if err := WriteUint32(w, c.LockTime); err != nil {
return err
}
recordProducers := closingSigRecords(&c.ClosingSigs)
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// ClosingComplete message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingComplete) MsgType() MessageType {
return MsgClosingComplete
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ClosingComplete) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// A compile time check to ensure ClosingComplete implements the lnwire.Message
// interface.
var _ Message = (*ClosingComplete)(nil)
// A compile time check to ensure ClosingComplete implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingComplete)(nil)
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcutil"
)
// ClosingSig is sent in response to a ClosingComplete message. It carries the
// signatures of the closee to the closer.
type ClosingSig struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// CloserScript is the script to which the channel funds will be paid
// for the closer (the person sending the ClosingComplete) message.
CloserScript DeliveryAddress
// CloseeScript is the script to which the channel funds will be paid
// (the person receiving the ClosingComplete message).
CloseeScript DeliveryAddress
// FeeSatoshis is the total fee in satoshis that the party to the
// channel proposed for the close transaction.
FeeSatoshis btcutil.Amount
// LockTime is the locktime number to be used in the input spending the
// funding transaction.
LockTime uint32
// ClosingSigs houses the 3 possible signatures that can be sent.
ClosingSigs
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// Decode deserializes a serialized ClosingSig message stored in the passed
// io.Reader.
func (c *ClosingSig) Decode(r io.Reader, _ uint32) error {
// First, read out all the fields that are hard coded into the message.
err := ReadElements(
r, &c.ChannelID, &c.CloserScript, &c.CloseeScript,
&c.FeeSatoshis, &c.LockTime,
)
if err != nil {
return err
}
// With the hard coded messages read, we'll now read out the TLV fields
// of the message.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil {
return err
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ClosingSig into the passed io.Writer.
func (c *ClosingSig) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, c.ChannelID); err != nil {
return err
}
if err := WriteDeliveryAddress(w, c.CloserScript); err != nil {
return err
}
if err := WriteDeliveryAddress(w, c.CloseeScript); err != nil {
return err
}
if err := WriteSatoshi(w, c.FeeSatoshis); err != nil {
return err
}
if err := WriteUint32(w, c.LockTime); err != nil {
return err
}
recordProducers := closingSigRecords(&c.ClosingSigs)
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// ClosingSig message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSig) MsgType() MessageType {
return MsgClosingSig
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ClosingSig) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// A compile time check to ensure ClosingSig implements the lnwire.Message
// interface.
var _ Message = (*ClosingSig)(nil)
// A compile time check to ensure ClosingSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingSig)(nil)
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
// ClosingSigned is sent by both parties to a channel once the channel is clear
// of HTLCs, and is primarily concerned with negotiating fees for the close
// transaction. Each party provides a signature for a transaction with a fee
// that they believe is fair. The process terminates when both sides agree on
// the same fee, or when one side force closes the channel.
//
// NOTE: The responder is able to send a signature without any additional
// messages as all transactions are assembled observing BIP 69 which defines a
// canonical ordering for input/outputs. Therefore, both sides are able to
// arrive at an identical closure transaction as they know the order of the
// inputs/outputs.
type ClosingSigned struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// FeeSatoshis is the total fee in satoshis that the party to the
// channel would like to propose for the close transaction.
FeeSatoshis btcutil.Amount
// Signature is for the proposed channel close transaction.
Signature Sig
// PartialSig is used to transmit a musig2 extended partial signature
// that signs the latest fee offer. The nonce isn't sent along side, as
// that has already been sent in the initial shutdown message.
//
// NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank.
PartialSig OptPartialSigTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewClosingSigned creates a new empty ClosingSigned message.
func NewClosingSigned(cid ChannelID, fs btcutil.Amount,
sig Sig) *ClosingSigned {
return &ClosingSigned{
ChannelID: cid,
FeeSatoshis: fs,
Signature: sig,
}
}
// A compile time check to ensure ClosingSigned implements the lnwire.Message
// interface.
var _ Message = (*ClosingSigned)(nil)
// A compile time check to ensure ClosingSigned implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ClosingSigned)(nil)
// Decode deserializes a serialized ClosingSigned message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error {
err := ReadElements(
r, &c.ChannelID, &c.FeeSatoshis, &c.Signature,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
partialSig := c.PartialSig.Zero()
typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil {
c.PartialSig = tlv.SomeRecordT(partialSig)
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ClosingSigned into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
c.PartialSig.WhenSome(func(sig PartialSigTLV) {
recordProducers = append(recordProducers, &sig)
})
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteChannelID(w, c.ChannelID); err != nil {
return err
}
if err := WriteSatoshi(w, c.FeeSatoshis); err != nil {
return err
}
if err := WriteSig(w, c.Signature); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSigned) MsgType() MessageType {
return MsgClosingSigned
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ClosingSigned) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
// CommitSig is sent by either side to stage any pending HTLC's in the
// receiver's pending set into a new commitment state. Implicitly, the new
// commitment transaction constructed which has been signed by CommitSig
// includes all HTLC's in the remote node's pending set. A CommitSig message
// may be sent after a series of UpdateAddHTLC/UpdateFulfillHTLC messages in
// order to batch add several HTLC's with a single signature covering all
// implicitly accepted HTLC's.
type CommitSig struct {
// ChanID uniquely identifies to which currently active channel this
// CommitSig applies to.
ChanID ChannelID
// CommitSig is Alice's signature for Bob's new commitment transaction.
// Alice is able to send this signature without requesting any
// additional data due to the piggybacking of Bob's next revocation
// hash in his prior RevokeAndAck message, as well as the canonical
// ordering used for all inputs/outputs within commitment transactions.
// If initiating a new commitment state, this signature should ONLY
// cover all of the sending party's pending log updates, and the log
// updates of the remote party that have been ACK'd.
CommitSig Sig
// HtlcSigs is a signature for each relevant HTLC output within the
// created commitment. The order of the signatures is expected to be
// identical to the placement of the HTLC's within the BIP 69 sorted
// commitment transaction. For each outgoing HTLC (from the PoV of the
// sender of this message), a signature for an HTLC timeout transaction
// should be signed, for each incoming HTLC the HTLC timeout
// transaction should be signed.
HtlcSigs []Sig
// PartialSig is used to transmit a musig2 extended partial signature
// that also carries along the public nonce of the signer.
//
// NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank.
PartialSig OptPartialSigWithNonceTLV
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewCommitSig creates a new empty CommitSig message.
func NewCommitSig() *CommitSig {
return &CommitSig{}
}
// A compile time check to ensure CommitSig implements the lnwire.Message
// interface.
var _ Message = (*CommitSig)(nil)
// A compile time check to ensure CommitSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*CommitSig)(nil)
// Decode deserializes a serialized CommitSig message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
// msgExtraData is a temporary variable used to read the message extra
// data field from the reader.
var msgExtraData ExtraOpaqueData
err := ReadElements(r,
&c.ChanID,
&c.CommitSig,
&c.HtlcSigs,
&msgExtraData,
)
if err != nil {
return err
}
// Extract TLV records from the extra data field.
partialSig := c.PartialSig.Zero()
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
msgExtraData, &partialSig,
)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if _, ok := parsed[partialSig.TlvType()]; ok {
c.PartialSig = tlv.SomeRecordT(partialSig)
}
c.CustomRecords = customRecords
c.ExtraData = extraData
return nil
}
// Encode serializes the target CommitSig into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, &sig)
})
extraData, err := MergeAndEncode(
recordProducers, c.ExtraData, c.CustomRecords,
)
if err != nil {
return err
}
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteSig(w, c.CommitSig); err != nil {
return err
}
if err := WriteSigs(w, c.HtlcSigs); err != nil {
return err
}
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *CommitSig) MsgType() MessageType {
return MsgCommitSig
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *CommitSig) TargetChanID() ChannelID {
return c.ChanID
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *CommitSig) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"fmt"
"io"
"sync"
)
// CustomTypeStart is the start of the custom type range for peer messages as
// defined in BOLT 01.
const CustomTypeStart MessageType = 32768
var (
// customTypeOverride contains a set of message types < CustomTypeStart
// that lnd allows to be treated as custom messages. This allows us to
// override messages reserved for the protocol level and treat them as
// custom messages. This set of message types is stored as a global so
// that we do not need to pass around state when accounting for this
// set of messages in message creation.
//
// Note: This global is protected by the customTypeOverride mutex.
customTypeOverride map[MessageType]struct{}
// customTypeOverrideMtx manages concurrent access to
// customTypeOverride.
customTypeOverrideMtx sync.RWMutex
)
// SetCustomOverrides validates that the set of override types are outside of
// the custom message range (there's no reason to override messages that are
// already within the range), and updates the customTypeOverride global to hold
// this set of message types. Note that this function will completely overwrite
// the set of overrides, so should be called with the full set of types.
func SetCustomOverrides(overrideTypes []uint16) error {
customTypeOverrideMtx.Lock()
defer customTypeOverrideMtx.Unlock()
customTypeOverride = make(map[MessageType]struct{}, len(overrideTypes))
for _, t := range overrideTypes {
msgType := MessageType(t)
if msgType >= CustomTypeStart {
return fmt.Errorf("can't override type: %v, already "+
"in custom range", t)
}
customTypeOverride[msgType] = struct{}{}
}
return nil
}
// IsCustomOverride returns a bool indicating whether the message type is one
// of the protocol messages that we override for custom use.
func IsCustomOverride(t MessageType) bool {
customTypeOverrideMtx.RLock()
defer customTypeOverrideMtx.RUnlock()
_, ok := customTypeOverride[t]
return ok
}
// Custom represents an application-defined wire message.
type Custom struct {
Type MessageType
Data []byte
}
// A compile time check to ensure Custom implements the lnwire.Message
// interface.
var _ Message = (*Custom)(nil)
// A compile time check to ensure Custom implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Custom)(nil)
// NewCustom instantiates a new custom message.
func NewCustom(msgType MessageType, data []byte) (*Custom, error) {
if msgType < CustomTypeStart && !IsCustomOverride(msgType) {
return nil, fmt.Errorf("msg type: %d not in custom range: %v "+
"and not overridden", msgType, CustomTypeStart)
}
return &Custom{
Type: msgType,
Data: data,
}, nil
}
// Encode serializes the target Custom message into the passed io.Writer
// implementation.
//
// This is part of the lnwire.Message interface.
func (c *Custom) Encode(b *bytes.Buffer, pver uint32) error {
_, err := b.Write(c.Data)
return err
}
// Decode deserializes the serialized Custom message stored in the passed
// io.Reader into the target Custom message.
//
// This is part of the lnwire.Message interface.
func (c *Custom) Decode(r io.Reader, pver uint32) error {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return err
}
c.Data = b.Bytes()
return nil
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// Custom message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *Custom) MsgType() MessageType {
return c.Type
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *Custom) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"fmt"
"io"
"maps"
"sort"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// MinCustomRecordsTlvType is the minimum custom records TLV type as
// defined in BOLT 01.
MinCustomRecordsTlvType = 65536
)
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
// which must be greater than or equal to MinCustomRecordsTlvType.
type CustomRecords map[uint64][]byte
// NewCustomRecords creates a new CustomRecords instance from a
// tlv.TypeMap.
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
// Make comparisons in unit tests easy by returning nil if the map is
// empty.
if len(tlvMap) == 0 {
return nil, nil
}
customRecords := make(CustomRecords, len(tlvMap))
for k, v := range tlvMap {
customRecords[uint64(k)] = v
}
// Validate the custom records.
err := customRecords.Validate()
if err != nil {
return nil, fmt.Errorf("custom records from tlv map "+
"validation error: %w", err)
}
return customRecords, nil
}
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
return ParseCustomRecordsFrom(bytes.NewReader(b))
}
// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
typeMap, err := DecodeRecords(r)
if err != nil {
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
}
return NewCustomRecords(typeMap)
}
// Validate checks that all custom records are in the custom type range.
func (c CustomRecords) Validate() error {
if c == nil {
return nil
}
for key := range c {
if key < MinCustomRecordsTlvType {
return fmt.Errorf("custom records entry with TLV "+
"type below min: %d", MinCustomRecordsTlvType)
}
}
return nil
}
// Copy returns a copy of the custom records.
func (c CustomRecords) Copy() CustomRecords {
if c == nil {
return nil
}
customRecords := make(CustomRecords, len(c))
for k, v := range c {
customRecords[k] = v
}
return customRecords
}
// MergedCopy creates a copy of the records and merges them with the given
// records. If the same key is present in both sets, the value from the other
// records will be used.
func (c CustomRecords) MergedCopy(other CustomRecords) CustomRecords {
copiedRecords := make(CustomRecords, len(c))
maps.Copy(copiedRecords, c)
maps.Copy(copiedRecords, other)
return copiedRecords
}
// ExtendRecordProducers extends the given records slice with the custom
// records. The resultant records slice will be sorted if the given records
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
func (c CustomRecords) ExtendRecordProducers(
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
// If the custom records are nil or empty, there is nothing to do.
if len(c) == 0 {
return producers, nil
}
// Validate the custom records.
err := c.Validate()
if err != nil {
return nil, err
}
// Ensure that the existing records slice TLV types are not also present
// in the custom records. If they are, the resultant extended records
// slice would erroneously contain duplicate TLV types.
for _, rp := range producers {
record := rp.Record()
recordTlvType := uint64(record.Type())
_, foundDuplicateTlvType := c[recordTlvType]
if foundDuplicateTlvType {
return nil, fmt.Errorf("custom records contains a TLV "+
"type that is already present in the "+
"existing records: %d", recordTlvType)
}
}
// Convert the custom records map to a TLV record producer slice and
// append them to the exiting records slice.
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
producers = append(producers, customRecordProducers...)
// If the records slice which was given as an argument included TLV
// values greater than or equal to the minimum custom records TLV type
// we will sort the extended records slice to ensure that it is ordered
// correctly.
SortProducers(producers)
return producers, nil
}
// RecordProducers returns a slice of record producers for the custom records.
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
// If the custom records are nil or empty, return an empty slice.
if len(c) == 0 {
return nil
}
// Convert the custom records map to a TLV record producer slice.
records := tlv.MapToRecords(c)
return RecordsAsProducers(records)
}
// Serialize serializes the custom records into a byte slice.
func (c CustomRecords) Serialize() ([]byte, error) {
records := tlv.MapToRecords(c)
return EncodeRecords(records)
}
// SerializeTo serializes the custom records into the given writer.
func (c CustomRecords) SerializeTo(w io.Writer) error {
records := tlv.MapToRecords(c)
return EncodeRecordsTo(w, records)
}
// ProduceRecordsSorted converts a slice of record producers into a slice of
// records and then sorts it by type.
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
records := fn.Map(
recordProducers,
func(producer tlv.RecordProducer) tlv.Record {
return producer.Record()
},
)
// Ensure that the set of records are sorted before we attempt to
// decode from the stream, to ensure they're canonical.
tlv.SortRecords(records)
return records
}
// SortProducers sorts the given record producers by their type.
func SortProducers(producers []tlv.RecordProducer) {
sort.Slice(producers, func(i, j int) bool {
recordI := producers[i].Record()
recordJ := producers[j].Record()
return recordI.Type() < recordJ.Type()
})
}
// TlvMapToRecords converts a TLV map into a slice of records.
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
tlvMapGeneric := make(map[uint64][]byte)
for k, v := range tlvMap {
tlvMapGeneric[uint64(k)] = v
}
return tlv.MapToRecords(tlvMapGeneric)
}
// RecordsAsProducers converts a slice of records into a slice of record
// producers.
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
return fn.Map(records, func(record tlv.Record) tlv.RecordProducer {
return &record
})
}
// EncodeRecords encodes the given records into a byte slice.
func EncodeRecords(records []tlv.Record) ([]byte, error) {
var buf bytes.Buffer
if err := EncodeRecordsTo(&buf, records); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// EncodeRecordsTo encodes the given records into the given writer.
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// DecodeRecords decodes the given byte slice into the given records and returns
// the rest as a TLV type map.
func DecodeRecords(r io.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
return tlvStream.DecodeWithParsedTypes(r)
}
// DecodeRecordsP2P decodes the given byte slice into the given records and
// returns the rest as a TLV type map. This function is identical to
// DecodeRecords except that the record size is capped at 65535.
func DecodeRecordsP2P(r *bytes.Reader,
records ...tlv.Record) (tlv.TypeMap, error) {
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
return tlvStream.DecodeWithParsedTypesP2P(r)
}
// AssertUniqueTypes asserts that the given records have unique types.
func AssertUniqueTypes(r []tlv.Record) error {
seen := make(fn.Set[tlv.Type], len(r))
for _, record := range r {
t := record.Type()
if seen.Contains(t) {
return fmt.Errorf("duplicate record type: %d", t)
}
seen.Add(t)
}
return nil
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// DALocalMusig2Pubnonce is the TLV type number that identifies the
// musig2 public nonce that we need to verify the commitment transaction
// signature.
DALocalMusig2Pubnonce tlv.Type = 0
)
// DynAck is the message used to accept the parameters of a dynamic commitment
// negotiation. Additional optional parameters will need to be present depending
// on the details of the dynamic commitment upgrade.
type DynAck struct {
// ChanID is the ChannelID of the channel that is currently undergoing
// a dynamic commitment negotiation
ChanID ChannelID
// LocalNonce is an optional field that is transmitted when accepting
// a dynamic commitment upgrade to Taproot Channels. This nonce will be
// used to verify the first commitment transaction signature. This will
// only be populated if the DynPropose we are responding to specifies
// taproot channels in the ChannelType field.
LocalNonce fn.Option[Musig2Nonce]
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure DynAck implements the lnwire.Message
// interface.
var _ Message = (*DynAck)(nil)
// A compile time check to ensure DynAck implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*DynAck)(nil)
// Encode serializes the target DynAck into the passed io.Writer. Serialization
// will observe the rules defined by the passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (da *DynAck) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, da.ChanID); err != nil {
return err
}
var tlvRecords []tlv.Record
da.LocalNonce.WhenSome(func(nonce Musig2Nonce) {
tlvRecords = append(
tlvRecords, tlv.MakeStaticRecord(
DALocalMusig2Pubnonce, &nonce,
musig2.PubNonceSize, nonceTypeEncoder,
nonceTypeDecoder,
),
)
})
tlv.SortRecords(tlvRecords)
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
var extraBytesWriter bytes.Buffer
if err := tlvStream.Encode(&extraBytesWriter); err != nil {
return err
}
da.ExtraData = ExtraOpaqueData(extraBytesWriter.Bytes())
return WriteBytes(w, da.ExtraData)
}
// Decode deserializes the serialized DynAck stored in the passed io.Reader into
// the target DynAck using the deserialization rules defined by the passed
// protocol version.
//
// This is a part of the lnwire.Message interface.
func (da *DynAck) Decode(r io.Reader, _ uint32) error {
// Parse out main message.
if err := ReadElements(r, &da.ChanID); err != nil {
return err
}
// Parse out TLV records.
var tlvRecords ExtraOpaqueData
if err := ReadElement(r, &tlvRecords); err != nil {
return err
}
// Prepare receiving buffers to be filled by TLV extraction.
var localNonceScratch Musig2Nonce
localNonce := tlv.MakeStaticRecord(
DALocalMusig2Pubnonce, &localNonceScratch, musig2.PubNonceSize,
nonceTypeEncoder, nonceTypeDecoder,
)
// Create set of Records to read TLV bytestream into.
records := []tlv.Record{localNonce}
tlv.SortRecords(records)
// Read TLV stream into record set.
extraBytesReader := bytes.NewReader(tlvRecords)
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
typeMap, err := tlvStream.DecodeWithParsedTypesP2P(extraBytesReader)
if err != nil {
return err
}
// Check the results of the TLV Stream decoding and appropriately set
// message fields.
if val, ok := typeMap[DALocalMusig2Pubnonce]; ok && val == nil {
da.LocalNonce = fn.Some(localNonceScratch)
}
if len(tlvRecords) != 0 {
da.ExtraData = tlvRecords
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as a DynAck on the wire.
//
// This is part of the lnwire.Message interface.
func (da *DynAck) MsgType() MessageType {
return MsgDynAck
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (da *DynAck) SerializedSize() (uint32, error) {
return MessageSerializedSize(da)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// DPDustLimitSatoshis is the TLV type number that identifies the record
// for DynPropose.DustLimit.
DPDustLimitSatoshis tlv.Type = 0
// DPMaxHtlcValueInFlightMsat is the TLV type number that identifies the
// record for DynPropose.MaxValueInFlight.
DPMaxHtlcValueInFlightMsat tlv.Type = 1
// DPChannelReserveSatoshis is the TLV type number that identifies the
// for DynPropose.ChannelReserve.
DPChannelReserveSatoshis tlv.Type = 2
// DPToSelfDelay is the TLV type number that identifies the record for
// DynPropose.CsvDelay.
DPToSelfDelay tlv.Type = 3
// DPMaxAcceptedHtlcs is the TLV type number that identifies the record
// for DynPropose.MaxAcceptedHTLCs.
DPMaxAcceptedHtlcs tlv.Type = 4
// DPFundingPubkey is the TLV type number that identifies the record for
// DynPropose.FundingKey.
DPFundingPubkey tlv.Type = 5
// DPChannelType is the TLV type number that identifies the record for
// DynPropose.ChannelType.
DPChannelType tlv.Type = 6
// DPKickoffFeerate is the TLV type number that identifies the record
// for DynPropose.KickoffFeerate.
DPKickoffFeerate tlv.Type = 7
)
// DynPropose is a message that is sent during a dynamic commitments negotiation
// process. It is sent by both parties to propose new channel parameters.
type DynPropose struct {
// ChanID identifies the channel whose parameters we are trying to
// re-negotiate.
ChanID ChannelID
// Initiator is a byte that identifies whether this message was sent as
// the initiator of a dynamic commitment negotiation or the responder
// of a dynamic commitment negotiation. bool true indicates it is the
// initiator
Initiator bool
// DustLimit, if not nil, proposes a change to the dust_limit_satoshis
// for the sender's commitment transaction.
DustLimit fn.Option[btcutil.Amount]
// MaxValueInFlight, if not nil, proposes a change to the
// max_htlc_value_in_flight_msat limit of the sender.
MaxValueInFlight fn.Option[MilliSatoshi]
// ChannelReserve, if not nil, proposes a change to the
// channel_reserve_satoshis requirement of the recipient.
ChannelReserve fn.Option[btcutil.Amount]
// CsvDelay, if not nil, proposes a change to the to_self_delay
// requirement of the recipient.
CsvDelay fn.Option[uint16]
// MaxAcceptedHTLCs, if not nil, proposes a change to the
// max_accepted_htlcs limit of the sender.
MaxAcceptedHTLCs fn.Option[uint16]
// FundingKey, if not nil, proposes a change to the funding_pubkey
// parameter of the sender.
FundingKey fn.Option[btcec.PublicKey]
// ChannelType, if not nil, proposes a change to the channel_type
// parameter.
ChannelType fn.Option[ChannelType]
// KickoffFeerate proposes the fee rate in satoshis per kw that it
// is offering for a ChannelType conversion that requires a kickoff
// transaction.
KickoffFeerate fn.Option[chainfee.SatPerKWeight]
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
//
// NOTE: Since the fields in this structure are part of the TLV stream,
// ExtraData will contain all TLV records _except_ the ones that are
// present in earlier parts of this structure.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure DynPropose implements the lnwire.Message
// interface.
var _ Message = (*DynPropose)(nil)
// A compile time check to ensure DynPropose implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*DynPropose)(nil)
// Encode serializes the target DynPropose into the passed io.Writer.
// Serialization will observe the rules defined by the passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (dp *DynPropose) Encode(w *bytes.Buffer, _ uint32) error {
var tlvRecords []tlv.Record
dp.DustLimit.WhenSome(func(dl btcutil.Amount) {
protoSats := uint64(dl)
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPDustLimitSatoshis, &protoSats,
),
)
})
dp.MaxValueInFlight.WhenSome(func(max MilliSatoshi) {
protoSats := uint64(max)
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPMaxHtlcValueInFlightMsat, &protoSats,
),
)
})
dp.ChannelReserve.WhenSome(func(min btcutil.Amount) {
channelReserve := uint64(min)
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPChannelReserveSatoshis, &channelReserve,
),
)
})
dp.CsvDelay.WhenSome(func(wait uint16) {
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPToSelfDelay, &wait,
),
)
})
dp.MaxAcceptedHTLCs.WhenSome(func(max uint16) {
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPMaxAcceptedHtlcs, &max,
),
)
})
dp.FundingKey.WhenSome(func(key btcec.PublicKey) {
keyScratch := &key
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPFundingPubkey, &keyScratch,
),
)
})
dp.ChannelType.WhenSome(func(ty ChannelType) {
tlvRecords = append(
tlvRecords, tlv.MakeDynamicRecord(
DPChannelType, &ty,
ty.featureBitLen,
channelTypeEncoder, channelTypeDecoder,
),
)
})
dp.KickoffFeerate.WhenSome(func(kickoffFeerate chainfee.SatPerKWeight) {
protoSats := uint32(kickoffFeerate)
tlvRecords = append(
tlvRecords, tlv.MakePrimitiveRecord(
DPKickoffFeerate, &protoSats,
),
)
})
tlv.SortRecords(tlvRecords)
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
var extraBytesWriter bytes.Buffer
if err := tlvStream.Encode(&extraBytesWriter); err != nil {
return err
}
dp.ExtraData = ExtraOpaqueData(extraBytesWriter.Bytes())
if err := WriteChannelID(w, dp.ChanID); err != nil {
return err
}
if err := WriteBool(w, dp.Initiator); err != nil {
return err
}
return WriteBytes(w, dp.ExtraData)
}
// Decode deserializes the serialized DynPropose stored in the passed io.Reader
// into the target DynPropose using the deserialization rules defined by the
// passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (dp *DynPropose) Decode(r io.Reader, _ uint32) error {
// Parse out the only required field.
if err := ReadElements(r, &dp.ChanID, &dp.Initiator); err != nil {
return err
}
// Parse out TLV stream.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
// Prepare receiving buffers to be filled by TLV extraction.
var dustLimitScratch uint64
dustLimit := tlv.MakePrimitiveRecord(
DPDustLimitSatoshis, &dustLimitScratch,
)
var maxValueScratch uint64
maxValue := tlv.MakePrimitiveRecord(
DPMaxHtlcValueInFlightMsat, &maxValueScratch,
)
var reserveScratch uint64
reserve := tlv.MakePrimitiveRecord(
DPChannelReserveSatoshis, &reserveScratch,
)
var csvDelayScratch uint16
csvDelay := tlv.MakePrimitiveRecord(DPToSelfDelay, &csvDelayScratch)
var maxHtlcsScratch uint16
maxHtlcs := tlv.MakePrimitiveRecord(
DPMaxAcceptedHtlcs, &maxHtlcsScratch,
)
var fundingKeyScratch *btcec.PublicKey
fundingKey := tlv.MakePrimitiveRecord(
DPFundingPubkey, &fundingKeyScratch,
)
var chanTypeScratch ChannelType
chanType := tlv.MakeDynamicRecord(
DPChannelType, &chanTypeScratch, chanTypeScratch.featureBitLen,
channelTypeEncoder, channelTypeDecoder,
)
var kickoffFeerateScratch uint32
kickoffFeerate := tlv.MakePrimitiveRecord(
DPKickoffFeerate, &kickoffFeerateScratch,
)
// Create set of Records to read TLV bytestream into.
records := []tlv.Record{
dustLimit, maxValue, reserve, csvDelay, maxHtlcs, fundingKey,
chanType, kickoffFeerate,
}
tlv.SortRecords(records)
// Read TLV stream into record set.
extraBytesReader := bytes.NewReader(tlvRecords)
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
typeMap, err := tlvStream.DecodeWithParsedTypesP2P(extraBytesReader)
if err != nil {
return err
}
// Check the results of the TLV Stream decoding and appropriately set
// message fields.
if val, ok := typeMap[DPDustLimitSatoshis]; ok && val == nil {
dp.DustLimit = fn.Some(btcutil.Amount(dustLimitScratch))
}
if val, ok := typeMap[DPMaxHtlcValueInFlightMsat]; ok && val == nil {
dp.MaxValueInFlight = fn.Some(MilliSatoshi(maxValueScratch))
}
if val, ok := typeMap[DPChannelReserveSatoshis]; ok && val == nil {
dp.ChannelReserve = fn.Some(btcutil.Amount(reserveScratch))
}
if val, ok := typeMap[DPToSelfDelay]; ok && val == nil {
dp.CsvDelay = fn.Some(csvDelayScratch)
}
if val, ok := typeMap[DPMaxAcceptedHtlcs]; ok && val == nil {
dp.MaxAcceptedHTLCs = fn.Some(maxHtlcsScratch)
}
if val, ok := typeMap[DPFundingPubkey]; ok && val == nil {
dp.FundingKey = fn.Some(*fundingKeyScratch)
}
if val, ok := typeMap[DPChannelType]; ok && val == nil {
dp.ChannelType = fn.Some(chanTypeScratch)
}
if val, ok := typeMap[DPKickoffFeerate]; ok && val == nil {
dp.KickoffFeerate = fn.Some(
chainfee.SatPerKWeight(kickoffFeerateScratch),
)
}
if len(tlvRecords) != 0 {
dp.ExtraData = tlvRecords
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as a DynPropose on the wire.
//
// This is part of the lnwire.Message interface.
func (dp *DynPropose) MsgType() MessageType {
return MsgDynPropose
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (dp *DynPropose) SerializedSize() (uint32, error) {
return MessageSerializedSize(dp)
}
package lnwire
import (
"bytes"
"io"
)
// DynReject is a message that is sent during a dynamic commitments negotiation
// process. It is sent by both parties to propose new channel parameters.
type DynReject struct {
// ChanID identifies the channel whose parameters we are trying to
// re-negotiate.
ChanID ChannelID
// UpdateRejections is a bit vector that specifies which of the
// DynPropose parameters we wish to call out as being unacceptable.
UpdateRejections RawFeatureVector
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
//
// NOTE: Since the fields in this structure are part of the TLV stream,
// ExtraData will contain all TLV records _except_ the ones that are
// present in earlier parts of this structure.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure DynReject implements the lnwire.Message
// interface.
var _ Message = (*DynReject)(nil)
// A compile time check to ensure DynReject implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*DynReject)(nil)
// Encode serializes the target DynReject into the passed io.Writer.
// Serialization will observe the rules defined by the passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (dr *DynReject) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, dr.ChanID); err != nil {
return err
}
if err := WriteRawFeatureVector(w, &dr.UpdateRejections); err != nil {
return err
}
return WriteBytes(w, dr.ExtraData)
}
// Decode deserializes the serialized DynReject stored in the passed io.Reader
// into the target DynReject using the deserialization rules defined by the
// passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (dr *DynReject) Decode(r io.Reader, _ uint32) error {
var extra ExtraOpaqueData
if err := ReadElements(
r, &dr.ChanID, &dr.UpdateRejections, &extra,
); err != nil {
return err
}
if len(extra) != 0 {
dr.ExtraData = extra
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as a DynReject on the wire.
//
// This is part of the lnwire.Message interface.
func (dr *DynReject) MsgType() MessageType {
return MsgDynReject
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (dr *DynReject) SerializedSize() (uint32, error) {
return MessageSerializedSize(dr)
}
package lnwire
import "github.com/lightningnetwork/lnd/tlv"
// QueryEncoding is an enum-like type that represents exactly how a set data is
// encoded on the wire.
type QueryEncoding uint8
const (
// EncodingSortedPlain signals that the set of data is encoded using the
// regular encoding, in a sorted order.
EncodingSortedPlain QueryEncoding = 0
// EncodingSortedZlib signals that the set of data is encoded by first
// sorting the set of channel ID's, as then compressing them using zlib.
//
// NOTE: this should no longer be used or accepted.
EncodingSortedZlib QueryEncoding = 1
)
// recordProducer is a simple helper struct that implements the
// tlv.RecordProducer interface.
type recordProducer struct {
record tlv.Record
}
// Record returns the underlying record.
func (r *recordProducer) Record() tlv.Record {
return r.record
}
// Ensure that recordProducer implements the tlv.RecordProducer interface.
var _ tlv.RecordProducer = (*recordProducer)(nil)
package lnwire
import (
"bytes"
"fmt"
"io"
)
// FundingError represents a set of errors that can be encountered and sent
// during the funding workflow.
type FundingError uint8
const (
// ErrMaxPendingChannels is returned by remote peer when the number of
// active pending channels exceeds their maximum policy limit.
ErrMaxPendingChannels FundingError = 1
// ErrChanTooLarge is returned by a remote peer that receives a
// FundingOpen request for a channel that is above their current
// soft-limit.
ErrChanTooLarge FundingError = 2
)
// String returns a human readable version of the target FundingError.
func (e FundingError) String() string {
switch e {
case ErrMaxPendingChannels:
return "Number of pending channels exceed maximum"
case ErrChanTooLarge:
return "channel too large"
default:
return "unknown error"
}
}
// Error returns the human readable version of the target FundingError.
//
// NOTE: Satisfies the Error interface.
func (e FundingError) Error() string {
return e.String()
}
// ErrorData is a set of bytes associated with a particular sent error. A
// receiving node SHOULD only print out data verbatim if the string is composed
// solely of printable ASCII characters. For reference, the printable character
// set includes byte values 32 through 127 inclusive.
type ErrorData []byte
// Error represents a generic error bound to an exact channel. The message
// format is purposefully general in order to allow expression of a wide array
// of possible errors. Each Error message is directed at a particular open
// channel referenced by ChannelPoint.
type Error struct {
// ChanID references the active channel in which the error occurred
// within. If the ChanID is all zeros, then this error applies to the
// entire established connection.
ChanID ChannelID
// Data is the attached error data that describes the exact failure
// which caused the error message to be sent.
Data ErrorData
}
// NewError creates a new Error message.
func NewError() *Error {
return &Error{}
}
// A compile time check to ensure Error implements the lnwire.Message
// interface.
var _ Message = (*Error)(nil)
// A compile time check to ensure Error implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Error)(nil)
// Error returns the string representation to Error.
//
// NOTE: Satisfies the error interface.
func (c *Error) Error() string {
errMsg := "non-ascii data"
if isASCII(c.Data) {
errMsg = string(c.Data)
}
return fmt.Sprintf("chan_id=%v, err=%v", c.ChanID, errMsg)
}
// Decode deserializes a serialized Error message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *Error) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.Data,
)
}
// Encode serializes the target Error into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *Error) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteBytes(w, c.ChanID[:]); err != nil {
return err
}
return WriteErrorData(w, c.Data)
}
// MsgType returns the integer uniquely identifying an Error message on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *Error) MsgType() MessageType {
return MsgError
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *Error) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// isASCII is a helper method that checks whether all bytes in `data` would be
// printable ASCII characters if interpreted as a string.
func isASCII(data []byte) bool {
for _, c := range data {
if c < 32 || c > 126 {
return false
}
}
return true
}
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
// ExtraOpaqueData is the set of data that was appended to this message, some
// of which we may not actually know how to iterate or parse. By holding onto
// this data, we ensure that we're able to properly validate the set of
// signatures that cover these new fields, and ensure we're able to make
// upgrades to the network in a forwards compatible manner.
type ExtraOpaqueData []byte
// NewExtraOpaqueData creates a new ExtraOpaqueData instance from a tlv.TypeMap.
func NewExtraOpaqueData(tlvMap tlv.TypeMap) (ExtraOpaqueData, error) {
// If the tlv map is empty, we'll want to mirror the behavior of
// decoding an empty extra opaque data field (see Decode method).
if len(tlvMap) == 0 {
return make([]byte, 0), nil
}
// Convert the TLV map into a slice of records.
records := TlvMapToRecords(tlvMap)
// Encode the records into the extra data byte slice.
return EncodeRecords(records)
}
// Encode attempts to encode the raw extra bytes into the passed io.Writer.
func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error {
eBytes := []byte((*e)[:])
if err := WriteBytes(w, eBytes); err != nil {
return err
}
return nil
}
// Decode attempts to unpack the raw bytes encoded in the passed-in io.Reader as
// a set of extra opaque data.
func (e *ExtraOpaqueData) Decode(r io.Reader) error {
// First, we'll attempt to read a set of bytes contained within the
// passed io.Reader (if any exist).
rawBytes, err := io.ReadAll(r)
if err != nil {
return err
}
// If we _do_ have some bytes, then we'll swap out our backing pointer.
// This ensures that any struct that embeds this type will properly
// store the bytes once this method exits.
if len(rawBytes) > 0 {
*e = rawBytes
} else {
*e = make([]byte, 0)
}
return nil
}
// PackRecords attempts to encode the set of tlv records into the target
// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream
// and stored within the backing slice pointer.
func (e *ExtraOpaqueData) PackRecords(
recordProducers ...tlv.RecordProducer) error {
// Assemble all the records passed in series, then encode them.
records := ProduceRecordsSorted(recordProducers...)
encoded, err := EncodeRecords(records)
if err != nil {
return err
}
*e = encoded
return nil
}
// ExtractRecords attempts to decode any types in the internal raw bytes as if
// it were a tlv stream. The set of raw parsed types is returned, and any
// passed records (if found in the stream) will be parsed into the proper
// tlv.Record.
func (e *ExtraOpaqueData) ExtractRecords(
recordProducers ...tlv.RecordProducer) (tlv.TypeMap, error) {
// First, assemble all the records passed in series.
records := ProduceRecordsSorted(recordProducers...)
extraBytesReader := bytes.NewReader(*e)
// Since ExtraOpaqueData is provided by a potentially malicious peer,
// pass it into the P2P decoding variant.
return DecodeRecordsP2P(extraBytesReader, records...)
}
// RecordProducers parses ExtraOpaqueData into a slice of TLV record producers
// by interpreting it as a TLV map.
func (e *ExtraOpaqueData) RecordProducers() ([]tlv.RecordProducer, error) {
var recordProducers []tlv.RecordProducer
// If the instance is nil or empty, return an empty slice.
if e == nil || len(*e) == 0 {
return recordProducers, nil
}
// Parse the extra opaque data as a TLV map.
tlvMap, err := e.ExtractRecords()
if err != nil {
return nil, err
}
// Convert the TLV map into a slice of record producers.
records := TlvMapToRecords(tlvMap)
return RecordsAsProducers(records), nil
}
// EncodeMessageExtraData encodes the given recordProducers into the given
// extraData.
func EncodeMessageExtraData(extraData *ExtraOpaqueData,
recordProducers ...tlv.RecordProducer) error {
// Treat extraData as a mutable reference.
if extraData == nil {
return fmt.Errorf("extra data cannot be nil")
}
// Pack in the series of TLV records into this message. The order we
// pass them in doesn't matter, as the method will ensure that things
// are all properly sorted.
return extraData.PackRecords(recordProducers...)
}
// ParseAndExtractCustomRecords parses the given extra data into the passed-in
// records, then returns any remaining records split into custom records and
// extra data.
func ParseAndExtractCustomRecords(allExtraData ExtraOpaqueData,
knownRecords ...tlv.RecordProducer) (CustomRecords,
fn.Set[tlv.Type], ExtraOpaqueData, error) {
extraDataTlvMap, err := allExtraData.ExtractRecords(knownRecords...)
if err != nil {
return nil, nil, nil, err
}
// Remove the known and now extracted records from the leftover extra
// data map.
parsedKnownRecords := make(fn.Set[tlv.Type], len(knownRecords))
for _, producer := range knownRecords {
r := producer.Record()
// Only remove the records if it was parsed (remainder is nil).
// We'll just store the type so we can tell the caller which
// records were actually parsed fully.
val, ok := extraDataTlvMap[r.Type()]
if ok && val == nil {
parsedKnownRecords.Add(r.Type())
delete(extraDataTlvMap, r.Type())
}
}
// Any records from the extra data TLV map which are in the custom
// records TLV type range will be included in the custom records field
// and removed from the extra data field.
customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap))
for k, v := range extraDataTlvMap {
// Skip records that are not in the custom records TLV type
// range.
if k < MinCustomRecordsTlvType {
continue
}
// Include the record in the custom records map.
customRecordsTlvMap[k] = v
// Now that the record is included in the custom records map,
// we can remove it from the extra data TLV map.
delete(extraDataTlvMap, k)
}
// Set the custom records field to the custom records specific TLV
// record map.
customRecords, err := NewCustomRecords(customRecordsTlvMap)
if err != nil {
return nil, nil, nil, err
}
// Encode the remaining records back into the extra data field. These
// records are not in the custom records TLV type range and do not
// have associated fields in the struct that produced the records.
extraData, err := NewExtraOpaqueData(extraDataTlvMap)
if err != nil {
return nil, nil, nil, err
}
// Help with unit testing where we might have the empty value (nil) for
// the extra data instead of the default that's returned by the
// constructor (empty slice).
if len(extraData) == 0 {
extraData = nil
}
return customRecords, parsedKnownRecords, extraData, nil
}
// MergeAndEncode merges the known records with the extra data and custom
// records, then encodes the merged records into raw bytes.
func MergeAndEncode(knownRecords []tlv.RecordProducer,
extraData ExtraOpaqueData, customRecords CustomRecords) ([]byte,
error) {
// Construct a slice of all the records that we should include in the
// message extra data field. We will start by including any records from
// the extra data field.
mergedRecords, err := extraData.RecordProducers()
if err != nil {
return nil, err
}
// Merge the known and extra data records.
mergedRecords = append(mergedRecords, knownRecords...)
// Include custom records in the extra data wire field if they are
// present. Ensure that the custom records are validated before encoding
// them.
if err := customRecords.Validate(); err != nil {
return nil, fmt.Errorf("custom records validation error: %w",
err)
}
// Extend the message extra data records slice with TLV records from the
// custom records field.
mergedRecords = append(
mergedRecords, customRecords.RecordProducers()...,
)
// Now we can sort the records and make sure there are no records with
// the same type that would collide when encoding.
sortedRecords := ProduceRecordsSorted(mergedRecords...)
if err := AssertUniqueTypes(sortedRecords); err != nil {
return nil, err
}
return EncodeRecords(sortedRecords)
}
package lnwire
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// ErrFeaturePairExists signals an error in feature vector construction
// where the opposing bit in a feature pair has already been set.
ErrFeaturePairExists = errors.New("feature pair exists")
// ErrFeatureStandard is returned when attempts to modify LND's known
// set of features are made.
ErrFeatureStandard = errors.New("feature is used in standard " +
"protocol set")
// ErrFeatureBitMaximum is returned when a feature bit exceeds the
// maximum allowable value.
ErrFeatureBitMaximum = errors.New("feature bit exceeds allowed maximum")
)
// FeatureBit represents a feature that can be enabled in either a local or
// global feature vector at a specific bit position. Feature bits follow the
// "it's OK to be odd" rule, where features at even bit positions must be known
// to a node receiving them from a peer while odd bits do not. In accordance,
// feature bits are usually assigned in pairs, first being assigned an odd bit
// position which may later be changed to the preceding even position once
// knowledge of the feature becomes required on the network.
type FeatureBit uint16
const (
// DataLossProtectRequired is a feature bit that indicates that a peer
// *requires* the other party know about the data-loss-protect optional
// feature. If the remote peer does not know of such a feature, then
// the sending peer SHOULD disconnect them. The data-loss-protect
// feature allows a peer that's lost partial data to recover their
// settled funds of the latest commitment state.
DataLossProtectRequired FeatureBit = 0
// DataLossProtectOptional is an optional feature bit that indicates
// that the sending peer knows of this new feature and can activate it
// it. The data-loss-protect feature allows a peer that's lost partial
// data to recover their settled funds of the latest commitment state.
DataLossProtectOptional FeatureBit = 1
// InitialRoutingSync is a local feature bit meaning that the receiving
// node should send a complete dump of routing information when a new
// connection is established.
InitialRoutingSync FeatureBit = 3
// UpfrontShutdownScriptRequired is a feature bit which indicates that a
// peer *requires* that the remote peer accept an upfront shutdown script to
// which payout is enforced on cooperative closes.
UpfrontShutdownScriptRequired FeatureBit = 4
// UpfrontShutdownScriptOptional is an optional feature bit which indicates
// that the peer will accept an upfront shutdown script to which payout is
// enforced on cooperative closes.
UpfrontShutdownScriptOptional FeatureBit = 5
// GossipQueriesRequired is a feature bit that indicates that the
// receiving peer MUST know of the set of features that allows nodes to
// more efficiently query the network view of peers on the network for
// reconciliation purposes.
GossipQueriesRequired FeatureBit = 6
// GossipQueriesOptional is an optional feature bit that signals that
// the setting peer knows of the set of features that allows more
// efficient network view reconciliation.
GossipQueriesOptional FeatureBit = 7
// TLVOnionPayloadRequired is a feature bit that indicates a node is
// able to decode the new TLV information included in the onion packet.
TLVOnionPayloadRequired FeatureBit = 8
// TLVOnionPayloadOptional is an optional feature bit that indicates a
// node is able to decode the new TLV information included in the onion
// packet.
TLVOnionPayloadOptional FeatureBit = 9
// StaticRemoteKeyRequired is a required feature bit that signals that
// within one's commitment transaction, the key used for the remote
// party's non-delay output should not be tweaked.
StaticRemoteKeyRequired FeatureBit = 12
// StaticRemoteKeyOptional is an optional feature bit that signals that
// within one's commitment transaction, the key used for the remote
// party's non-delay output should not be tweaked.
StaticRemoteKeyOptional FeatureBit = 13
// PaymentAddrRequired is a required feature bit that signals that a
// node requires payment addresses, which are used to mitigate probing
// attacks on the receiver of a payment.
PaymentAddrRequired FeatureBit = 14
// PaymentAddrOptional is an optional feature bit that signals that a
// node supports payment addresses, which are used to mitigate probing
// attacks on the receiver of a payment.
PaymentAddrOptional FeatureBit = 15
// MPPRequired is a required feature bit that signals that the receiver
// of a payment requires settlement of an invoice with more than one
// HTLC.
MPPRequired FeatureBit = 16
// MPPOptional is an optional feature bit that signals that the receiver
// of a payment supports settlement of an invoice with more than one
// HTLC.
MPPOptional FeatureBit = 17
// WumboChannelsRequired is a required feature bit that signals that a
// node is willing to accept channels larger than 2^24 satoshis.
WumboChannelsRequired FeatureBit = 18
// WumboChannelsOptional is an optional feature bit that signals that a
// node is willing to accept channels larger than 2^24 satoshis.
WumboChannelsOptional FeatureBit = 19
// AnchorsRequired is a required feature bit that signals that the node
// requires channels to be made using commitments having anchor
// outputs.
AnchorsRequired FeatureBit = 20
// AnchorsOptional is an optional feature bit that signals that the
// node supports channels to be made using commitments having anchor
// outputs.
AnchorsOptional FeatureBit = 21
// AnchorsZeroFeeHtlcTxRequired is a required feature bit that signals
// that the node requires channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments.
AnchorsZeroFeeHtlcTxRequired FeatureBit = 22
// AnchorsZeroFeeHtlcTxOptional is an optional feature bit that signals
// that the node supports channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments.
AnchorsZeroFeeHtlcTxOptional FeatureBit = 23
// RouteBlindingRequired is a required feature bit that signals that
// the node supports blinded payments.
RouteBlindingRequired FeatureBit = 24
// RouteBlindingOptional is an optional feature bit that signals that
// the node supports blinded payments.
RouteBlindingOptional FeatureBit = 25
// ShutdownAnySegwitRequired is an required feature bit that signals
// that the sender is able to properly handle/parse segwit witness
// programs up to version 16. This enables utilization of Taproot
// addresses for cooperative closure addresses.
ShutdownAnySegwitRequired FeatureBit = 26
// ShutdownAnySegwitOptional is an optional feature bit that signals
// that the sender is able to properly handle/parse segwit witness
// programs up to version 16. This enables utilization of Taproot
// addresses for cooperative closure addresses.
ShutdownAnySegwitOptional FeatureBit = 27
// AMPRequired is a required feature bit that signals that the receiver
// of a payment supports accepts spontaneous payments, i.e.
// sender-generated preimages according to BOLT XX.
AMPRequired FeatureBit = 30
// AMPOptional is an optional feature bit that signals that the receiver
// of a payment supports accepts spontaneous payments, i.e.
// sender-generated preimages according to BOLT XX.
AMPOptional FeatureBit = 31
// QuiescenceRequired is a required feature bit that denotes that a
// connection established with this node must support the quiescence
// protocol if it wants to have a channel relationship.
QuiescenceRequired FeatureBit = 34
// QuiescenceOptional is an optional feature bit that denotes that a
// connection established with this node is permitted to use the
// quiescence protocol.
QuiescenceOptional FeatureBit = 35
// ExplicitChannelTypeRequired is a required bit that denotes that a
// connection established with this node is to use explicit channel
// commitment types for negotiation instead of the existing implicit
// negotiation methods. With this bit, there is no longer a "default"
// implicit channel commitment type, allowing a connection to
// open/maintain types of several channels over its lifetime.
ExplicitChannelTypeRequired = 44
// ExplicitChannelTypeOptional is an optional bit that denotes that a
// connection established with this node is to use explicit channel
// commitment types for negotiation instead of the existing implicit
// negotiation methods. With this bit, there is no longer a "default"
// implicit channel commitment type, allowing a connection to
// TODO: Decide on actual feature bit value.
ExplicitChannelTypeOptional = 45
// ScidAliasRequired is a required feature bit that signals that the
// node requires understanding of ShortChannelID aliases in the TLV
// segment of the channel_ready message.
ScidAliasRequired FeatureBit = 46
// ScidAliasOptional is an optional feature bit that signals that the
// node understands ShortChannelID aliases in the TLV segment of the
// channel_ready message.
ScidAliasOptional FeatureBit = 47
// PaymentMetadataRequired is a required bit that denotes that if an
// invoice contains metadata, it must be passed along with the payment
// htlc(s).
PaymentMetadataRequired = 48
// PaymentMetadataOptional is an optional bit that denotes that if an
// invoice contains metadata, it may be passed along with the payment
// htlc(s).
PaymentMetadataOptional = 49
// ZeroConfRequired is a required feature bit that signals that the
// node requires understanding of the zero-conf channel_type.
ZeroConfRequired FeatureBit = 50
// ZeroConfOptional is an optional feature bit that signals that the
// node understands the zero-conf channel type.
ZeroConfOptional FeatureBit = 51
// KeysendRequired is a required bit that indicates that the node is
// able and willing to accept keysend payments.
KeysendRequired = 54
// KeysendOptional is an optional bit that indicates that the node is
// able and willing to accept keysend payments.
KeysendOptional = 55
// RbfCoopCloseRequired is a required feature bit that signals that
// the new RBF-based co-op close protocol is supported.
RbfCoopCloseRequired = 60
// RbfCoopCloseOptional is an optional feature bit that signals that the
// new RBF-based co-op close protocol is supported.
RbfCoopCloseOptional = 61
// RbfCoopCloseRequiredStaging is a required feature bit that signals
// that the new RBF-based co-op close protocol is supported.
RbfCoopCloseRequiredStaging = 160
// RbfCoopCloseOptionalStaging is an optional feature bit that signals
// that the new RBF-based co-op close protocol is supported.
RbfCoopCloseOptionalStaging = 161
// ScriptEnforcedLeaseRequired is a required feature bit that signals
// that the node requires channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments, along with an
// additional CLTV constraint of a channel lease's expiration height
// applied to all outputs that pay directly to the channel initiator.
//
// TODO: Decide on actual feature bit value.
ScriptEnforcedLeaseRequired FeatureBit = 2022
// ScriptEnforcedLeaseOptional is an optional feature bit that signals
// that the node requires channels having zero-fee second-level HTLC
// transactions, which also imply anchor commitments, along with an
// additional CLTV constraint of a channel lease's expiration height
// applied to all outputs that pay directly to the channel initiator.
//
// TODO: Decide on actual feature bit value.
ScriptEnforcedLeaseOptional FeatureBit = 2023
// SimpleTaprootChannelsRequiredFinal is a required bit that indicates
// the node is able to create taproot-native channels. This is the
// final feature bit to be used once the channel type is finalized.
SimpleTaprootChannelsRequiredFinal = 80
// SimpleTaprootChannelsOptionalFinal is an optional bit that indicates
// the node is able to create taproot-native channels. This is the
// final feature bit to be used once the channel type is finalized.
SimpleTaprootChannelsOptionalFinal = 81
// SimpleTaprootChannelsRequiredStaging is a required bit that indicates
// the node is able to create taproot-native channels. This is a
// feature bit used in the wild while the channel type is still being
// finalized.
SimpleTaprootChannelsRequiredStaging = 180
// SimpleTaprootChannelsOptionalStaging is an optional bit that
// indicates the node is able to create taproot-native channels. This
// is a feature bit used in the wild while the channel type is still
// being finalized.
SimpleTaprootChannelsOptionalStaging = 181
// ExperimentalEndorsementRequired is a required feature bit that
// indicates that the node will relay experimental endorsement signals.
ExperimentalEndorsementRequired FeatureBit = 260
// ExperimentalEndorsementOptional is an optional feature bit that
// indicates that the node will relay experimental endorsement signals.
ExperimentalEndorsementOptional FeatureBit = 261
// Bolt11BlindedPathsRequired is a required feature bit that indicates
// that the node is able to understand the blinded path tagged field in
// a BOLT 11 invoice.
Bolt11BlindedPathsRequired = 262
// Bolt11BlindedPathsOptional is an optional feature bit that indicates
// that the node is able to understand the blinded path tagged field in
// a BOLT 11 invoice.
Bolt11BlindedPathsOptional = 263
// SimpleTaprootOverlayChansRequired is a required bit that indicates
// support for the special custom taproot overlay channel.
SimpleTaprootOverlayChansOptional = 2025
// SimpleTaprootOverlayChansRequired is a required bit that indicates
// support for the special custom taproot overlay channel.
SimpleTaprootOverlayChansRequired = 2026
// MaxBolt11Feature is the maximum feature bit value allowed in bolt 11
// invoices.
//
// The base 32 encoded tagged fields in invoices are limited to 10 bits
// to express the length of the field's data.
//nolint:ll
// See: https://github.com/lightning/bolts/blob/master/11-payment-encoding.md#tagged-fields
//
// With a maximum length field of 1023 (2^10 -1) and 5 bit encoding,
// the highest feature bit that can be expressed is:
// 1023 * 5 - 1 = 5114.
MaxBolt11Feature = 5114
)
// IsRequired returns true if the feature bit is even, and false otherwise.
func (b FeatureBit) IsRequired() bool {
return b&0x01 == 0x00
}
// Features is a mapping of known feature bits to a descriptive name. All known
// feature bits must be assigned a name in this mapping, and feature bit pairs
// must be assigned together for correct behavior.
var Features = map[FeatureBit]string{
DataLossProtectRequired: "data-loss-protect",
DataLossProtectOptional: "data-loss-protect",
InitialRoutingSync: "initial-routing-sync",
UpfrontShutdownScriptRequired: "upfront-shutdown-script",
UpfrontShutdownScriptOptional: "upfront-shutdown-script",
GossipQueriesRequired: "gossip-queries",
GossipQueriesOptional: "gossip-queries",
TLVOnionPayloadRequired: "tlv-onion",
TLVOnionPayloadOptional: "tlv-onion",
StaticRemoteKeyOptional: "static-remote-key",
StaticRemoteKeyRequired: "static-remote-key",
PaymentAddrOptional: "payment-addr",
PaymentAddrRequired: "payment-addr",
MPPOptional: "multi-path-payments",
MPPRequired: "multi-path-payments",
AnchorsRequired: "anchor-commitments",
AnchorsOptional: "anchor-commitments",
AnchorsZeroFeeHtlcTxRequired: "anchors-zero-fee-htlc-tx",
AnchorsZeroFeeHtlcTxOptional: "anchors-zero-fee-htlc-tx",
WumboChannelsRequired: "wumbo-channels",
WumboChannelsOptional: "wumbo-channels",
AMPRequired: "amp",
AMPOptional: "amp",
QuiescenceRequired: "quiescence",
QuiescenceOptional: "quiescence",
PaymentMetadataOptional: "payment-metadata",
PaymentMetadataRequired: "payment-metadata",
ExplicitChannelTypeOptional: "explicit-commitment-type",
ExplicitChannelTypeRequired: "explicit-commitment-type",
KeysendOptional: "keysend",
KeysendRequired: "keysend",
ScriptEnforcedLeaseRequired: "script-enforced-lease",
ScriptEnforcedLeaseOptional: "script-enforced-lease",
ScidAliasRequired: "scid-alias",
ScidAliasOptional: "scid-alias",
ZeroConfRequired: "zero-conf",
ZeroConfOptional: "zero-conf",
RouteBlindingRequired: "route-blinding",
RouteBlindingOptional: "route-blinding",
ShutdownAnySegwitRequired: "shutdown-any-segwit",
ShutdownAnySegwitOptional: "shutdown-any-segwit",
SimpleTaprootChannelsRequiredFinal: "simple-taproot-chans",
SimpleTaprootChannelsOptionalFinal: "simple-taproot-chans",
SimpleTaprootChannelsRequiredStaging: "simple-taproot-chans-x",
SimpleTaprootChannelsOptionalStaging: "simple-taproot-chans-x",
SimpleTaprootOverlayChansOptional: "taproot-overlay-chans",
SimpleTaprootOverlayChansRequired: "taproot-overlay-chans",
ExperimentalEndorsementRequired: "endorsement-x",
ExperimentalEndorsementOptional: "endorsement-x",
Bolt11BlindedPathsOptional: "bolt-11-blinded-paths",
Bolt11BlindedPathsRequired: "bolt-11-blinded-paths",
RbfCoopCloseOptional: "rbf-coop-close",
RbfCoopCloseRequired: "rbf-coop-close",
RbfCoopCloseOptionalStaging: "rbf-coop-close-x",
RbfCoopCloseRequiredStaging: "rbf-coop-close-x",
}
// RawFeatureVector represents a set of feature bits as defined in BOLT-09. A
// RawFeatureVector itself just stores a set of bit flags but can be used to
// construct a FeatureVector which binds meaning to each bit. Feature vectors
// can be serialized and deserialized to/from a byte representation that is
// transmitted in Lightning network messages.
type RawFeatureVector struct {
features map[FeatureBit]struct{}
}
// NewRawFeatureVector creates a feature vector with all of the feature bits
// given as arguments enabled.
func NewRawFeatureVector(bits ...FeatureBit) *RawFeatureVector {
fv := &RawFeatureVector{features: make(map[FeatureBit]struct{})}
for _, bit := range bits {
fv.Set(bit)
}
return fv
}
// IsEmpty returns whether the feature vector contains any feature bits.
func (fv RawFeatureVector) IsEmpty() bool {
return len(fv.features) == 0
}
// OnlyContains determines whether only the specified feature bits are found.
func (fv RawFeatureVector) OnlyContains(bits ...FeatureBit) bool {
if len(bits) != len(fv.features) {
return false
}
for _, bit := range bits {
if !fv.IsSet(bit) {
return false
}
}
return true
}
// Equals determines whether two features vectors contain exactly the same
// features.
func (fv RawFeatureVector) Equals(other *RawFeatureVector) bool {
if len(fv.features) != len(other.features) {
return false
}
for bit := range fv.features {
if _, ok := other.features[bit]; !ok {
return false
}
}
return true
}
// Merges sets all feature bits in other on the receiver's feature vector.
func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error {
for bit := range other.features {
err := fv.SafeSet(bit)
if err != nil {
return err
}
}
return nil
}
// ValidateUpdate checks whether a feature vector can safely be updated to the
// new feature vector provided, checking that it does not alter any of the
// "standard" features that are defined by LND. The new feature vector should
// be inclusive of all features in the original vector that it still wants to
// advertise, setting and unsetting updates as desired. Features in the vector
// are also checked against a maximum inclusive value, as feature vectors in
// different contexts have different maximum values.
func (fv *RawFeatureVector) ValidateUpdate(other *RawFeatureVector,
maximumValue FeatureBit) error {
// Run through the new set of features and check that we're not adding
// any feature bits that are defined but not set in LND.
for feature := range other.features {
if fv.IsSet(feature) {
continue
}
if feature > maximumValue {
return fmt.Errorf("can't set feature bit %d: %w %v",
feature, ErrFeatureBitMaximum,
maximumValue)
}
if name, known := Features[feature]; known {
return fmt.Errorf("can't set feature "+
"bit %d (%v): %w", feature, name,
ErrFeatureStandard)
}
}
// Check that the new feature vector for this set does not unset any
// features that are standard in LND by comparing the features in our
// current set to the omitted values in the new set.
for feature := range fv.features {
if other.IsSet(feature) {
continue
}
if name, known := Features[feature]; known {
return fmt.Errorf("can't unset feature "+
"bit %d (%v): %w", feature, name,
ErrFeatureStandard)
}
}
return nil
}
// ValidatePairs checks each feature bit in a raw vector to ensure that the
// opposing bit is not set, validating that the vector has either the optional
// or required bit set, not both.
func (fv *RawFeatureVector) ValidatePairs() error {
for feature := range fv.features {
if _, ok := fv.features[feature^1]; ok {
return ErrFeaturePairExists
}
}
return nil
}
// Clone makes a copy of a feature vector.
func (fv *RawFeatureVector) Clone() *RawFeatureVector {
newFeatures := NewRawFeatureVector()
for bit := range fv.features {
newFeatures.Set(bit)
}
return newFeatures
}
// IsSet returns whether a particular feature bit is enabled in the vector.
func (fv *RawFeatureVector) IsSet(feature FeatureBit) bool {
_, ok := fv.features[feature]
return ok
}
// Set marks a feature as enabled in the vector.
func (fv *RawFeatureVector) Set(feature FeatureBit) {
fv.features[feature] = struct{}{}
}
// SafeSet sets the chosen feature bit in the feature vector, but returns an
// error if the opposing feature bit is already set. This ensures both that we
// are creating properly structured feature vectors, and in some cases, that
// peers are sending properly encoded ones, i.e. it can't be both optional and
// required.
func (fv *RawFeatureVector) SafeSet(feature FeatureBit) error {
if _, ok := fv.features[feature^1]; ok {
return ErrFeaturePairExists
}
fv.Set(feature)
return nil
}
// Unset marks a feature as disabled in the vector.
func (fv *RawFeatureVector) Unset(feature FeatureBit) {
delete(fv.features, feature)
}
// SerializeSize returns the number of bytes needed to represent feature vector
// in byte format.
func (fv *RawFeatureVector) SerializeSize() int {
// We calculate byte-length via the largest bit index.
return fv.serializeSize(8)
}
// SerializeSize32 returns the number of bytes needed to represent feature
// vector in base32 format.
func (fv *RawFeatureVector) SerializeSize32() int {
// We calculate base32-length via the largest bit index.
return fv.serializeSize(5)
}
// serializeSize returns the number of bytes required to encode the feature
// vector using at most width bits per encoded byte.
func (fv *RawFeatureVector) serializeSize(width int) int {
// Find the largest feature bit index
max := -1
for feature := range fv.features {
index := int(feature)
if index > max {
max = index
}
}
if max == -1 {
return 0
}
return max/width + 1
}
// Encode writes the feature vector in byte representation. Every feature
// encoded as a bit, and the bit vector is serialized using the least number of
// bytes. Since the bit vector length is variable, the first two bytes of the
// serialization represent the length.
func (fv *RawFeatureVector) Encode(w io.Writer) error {
// Write length of feature vector.
var l [2]byte
length := fv.SerializeSize()
binary.BigEndian.PutUint16(l[:], uint16(length))
if _, err := w.Write(l[:]); err != nil {
return err
}
return fv.encode(w, length, 8)
}
// EncodeBase256 writes the feature vector in base256 representation. Every
// feature is encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) EncodeBase256(w io.Writer) error {
length := fv.SerializeSize()
return fv.encode(w, length, 8)
}
// EncodeBase32 writes the feature vector in base32 representation. Every feature
// is encoded as a bit, and the bit vector is serialized using the least number of
// bytes.
func (fv *RawFeatureVector) EncodeBase32(w io.Writer) error {
length := fv.SerializeSize32()
return fv.encode(w, length, 5)
}
// encode writes the feature vector
func (fv *RawFeatureVector) encode(w io.Writer, length, width int) error {
// Generate the data and write it.
data := make([]byte, length)
for feature := range fv.features {
byteIndex := int(feature) / width
bitIndex := int(feature) % width
data[length-byteIndex-1] |= 1 << uint(bitIndex)
}
_, err := w.Write(data)
return err
}
// Decode reads the feature vector from its byte representation. Every feature
// is encoded as a bit, and the bit vector is serialized using the least number
// of bytes. Since the bit vector length is variable, the first two bytes of the
// serialization represent the length.
func (fv *RawFeatureVector) Decode(r io.Reader) error {
// Read the length of the feature vector.
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
length := binary.BigEndian.Uint16(l[:])
return fv.decode(r, int(length), 8)
}
// DecodeBase256 reads the feature vector from its base256 representation. Every
// feature encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) DecodeBase256(r io.Reader, length int) error {
return fv.decode(r, length, 8)
}
// DecodeBase32 reads the feature vector from its base32 representation. Every
// feature encoded as a bit, and the bit vector is serialized using the least
// number of bytes.
func (fv *RawFeatureVector) DecodeBase32(r io.Reader, length int) error {
return fv.decode(r, length, 5)
}
// decode reads a feature vector from the next length bytes of the io.Reader,
// assuming each byte has width feature bits encoded per byte.
func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error {
// Read the feature vector data.
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return err
}
// Set feature bits from parsed data.
bitsNumber := len(data) * width
for i := 0; i < bitsNumber; i++ {
byteIndex := int(i / width)
bitIndex := uint(i % width)
if (data[length-byteIndex-1]>>bitIndex)&1 == 1 {
fv.Set(FeatureBit(i))
}
}
return nil
}
// sizeFunc returns the length required to encode the feature vector.
func (fv *RawFeatureVector) sizeFunc() uint64 {
return uint64(fv.SerializeSize())
}
// Record returns a TLV record that can be used to encode/decode raw feature
// vectors. Note that the length of the feature vector is not included, because
// it is covered by the TLV record's length field.
func (fv *RawFeatureVector) Record() tlv.Record {
return tlv.MakeDynamicRecord(
0, fv, fv.sizeFunc, rawFeatureEncoder, rawFeatureDecoder,
)
}
// rawFeatureEncoder is a custom TLV encoder for raw feature vectors.
func rawFeatureEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*RawFeatureVector); ok {
// Encode the feature bits as a byte slice without its length
// prepended, as that's already taken care of by the TLV record.
fv := *v
return fv.encode(w, fv.SerializeSize(), 8)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector")
}
// rawFeatureDecoder is a custom TLV decoder for raw feature vectors.
func rawFeatureDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*RawFeatureVector); ok {
fv := NewRawFeatureVector()
if err := fv.decode(r, int(l), 8); err != nil {
return err
}
*v = *fv
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.RawFeatureVector")
}
// FeatureVector represents a set of enabled features. The set stores
// information on enabled flags and metadata about the feature names. A feature
// vector is serializable to a compact byte representation that is included in
// Lightning network messages.
type FeatureVector struct {
*RawFeatureVector
featureNames map[FeatureBit]string
}
// NewFeatureVector constructs a new FeatureVector from a raw feature vector
// and mapping of feature definitions. If the feature vector argument is nil, a
// new one will be constructed with no enabled features.
func NewFeatureVector(featureVector *RawFeatureVector,
featureNames map[FeatureBit]string) *FeatureVector {
if featureVector == nil {
featureVector = NewRawFeatureVector()
}
return &FeatureVector{
RawFeatureVector: featureVector,
featureNames: featureNames,
}
}
// EmptyFeatureVector returns a feature vector with no bits set.
func EmptyFeatureVector() *FeatureVector {
return NewFeatureVector(nil, Features)
}
// Record implements the RecordProducer interface for FeatureVector. Note that
// it uses a zero-value type is used to produce the record, as we expect this
// type value to be overwritten when used in generic TLV record production.
// This allows a single Record function to serve in the many different contexts
// in which feature vectors are encoded. This record wraps the encoding/
// decoding for our raw feature vectors so that we can directly parse fully
// formed feature vector types.
func (fv *FeatureVector) Record() tlv.Record {
return tlv.MakeDynamicRecord(0, fv, fv.sizeFunc,
func(w io.Writer, val interface{}, buf *[8]byte) error {
if f, ok := val.(*FeatureVector); ok {
return rawFeatureEncoder(
w, f.RawFeatureVector, buf,
)
}
return tlv.NewTypeForEncodingErr(
val, "*lnwire.FeatureVector",
)
},
func(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if f, ok := val.(*FeatureVector); ok {
features := NewFeatureVector(nil, Features)
err := rawFeatureDecoder(
r, features.RawFeatureVector, buf, l,
)
if err != nil {
return err
}
*f = *features
return nil
}
return tlv.NewTypeForDecodingErr(
val, "*lnwire.FeatureVector", l, l,
)
},
)
}
// HasFeature returns whether a particular feature is included in the set. The
// feature can be seen as set either if the bit is set directly OR the queried
// bit has the same meaning as its corresponding even/odd bit, which is set
// instead. The second case is because feature bits are generally assigned in
// pairs where both the even and odd position represent the same feature.
func (fv *FeatureVector) HasFeature(feature FeatureBit) bool {
return fv.IsSet(feature) ||
(fv.isFeatureBitPair(feature) && fv.IsSet(feature^1))
}
// RequiresFeature returns true if the referenced feature vector *requires*
// that the given required bit be set. This method can be used with both
// optional and required feature bits as a parameter.
func (fv *FeatureVector) RequiresFeature(feature FeatureBit) bool {
// If we weren't passed a required feature bit, then we'll flip the
// lowest bit to query for the required version of the feature. This
// lets callers pass in both the optional and required bits.
if !feature.IsRequired() {
feature ^= 1
}
return fv.IsSet(feature)
}
// UnknownRequiredFeatures returns a list of feature bits set in the vector
// that are unknown and in an even bit position. Feature bits with an even
// index must be known to a node receiving the feature vector in a message.
func (fv *FeatureVector) UnknownRequiredFeatures() []FeatureBit {
var unknown []FeatureBit
for feature := range fv.features {
if feature%2 == 0 && !fv.IsKnown(feature) {
unknown = append(unknown, feature)
}
}
return unknown
}
// UnknownFeatures returns a boolean if a feature vector contains *any*
// unknown features (even if they are odd).
func (fv *FeatureVector) UnknownFeatures() bool {
for feature := range fv.features {
if !fv.IsKnown(feature) {
return true
}
}
return false
}
// Name returns a string identifier for the feature represented by this bit. If
// the bit does not represent a known feature, this returns a string indicating
// as such.
func (fv *FeatureVector) Name(bit FeatureBit) string {
name, known := fv.featureNames[bit]
if !known {
return "unknown"
}
return name
}
// IsKnown returns whether this feature bit represents a known feature.
func (fv *FeatureVector) IsKnown(bit FeatureBit) bool {
_, known := fv.featureNames[bit]
return known
}
// isFeatureBitPair returns whether this feature bit and its corresponding
// even/odd bit both represent the same feature. This may often be the case as
// bits are generally assigned in pairs, first being assigned an odd bit
// position then being promoted to an even bit position once the network is
// ready.
func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool {
name1, known1 := fv.featureNames[bit]
name2, known2 := fv.featureNames[bit^1]
return known1 && known2 && name1 == name2
}
// Features returns the set of raw features contained in the feature vector.
func (fv *FeatureVector) Features() map[FeatureBit]struct{} {
fs := make(map[FeatureBit]struct{}, len(fv.RawFeatureVector.features))
for b := range fv.RawFeatureVector.features {
fs[b] = struct{}{}
}
return fs
}
// Clone copies a feature vector, carrying over its feature bits. The feature
// names are not copied.
func (fv *FeatureVector) Clone() *FeatureVector {
features := fv.RawFeatureVector.Clone()
return NewFeatureVector(features, fv.featureNames)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/tlv"
)
// FundingCreated is sent from Alice (the initiator) to Bob (the responder),
// once Alice receives Bob's contributions as well as his channel constraints.
// Once bob receives this message, he'll gain access to an immediately
// broadcastable commitment transaction and will reply with a signature for
// Alice's version of the commitment transaction.
type FundingCreated struct {
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// FundingPoint is the outpoint of the funding transaction created by
// Alice. With this, Bob is able to generate both his version and
// Alice's version of the commitment transaction.
FundingPoint wire.OutPoint
// CommitSig is Alice's signature from Bob's version of the commitment
// transaction.
CommitSig Sig
// PartialSig is used to transmit a musig2 extended partial signature
// that also carries along the public nonce of the signer.
//
// NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank.
PartialSig OptPartialSigWithNonceTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure FundingCreated implements the lnwire.Message
// interface.
var _ Message = (*FundingCreated)(nil)
// A compile time check to ensure FundingCreated implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*FundingCreated)(nil)
// Encode serializes the target FundingCreated into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, &sig)
})
err := EncodeMessageExtraData(&f.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteBytes(w, f.PendingChannelID[:]); err != nil {
return err
}
if err := WriteOutPoint(w, f.FundingPoint); err != nil {
return err
}
if err := WriteSig(w, f.CommitSig); err != nil {
return err
}
return WriteBytes(w, f.ExtraData)
}
// Decode deserializes the serialized FundingCreated stored in the passed
// io.Reader into the target FundingCreated using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) Decode(r io.Reader, pver uint32) error {
err := ReadElements(
r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
partialSig := f.PartialSig.Zero()
typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil {
f.PartialSig = tlv.SomeRecordT(partialSig)
}
if len(tlvRecords) != 0 {
f.ExtraData = tlvRecords
}
return nil
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// FundingCreated on the wire.
//
// This is part of the lnwire.Message interface.
func (f *FundingCreated) MsgType() MessageType {
return MsgFundingCreated
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (f *FundingCreated) SerializedSize() (uint32, error) {
return MessageSerializedSize(f)
}
package lnwire
import (
"bytes"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
// FundingSigned is sent from Bob (the responder) to Alice (the initiator)
// after receiving the funding outpoint and her signature for Bob's version of
// the commitment transaction.
type FundingSigned struct {
// ChannelPoint is the particular active channel that this
// FundingSigned is bound to.
ChanID ChannelID
// CommitSig is Bob's signature for Alice's version of the commitment
// transaction.
CommitSig Sig
// PartialSig is used to transmit a musig2 extended partial signature
// that also carries along the public nonce of the signer.
//
// NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank.
PartialSig OptPartialSigWithNonceTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure FundingSigned implements the lnwire.Message
// interface.
var _ Message = (*FundingSigned)(nil)
// A compile time check to ensure FundingSigned implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*FundingSigned)(nil)
// Encode serializes the target FundingSigned into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, &sig)
})
err := EncodeMessageExtraData(&f.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteChannelID(w, f.ChanID); err != nil {
return err
}
if err := WriteSig(w, f.CommitSig); err != nil {
return err
}
return WriteBytes(w, f.ExtraData)
}
// Decode deserializes the serialized FundingSigned stored in the passed
// io.Reader into the target FundingSigned using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r, &f.ChanID, &f.CommitSig)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
partialSig := f.PartialSig.Zero()
typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil {
f.PartialSig = tlv.SomeRecordT(partialSig)
}
if len(tlvRecords) != 0 {
f.ExtraData = tlvRecords
}
return nil
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// FundingSigned on the wire.
//
// This is part of the lnwire.Message interface.
func (f *FundingSigned) MsgType() MessageType {
return MsgFundingSigned
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (f *FundingSigned) SerializedSize() (uint32, error) {
return MessageSerializedSize(f)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// GossipTimestampRange is a message that allows the sender to restrict the set
// of future gossip announcements sent by the receiver. Nodes should send this
// if they have the gossip-queries feature bit active. Nodes are able to send
// new GossipTimestampRange messages to replace the prior window.
type GossipTimestampRange struct {
// ChainHash denotes the chain that the sender wishes to restrict the
// set of received announcements of.
ChainHash chainhash.Hash
// FirstTimestamp is the timestamp of the earliest announcement message
// that should be sent by the receiver. This is only to be used for
// querying message of gossip 1.0 which are timestamped using Unix
// timestamps. FirstBlockHeight and BlockRange should be used to
// query for announcement messages timestamped using block heights.
FirstTimestamp uint32
// TimestampRange is the horizon beyond the FirstTimestamp that any
// announcement messages should be sent for. The receiving node MUST
// NOT send any announcements that have a timestamp greater than
// FirstTimestamp + TimestampRange. This is used together with
// FirstTimestamp to query for gossip 1.0 messages timestamped with
// Unix timestamps.
TimestampRange uint32
// FirstBlockHeight is the height of earliest announcement message that
// should be sent by the receiver. This is used only for querying
// announcement messages that use block heights as a timestamp.
FirstBlockHeight tlv.OptionalRecordT[tlv.TlvType2, uint32]
// BlockRange is the horizon beyond FirstBlockHeight that any
// announcement messages should be sent for. The receiving node MUST NOT
// send any announcements that have a timestamp greater than
// FirstBlockHeight + BlockRange.
BlockRange tlv.OptionalRecordT[tlv.TlvType4, uint32]
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewGossipTimestampRange creates a new empty GossipTimestampRange message.
func NewGossipTimestampRange() *GossipTimestampRange {
return &GossipTimestampRange{}
}
// A compile time check to ensure GossipTimestampRange implements the
// lnwire.Message interface.
var _ Message = (*GossipTimestampRange)(nil)
// A compile time check to ensure GossipTimestampRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*GossipTimestampRange)(nil)
// Decode deserializes a serialized GossipTimestampRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) Decode(r io.Reader, _ uint32) error {
err := ReadElements(r,
g.ChainHash[:],
&g.FirstTimestamp,
&g.TimestampRange,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var (
firstBlock = tlv.ZeroRecordT[tlv.TlvType2, uint32]()
blockRange = tlv.ZeroRecordT[tlv.TlvType4, uint32]()
)
typeMap, err := tlvRecords.ExtractRecords(&firstBlock, &blockRange)
if err != nil {
return err
}
if val, ok := typeMap[g.FirstBlockHeight.TlvType()]; ok && val == nil {
g.FirstBlockHeight = tlv.SomeRecordT(firstBlock)
}
if val, ok := typeMap[g.BlockRange.TlvType()]; ok && val == nil {
g.BlockRange = tlv.SomeRecordT(blockRange)
}
if len(tlvRecords) != 0 {
g.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target GossipTimestampRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteBytes(w, g.ChainHash[:]); err != nil {
return err
}
if err := WriteUint32(w, g.FirstTimestamp); err != nil {
return err
}
if err := WriteUint32(w, g.TimestampRange); err != nil {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 2)
g.FirstBlockHeight.WhenSome(
func(height tlv.RecordT[tlv.TlvType2, uint32]) {
recordProducers = append(recordProducers, &height)
},
)
g.BlockRange.WhenSome(
func(blockRange tlv.RecordT[tlv.TlvType4, uint32]) {
recordProducers = append(recordProducers, &blockRange)
},
)
err := EncodeMessageExtraData(&g.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, g.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (g *GossipTimestampRange) MsgType() MessageType {
return MsgGossipTimestampRange
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (g *GossipTimestampRange) SerializedSize() (uint32, error) {
return MessageSerializedSize(g)
}
package lnwire
import (
"bytes"
"io"
)
// Init is the first message reveals the features supported or required by this
// node. Nodes wait for receipt of the other's features to simplify error
// diagnosis where features are incompatible. Each node MUST wait to receive
// init before sending any other messages.
type Init struct {
// GlobalFeatures is a legacy feature vector used for backwards
// compatibility with older nodes. Any features defined here should be
// merged with those presented in Features.
GlobalFeatures *RawFeatureVector
// Features is a feature vector containing the features supported by
// the remote node.
//
// NOTE: Older nodes may place some features in GlobalFeatures, but all
// new features are to be added in Features. When handling an Init
// message, any GlobalFeatures should be merged into the unified
// Features field.
Features *RawFeatureVector
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewInitMessage creates new instance of init message object.
func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init {
return &Init{
GlobalFeatures: gf,
Features: f,
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure Init implements the lnwire.Message
// interface.
var _ Message = (*Init)(nil)
// A compile time check to ensure Init implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Init)(nil)
// Decode deserializes a serialized Init message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (msg *Init) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&msg.GlobalFeatures,
&msg.Features,
&msg.ExtraData,
)
}
// Encode serializes the target Init into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (msg *Init) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteRawFeatureVector(w, msg.GlobalFeatures); err != nil {
return err
}
if err := WriteRawFeatureVector(w, msg.Features); err != nil {
return err
}
return WriteBytes(w, msg.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (msg *Init) MsgType() MessageType {
return MsgInit
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (msg *Init) SerializedSize() (uint32, error) {
return MessageSerializedSize(msg)
}
package lnwire
import (
"bytes"
"io"
)
// KickoffSig is the message used to transmit the signature for a kickoff
// transaction during the execution phase of a dynamic commitment negotiation
// that requires a reanchoring step.
type KickoffSig struct {
// ChanID identifies the channel id for which this signature is
// intended.
ChanID ChannelID
// Signature contains the ECDSA signature that signs the kickoff
// transaction.
Signature Sig
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure that KickoffSig implements the lnwire.Message
// interface.
var _ Message = (*KickoffSig)(nil)
// A compile time check to ensure KickoffSig implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*KickoffSig)(nil)
// Encode serializes the target KickoffSig into the passed bytes.Buffer
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (ks *KickoffSig) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, ks.ChanID); err != nil {
return err
}
if err := WriteSig(w, ks.Signature); err != nil {
return err
}
return WriteBytes(w, ks.ExtraData)
}
// Decode deserializes a serialized KickoffSig message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (ks *KickoffSig) Decode(r io.Reader, _ uint32) error {
return ReadElements(r, &ks.ChanID, &ks.Signature, &ks.ExtraData)
}
// MsgType returns the integer uniquely identifying KickoffSig on the wire.
//
// This is part of the lnwire.Message interface.
func (ks *KickoffSig) MsgType() MessageType { return MsgKickoffSig }
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (ks *KickoffSig) SerializedSize() (uint32, error) {
return MessageSerializedSize(ks)
}
package lnwire
import (
"bytes"
"encoding/binary"
"fmt"
"image/color"
"io"
"math"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/tor"
)
const (
// MaxSliceLength is the maximum allowed length for any opaque byte
// slices in the wire protocol.
MaxSliceLength = 65535
// MaxMsgBody is the largest payload any message is allowed to provide.
// This is two less than the MaxSliceLength as each message has a 2
// byte type that precedes the message body.
MaxMsgBody = 65533
)
// PkScript is simple type definition which represents a raw serialized public
// key script.
type PkScript []byte
// addressType specifies the network protocol and version that should be used
// when connecting to a node at a particular address.
type addressType uint8
const (
// noAddr denotes a blank address. An address of this type indicates
// that a node doesn't have any advertised addresses.
noAddr addressType = 0
// tcp4Addr denotes an IPv4 TCP address.
tcp4Addr addressType = 1
// tcp6Addr denotes an IPv6 TCP address.
tcp6Addr addressType = 2
// v2OnionAddr denotes a version 2 Tor onion service address.
v2OnionAddr addressType = 3
// v3OnionAddr denotes a version 3 Tor (prop224) onion service address.
v3OnionAddr addressType = 4
)
// AddrLen returns the number of bytes that it takes to encode the target
// address.
func (a addressType) AddrLen() uint16 {
switch a {
case noAddr:
return 0
case tcp4Addr:
return 6
case tcp6Addr:
return 18
case v2OnionAddr:
return 12
case v3OnionAddr:
return 37
default:
return 0
}
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for the wire protocol.
//
// TODO(yy): rm this method once we finish dereferencing it from other
// packages.
func WriteElement(w *bytes.Buffer, element interface{}) error {
switch e := element.(type) {
case NodeAlias:
if _, err := w.Write(e[:]); err != nil {
return err
}
case QueryEncoding:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint8:
var b [1]byte
b[0] = e
if _, err := w.Write(b[:]); err != nil {
return err
}
case FundingFlag:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint16:
var b [2]byte
binary.BigEndian.PutUint16(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case ChanUpdateMsgFlags:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case ChanUpdateChanFlags:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case MilliSatoshi:
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case btcutil.Amount:
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint32:
var b [4]byte
binary.BigEndian.PutUint32(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint64:
var b [8]byte
binary.BigEndian.PutUint64(b[:], e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case *btcec.PublicKey:
if e == nil {
return fmt.Errorf("cannot write nil pubkey")
}
var b [33]byte
serializedPubkey := e.SerializeCompressed()
copy(b[:], serializedPubkey)
if _, err := w.Write(b[:]); err != nil {
return err
}
case []Sig:
var b [2]byte
numSigs := uint16(len(e))
binary.BigEndian.PutUint16(b[:], numSigs)
if _, err := w.Write(b[:]); err != nil {
return err
}
for _, sig := range e {
if err := WriteElement(w, sig); err != nil {
return err
}
}
case Sig:
// Write buffer
if _, err := w.Write(e.bytes[:]); err != nil {
return err
}
case PartialSig:
if err := e.Encode(w); err != nil {
return err
}
case PingPayload:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case PongPayload:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case WarningData:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case ErrorData:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case OpaqueReason:
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(e)))
if _, err := w.Write(l[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case [33]byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case []byte:
if _, err := w.Write(e[:]); err != nil {
return err
}
case PkScript:
// The largest script we'll accept is a p2wsh which is exactly
// 34 bytes long.
scriptLength := len(e)
if scriptLength > 34 {
return fmt.Errorf("'PkScript' too long")
}
if err := wire.WriteVarBytes(w, 0, e); err != nil {
return err
}
case *RawFeatureVector:
if e == nil {
return fmt.Errorf("cannot write nil feature vector")
}
if err := e.Encode(w); err != nil {
return err
}
case wire.OutPoint:
var h [32]byte
copy(h[:], e.Hash[:])
if _, err := w.Write(h[:]); err != nil {
return err
}
if e.Index > math.MaxUint16 {
return fmt.Errorf("index for outpoint (%v) is "+
"greater than max index of %v", e.Index,
math.MaxUint16)
}
var idx [2]byte
binary.BigEndian.PutUint16(idx[:], uint16(e.Index))
if _, err := w.Write(idx[:]); err != nil {
return err
}
case ChannelID:
if _, err := w.Write(e[:]); err != nil {
return err
}
case FailCode:
if err := WriteElement(w, uint16(e)); err != nil {
return err
}
case ShortChannelID:
// Check that field fit in 3 bytes and write the blockHeight
if e.BlockHeight > ((1 << 24) - 1) {
return errors.New("block height should fit in 3 bytes")
}
var blockHeight [4]byte
binary.BigEndian.PutUint32(blockHeight[:], e.BlockHeight)
if _, err := w.Write(blockHeight[1:]); err != nil {
return err
}
// Check that field fit in 3 bytes and write the txIndex
if e.TxIndex > ((1 << 24) - 1) {
return errors.New("tx index should fit in 3 bytes")
}
var txIndex [4]byte
binary.BigEndian.PutUint32(txIndex[:], e.TxIndex)
if _, err := w.Write(txIndex[1:]); err != nil {
return err
}
// Write the txPosition
var txPosition [2]byte
binary.BigEndian.PutUint16(txPosition[:], e.TxPosition)
if _, err := w.Write(txPosition[:]); err != nil {
return err
}
case *net.TCPAddr:
if e == nil {
return fmt.Errorf("cannot write nil TCPAddr")
}
if e.IP.To4() != nil {
var descriptor [1]byte
descriptor[0] = uint8(tcp4Addr)
if _, err := w.Write(descriptor[:]); err != nil {
return err
}
var ip [4]byte
copy(ip[:], e.IP.To4())
if _, err := w.Write(ip[:]); err != nil {
return err
}
} else {
var descriptor [1]byte
descriptor[0] = uint8(tcp6Addr)
if _, err := w.Write(descriptor[:]); err != nil {
return err
}
var ip [16]byte
copy(ip[:], e.IP.To16())
if _, err := w.Write(ip[:]); err != nil {
return err
}
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], uint16(e.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
case *tor.OnionAddr:
if e == nil {
return errors.New("cannot write nil onion address")
}
var suffixIndex int
switch len(e.OnionService) {
case tor.V2Len:
descriptor := []byte{byte(v2OnionAddr)}
if _, err := w.Write(descriptor); err != nil {
return err
}
suffixIndex = tor.V2Len - tor.OnionSuffixLen
case tor.V3Len:
descriptor := []byte{byte(v3OnionAddr)}
if _, err := w.Write(descriptor); err != nil {
return err
}
suffixIndex = tor.V3Len - tor.OnionSuffixLen
default:
return errors.New("unknown onion service length")
}
host, err := tor.Base32Encoding.DecodeString(
e.OnionService[:suffixIndex],
)
if err != nil {
return err
}
if _, err := w.Write(host); err != nil {
return err
}
var port [2]byte
binary.BigEndian.PutUint16(port[:], uint16(e.Port))
if _, err := w.Write(port[:]); err != nil {
return err
}
case []net.Addr:
// First, we'll encode all the addresses into an intermediate
// buffer. We need to do this in order to compute the total
// length of the addresses.
var addrBuf bytes.Buffer
for _, address := range e {
if err := WriteElement(&addrBuf, address); err != nil {
return err
}
}
// With the addresses fully encoded, we can now write out the
// number of bytes needed to encode them.
addrLen := addrBuf.Len()
if err := WriteElement(w, uint16(addrLen)); err != nil {
return err
}
// Finally, we'll write out the raw addresses themselves, but
// only if we have any bytes to write.
if addrLen > 0 {
if _, err := w.Write(addrBuf.Bytes()); err != nil {
return err
}
}
case color.RGBA:
if err := WriteElements(w, e.R, e.G, e.B); err != nil {
return err
}
case DeliveryAddress:
var length [2]byte
binary.BigEndian.PutUint16(length[:], uint16(len(e)))
if _, err := w.Write(length[:]); err != nil {
return err
}
if _, err := w.Write(e[:]); err != nil {
return err
}
case bool:
var b [1]byte
if e {
b[0] = 1
}
if _, err := w.Write(b[:]); err != nil {
return err
}
case ExtraOpaqueData:
return e.Encode(w)
default:
return fmt.Errorf("unknown type in WriteElement: %T", e)
}
return nil
}
// WriteElements is writes each element in the elements slice to the passed
// buffer using WriteElement.
//
// TODO(yy): rm this method once we finish dereferencing it from other
// packages.
func WriteElements(buf *bytes.Buffer, elements ...interface{}) error {
for _, element := range elements {
err := WriteElement(buf, element)
if err != nil {
return err
}
}
return nil
}
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of lnwire.
func ReadElement(r io.Reader, element interface{}) error {
var err error
switch e := element.(type) {
case *bool:
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
if b[0] == 1 {
*e = true
}
case *NodeAlias:
var a [32]byte
if _, err := io.ReadFull(r, a[:]); err != nil {
return err
}
alias, err := NewNodeAlias(string(a[:]))
if err != nil {
return err
}
*e = alias
case *QueryEncoding:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = QueryEncoding(b[0])
case *uint8:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = b[0]
case *FundingFlag:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = FundingFlag(b[0])
case *uint16:
var b [2]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint16(b[:])
case *ChanUpdateMsgFlags:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ChanUpdateMsgFlags(b[0])
case *ChanUpdateChanFlags:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ChanUpdateChanFlags(b[0])
case *uint32:
var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint32(b[:])
case *uint64:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = binary.BigEndian.Uint64(b[:])
case *MilliSatoshi:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = MilliSatoshi(int64(binary.BigEndian.Uint64(b[:])))
case *btcutil.Amount:
var b [8]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = btcutil.Amount(int64(binary.BigEndian.Uint64(b[:])))
case **btcec.PublicKey:
var b [btcec.PubKeyBytesLenCompressed]byte
if _, err = io.ReadFull(r, b[:]); err != nil {
return err
}
pubKey, err := btcec.ParsePubKey(b[:])
if err != nil {
return err
}
*e = pubKey
case *RawFeatureVector:
f := NewRawFeatureVector()
err = f.Decode(r)
if err != nil {
return err
}
*e = *f
case **RawFeatureVector:
f := NewRawFeatureVector()
err = f.Decode(r)
if err != nil {
return err
}
*e = f
case *[]Sig:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
numSigs := binary.BigEndian.Uint16(l[:])
var sigs []Sig
if numSigs > 0 {
sigs = make([]Sig, numSigs)
for i := 0; i < int(numSigs); i++ {
if err := ReadElement(r, &sigs[i]); err != nil {
return err
}
}
}
*e = sigs
case *Sig:
if _, err := io.ReadFull(r, e.bytes[:]); err != nil {
return err
}
case *OpaqueReason:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
reasonLen := binary.BigEndian.Uint16(l[:])
*e = OpaqueReason(make([]byte, reasonLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *WarningData:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
errorLen := binary.BigEndian.Uint16(l[:])
*e = WarningData(make([]byte, errorLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *ErrorData:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
errorLen := binary.BigEndian.Uint16(l[:])
*e = ErrorData(make([]byte, errorLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *PingPayload:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
pingLen := binary.BigEndian.Uint16(l[:])
*e = PingPayload(make([]byte, pingLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *PongPayload:
var l [2]byte
if _, err := io.ReadFull(r, l[:]); err != nil {
return err
}
pongLen := binary.BigEndian.Uint16(l[:])
*e = PongPayload(make([]byte, pongLen))
if _, err := io.ReadFull(r, *e); err != nil {
return err
}
case *[33]byte:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case []byte:
if _, err := io.ReadFull(r, e); err != nil {
return err
}
case *PkScript:
pkScript, err := wire.ReadVarBytes(r, 0, 34, "pkscript")
if err != nil {
return err
}
*e = pkScript
case *wire.OutPoint:
var h [32]byte
if _, err = io.ReadFull(r, h[:]); err != nil {
return err
}
hash, err := chainhash.NewHash(h[:])
if err != nil {
return err
}
var idxBytes [2]byte
_, err = io.ReadFull(r, idxBytes[:])
if err != nil {
return err
}
index := binary.BigEndian.Uint16(idxBytes[:])
*e = wire.OutPoint{
Hash: *hash,
Index: uint32(index),
}
case *FailCode:
if err := ReadElement(r, (*uint16)(e)); err != nil {
return err
}
case *ChannelID:
if _, err := io.ReadFull(r, e[:]); err != nil {
return err
}
case *ShortChannelID:
var blockHeight [4]byte
if _, err = io.ReadFull(r, blockHeight[1:]); err != nil {
return err
}
var txIndex [4]byte
if _, err = io.ReadFull(r, txIndex[1:]); err != nil {
return err
}
var txPosition [2]byte
if _, err = io.ReadFull(r, txPosition[:]); err != nil {
return err
}
*e = ShortChannelID{
BlockHeight: binary.BigEndian.Uint32(blockHeight[:]),
TxIndex: binary.BigEndian.Uint32(txIndex[:]),
TxPosition: binary.BigEndian.Uint16(txPosition[:]),
}
case *[]net.Addr:
// First, we'll read the number of total bytes that have been
// used to encode the set of addresses.
var numAddrsBytes [2]byte
if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil {
return err
}
addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:])
// With the number of addresses, read, we'll now pull in the
// buffer of the encoded addresses into memory.
addrs := make([]byte, addrsLen)
if _, err := io.ReadFull(r, addrs[:]); err != nil {
return err
}
addrBuf := bytes.NewReader(addrs)
// Finally, we'll parse the remaining address payload in
// series, using the first byte to denote how to decode the
// address itself.
var (
addresses []net.Addr
addrBytesRead uint16
)
for addrBytesRead < addrsLen {
var descriptor [1]byte
if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil {
return err
}
addrBytesRead++
var address net.Addr
switch aType := addressType(descriptor[0]); aType {
case noAddr:
addrBytesRead += aType.AddrLen()
continue
case tcp4Addr:
var ip [4]byte
if _, err := io.ReadFull(addrBuf, ip[:]); err != nil {
return err
}
var port [2]byte
if _, err := io.ReadFull(addrBuf, port[:]); err != nil {
return err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
addrBytesRead += aType.AddrLen()
case tcp6Addr:
var ip [16]byte
if _, err := io.ReadFull(addrBuf, ip[:]); err != nil {
return err
}
var port [2]byte
if _, err := io.ReadFull(addrBuf, port[:]); err != nil {
return err
}
address = &net.TCPAddr{
IP: net.IP(ip[:]),
Port: int(binary.BigEndian.Uint16(port[:])),
}
addrBytesRead += aType.AddrLen()
case v2OnionAddr:
var h [tor.V2DecodedLen]byte
if _, err := io.ReadFull(addrBuf, h[:]); err != nil {
return err
}
var p [2]byte
if _, err := io.ReadFull(addrBuf, p[:]); err != nil {
return err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
addrBytesRead += aType.AddrLen()
case v3OnionAddr:
var h [tor.V3DecodedLen]byte
if _, err := io.ReadFull(addrBuf, h[:]); err != nil {
return err
}
var p [2]byte
if _, err := io.ReadFull(addrBuf, p[:]); err != nil {
return err
}
onionService := tor.Base32Encoding.EncodeToString(h[:])
onionService += tor.OnionSuffix
port := int(binary.BigEndian.Uint16(p[:]))
address = &tor.OnionAddr{
OnionService: onionService,
Port: port,
}
addrBytesRead += aType.AddrLen()
default:
// If we don't understand this address type,
// we just store it along with the remaining
// address bytes as type OpaqueAddrs. We need
// to hold onto the bytes so that we can still
// write them back to the wire when we
// propagate this message.
payloadLen := 1 + addrsLen - addrBytesRead
payload := make([]byte, payloadLen)
// First write a byte for the address type that
// we already read.
payload[0] = byte(aType)
// Now append the rest of the address bytes.
_, err := io.ReadFull(addrBuf, payload[1:])
if err != nil {
return err
}
address = &OpaqueAddrs{
Payload: payload,
}
addrBytesRead = addrsLen
}
addresses = append(addresses, address)
}
*e = addresses
case *color.RGBA:
err := ReadElements(r,
&e.R,
&e.G,
&e.B,
)
if err != nil {
return err
}
case *DeliveryAddress:
var addrLen [2]byte
if _, err = io.ReadFull(r, addrLen[:]); err != nil {
return err
}
length := binary.BigEndian.Uint16(addrLen[:])
var addrBytes [deliveryAddressMaxSize]byte
if length > deliveryAddressMaxSize {
return fmt.Errorf(
"cannot read %d bytes into addrBytes", length,
)
}
if _, err = io.ReadFull(r, addrBytes[:length]); err != nil {
return err
}
*e = addrBytes[:length]
case *PartialSig:
var sig PartialSig
if err = sig.Decode(r); err != nil {
return err
}
*e = sig
case *ExtraOpaqueData:
return e.Decode(r)
default:
return fmt.Errorf("unknown type in ReadElement: %T", e)
}
return nil
}
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := ReadElement(r, element)
if err != nil {
return err
}
}
return nil
}
// Copyright (c) 2013-2017 The btcsuite developers
// Copyright (c) 2015-2016 The Decred developers
// code derived from https://github .com/btcsuite/btcd/blob/master/wire/message.go
// Copyright (C) 2015-2022 The Lightning Network Developers
package lnwire
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
// MessageTypeSize is the size in bytes of the message type field in the header
// of all messages.
const MessageTypeSize = 2
// MessageType is the unique 2 byte big-endian integer that indicates the type
// of message on the wire. All messages have a very simple header which
// consists simply of 2-byte message type. We omit a length field, and checksum
// as the Lightning Protocol is intended to be encapsulated within a
// confidential+authenticated cryptographic messaging protocol.
type MessageType uint16
// The currently defined message types within this current version of the
// Lightning protocol.
const (
MsgWarning MessageType = 1
MsgStfu = 2
MsgInit = 16
MsgError = 17
MsgPing = 18
MsgPong = 19
MsgOpenChannel = 32
MsgAcceptChannel = 33
MsgFundingCreated = 34
MsgFundingSigned = 35
MsgChannelReady = 36
MsgShutdown = 38
MsgClosingSigned = 39
MsgClosingComplete = 40
MsgClosingSig = 41
MsgDynPropose = 111
MsgDynAck = 113
MsgDynReject = 115
MsgUpdateAddHTLC = 128
MsgUpdateFulfillHTLC = 130
MsgUpdateFailHTLC = 131
MsgCommitSig = 132
MsgRevokeAndAck = 133
MsgUpdateFee = 134
MsgUpdateFailMalformedHTLC = 135
MsgChannelReestablish = 136
MsgChannelAnnouncement = 256
MsgNodeAnnouncement = 257
MsgChannelUpdate = 258
MsgAnnounceSignatures = 259
MsgAnnounceSignatures2 = 260
MsgQueryShortChanIDs = 261
MsgReplyShortChanIDsEnd = 262
MsgQueryChannelRange = 263
MsgReplyChannelRange = 264
MsgGossipTimestampRange = 265
MsgChannelAnnouncement2 = 267
MsgChannelUpdate2 = 271
MsgKickoffSig = 777
// MsgEnd defines the end of the official message range of the protocol.
// If a new message is added beyond this message, then this should be
// modified.
MsgEnd = 778
)
// IsChannelUpdate is a filter function that discerns channel update messages
// from the other messages in the Lightning Network Protocol.
func (t MessageType) IsChannelUpdate() bool {
switch t {
case MsgUpdateAddHTLC:
return true
case MsgUpdateFulfillHTLC:
return true
case MsgUpdateFailHTLC:
return true
case MsgUpdateFailMalformedHTLC:
return true
case MsgUpdateFee:
return true
default:
return false
}
}
// ErrorEncodeMessage is used when failed to encode the message payload.
func ErrorEncodeMessage(err error) error {
return fmt.Errorf("failed to encode message to buffer, got %w", err)
}
// ErrorWriteMessageType is used when failed to write the message type.
func ErrorWriteMessageType(err error) error {
return fmt.Errorf("failed to write message type, got %w", err)
}
// ErrorPayloadTooLarge is used when the payload size exceeds the
// MaxMsgBody.
func ErrorPayloadTooLarge(size int) error {
return fmt.Errorf(
"message payload is too large - encoded %d bytes, "+
"but maximum message payload is %d bytes",
size, MaxMsgBody,
)
}
// String return the string representation of message type.
func (t MessageType) String() string {
switch t {
case MsgWarning:
return "Warning"
case MsgStfu:
return "Stfu"
case MsgInit:
return "Init"
case MsgOpenChannel:
return "MsgOpenChannel"
case MsgAcceptChannel:
return "MsgAcceptChannel"
case MsgFundingCreated:
return "MsgFundingCreated"
case MsgFundingSigned:
return "MsgFundingSigned"
case MsgChannelReady:
return "ChannelReady"
case MsgShutdown:
return "Shutdown"
case MsgClosingSigned:
return "ClosingSigned"
case MsgDynPropose:
return "DynPropose"
case MsgDynAck:
return "DynAck"
case MsgDynReject:
return "DynReject"
case MsgKickoffSig:
return "KickoffSig"
case MsgUpdateAddHTLC:
return "UpdateAddHTLC"
case MsgUpdateFailHTLC:
return "UpdateFailHTLC"
case MsgUpdateFulfillHTLC:
return "UpdateFulfillHTLC"
case MsgCommitSig:
return "CommitSig"
case MsgRevokeAndAck:
return "RevokeAndAck"
case MsgUpdateFailMalformedHTLC:
return "UpdateFailMalformedHTLC"
case MsgChannelReestablish:
return "ChannelReestablish"
case MsgError:
return "Error"
case MsgChannelAnnouncement:
return "ChannelAnnouncement"
case MsgChannelUpdate:
return "ChannelUpdate"
case MsgNodeAnnouncement:
return "NodeAnnouncement"
case MsgPing:
return "Ping"
case MsgAnnounceSignatures:
return "AnnounceSignatures"
case MsgPong:
return "Pong"
case MsgUpdateFee:
return "UpdateFee"
case MsgQueryShortChanIDs:
return "QueryShortChanIDs"
case MsgReplyShortChanIDsEnd:
return "ReplyShortChanIDsEnd"
case MsgQueryChannelRange:
return "QueryChannelRange"
case MsgReplyChannelRange:
return "ReplyChannelRange"
case MsgGossipTimestampRange:
return "GossipTimestampRange"
case MsgClosingComplete:
return "ClosingComplete"
case MsgClosingSig:
return "ClosingSig"
case MsgAnnounceSignatures2:
return "MsgAnnounceSignatures2"
case MsgChannelAnnouncement2:
return "ChannelAnnouncement2"
case MsgChannelUpdate2:
return "ChannelUpdate2"
default:
return "<unknown>"
}
}
// UnknownMessage is an implementation of the error interface that allows the
// creation of an error in response to an unknown message.
type UnknownMessage struct {
messageType MessageType
}
// Error returns a human readable string describing the error.
//
// This is part of the error interface.
func (u *UnknownMessage) Error() string {
return fmt.Sprintf("unable to parse message of unknown type: %v",
u.messageType)
}
// Serializable is an interface which defines a lightning wire serializable
// object.
type Serializable interface {
// Decode reads the bytes stream and converts it to the object.
Decode(io.Reader, uint32) error
// Encode converts object to the bytes stream and write it into the
// write buffer.
Encode(*bytes.Buffer, uint32) error
}
// Message is an interface that defines a lightning wire protocol message. The
// interface is general in order to allow implementing types full control over
// the representation of its data.
type Message interface {
Serializable
MsgType() MessageType
}
// LinkUpdater is an interface implemented by most messages in BOLT 2 that are
// allowed to update the channel state.
type LinkUpdater interface {
// All LinkUpdater messages are messages and so we embed the interface
// so that we can treat it as a message if all we know about it is that
// it is a LinkUpdater message.
Message
// TargetChanID returns the channel id of the link for which this
// message is intended.
TargetChanID() ChannelID
}
// SizeableMessage is an interface that extends the base Message interface with
// a method to calculate the serialized size of a message.
type SizeableMessage interface {
Message
// SerializedSize returns the serialized size of the message in bytes.
// The returned size includes the message type header bytes.
SerializedSize() (uint32, error)
}
// MessageSerializedSize calculates the serialized size of a message in bytes.
// This is a helper function that can be used by all message types to implement
// the SerializedSize method.
func MessageSerializedSize(msg Message) (uint32, error) {
var buf bytes.Buffer
// Encode the message to the buffer.
if err := msg.Encode(&buf, 0); err != nil {
return 0, err
}
// Add the size of the message type.
return uint32(buf.Len()) + MessageTypeSize, nil
}
// makeEmptyMessage creates a new empty message of the proper concrete type
// based on the passed message type.
func makeEmptyMessage(msgType MessageType) (Message, error) {
var msg Message
switch msgType {
case MsgWarning:
msg = &Warning{}
case MsgStfu:
msg = &Stfu{}
case MsgInit:
msg = &Init{}
case MsgOpenChannel:
msg = &OpenChannel{}
case MsgAcceptChannel:
msg = &AcceptChannel{}
case MsgFundingCreated:
msg = &FundingCreated{}
case MsgFundingSigned:
msg = &FundingSigned{}
case MsgChannelReady:
msg = &ChannelReady{}
case MsgShutdown:
msg = &Shutdown{}
case MsgClosingSigned:
msg = &ClosingSigned{}
case MsgDynPropose:
msg = &DynPropose{}
case MsgDynAck:
msg = &DynAck{}
case MsgDynReject:
msg = &DynReject{}
case MsgKickoffSig:
msg = &KickoffSig{}
case MsgUpdateAddHTLC:
msg = &UpdateAddHTLC{}
case MsgUpdateFailHTLC:
msg = &UpdateFailHTLC{}
case MsgUpdateFulfillHTLC:
msg = &UpdateFulfillHTLC{}
case MsgCommitSig:
msg = &CommitSig{}
case MsgRevokeAndAck:
msg = &RevokeAndAck{}
case MsgUpdateFee:
msg = &UpdateFee{}
case MsgUpdateFailMalformedHTLC:
msg = &UpdateFailMalformedHTLC{}
case MsgChannelReestablish:
msg = &ChannelReestablish{}
case MsgError:
msg = &Error{}
case MsgChannelAnnouncement:
msg = &ChannelAnnouncement1{}
case MsgChannelUpdate:
msg = &ChannelUpdate1{}
case MsgNodeAnnouncement:
msg = &NodeAnnouncement{}
case MsgPing:
msg = &Ping{}
case MsgAnnounceSignatures:
msg = &AnnounceSignatures1{}
case MsgPong:
msg = &Pong{}
case MsgQueryShortChanIDs:
msg = &QueryShortChanIDs{}
case MsgReplyShortChanIDsEnd:
msg = &ReplyShortChanIDsEnd{}
case MsgQueryChannelRange:
msg = &QueryChannelRange{}
case MsgReplyChannelRange:
msg = &ReplyChannelRange{}
case MsgGossipTimestampRange:
msg = &GossipTimestampRange{}
case MsgClosingComplete:
msg = &ClosingComplete{}
case MsgClosingSig:
msg = &ClosingSig{}
case MsgAnnounceSignatures2:
msg = &AnnounceSignatures2{}
case MsgChannelAnnouncement2:
msg = &ChannelAnnouncement2{}
case MsgChannelUpdate2:
msg = &ChannelUpdate2{}
default:
// If the message is not within our custom range and has not
// specifically been overridden, return an unknown message.
//
// Note that we do not allow custom message overrides to replace
// known message types, only protocol messages that are not yet
// known to lnd.
if msgType < CustomTypeStart && !IsCustomOverride(msgType) {
return nil, &UnknownMessage{msgType}
}
msg = &Custom{
Type: msgType,
}
}
return msg, nil
}
// MakeEmptyMessage creates a new empty message of the proper concrete type
// based on the passed message type. This is exported to be used in tests.
func MakeEmptyMessage(msgType MessageType) (Message, error) {
return makeEmptyMessage(msgType)
}
// WriteMessage writes a lightning Message to a buffer including the necessary
// header information and returns the number of bytes written. If any error is
// encountered, the buffer passed will be reset to its original state since we
// don't want any broken bytes left. In other words, no bytes will be written
// if there's an error. Either all or none of the message bytes will be written
// to the buffer.
//
// NOTE: this method is not concurrent safe.
func WriteMessage(buf *bytes.Buffer, msg Message, pver uint32) (int, error) {
// Record the size of the bytes already written in buffer.
oldByteSize := buf.Len()
// cleanBrokenBytes is a helper closure that helps reset the buffer to
// its original state. It truncates all the bytes written in current
// scope.
var cleanBrokenBytes = func(b *bytes.Buffer) int {
b.Truncate(oldByteSize)
return 0
}
// Write the message type.
var mType [2]byte
binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType()))
msgTypeBytes, err := buf.Write(mType[:])
if err != nil {
return cleanBrokenBytes(buf), ErrorWriteMessageType(err)
}
// Use the write buffer to encode our message.
if err := msg.Encode(buf, pver); err != nil {
return cleanBrokenBytes(buf), ErrorEncodeMessage(err)
}
// Enforce maximum overall message payload. The write buffer now has
// the size of len(originalBytes) + len(payload) + len(type). We want
// to enforce the payload here, so we subtract it by the length of the
// type and old bytes.
lenp := buf.Len() - oldByteSize - msgTypeBytes
if lenp > MaxMsgBody {
return cleanBrokenBytes(buf), ErrorPayloadTooLarge(lenp)
}
return buf.Len() - oldByteSize, nil
}
// ReadMessage reads, validates, and parses the next Lightning message from r
// for the provided protocol version.
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
// First, we'll read out the first two bytes of the message so we can
// create the proper empty message.
var mType [2]byte
if _, err := io.ReadFull(r, mType[:]); err != nil {
return nil, err
}
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
// Now that we know the target message type, we can create the proper
// empty message type and decode the message into it.
msg, err := makeEmptyMessage(msgType)
if err != nil {
return nil, err
}
if err := msg.Decode(r, pver); err != nil {
return nil, err
}
return msg, nil
}
package lnwire
import (
"fmt"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// mSatScale is a value that's used to scale satoshis to milli-satoshis, and
// the other way around.
mSatScale uint64 = 1000
// MaxMilliSatoshi is the maximum number of msats that can be expressed
// in this data type.
MaxMilliSatoshi = ^MilliSatoshi(0)
)
// MilliSatoshi are the native unit of the Lightning Network. A milli-satoshi
// is simply 1/1000th of a satoshi. There are 1000 milli-satoshis in a single
// satoshi. Within the network, all HTLC payments are denominated in
// milli-satoshis. As milli-satoshis aren't deliverable on the native
// blockchain, before settling to broadcasting, the values are rounded down to
// the nearest satoshi.
type MilliSatoshi uint64
// NewMSatFromSatoshis creates a new MilliSatoshi instance from a target amount
// of satoshis.
func NewMSatFromSatoshis(sat btcutil.Amount) MilliSatoshi {
return MilliSatoshi(uint64(sat) * mSatScale)
}
// ToBTC converts the target MilliSatoshi amount to its corresponding value
// when expressed in BTC.
func (m MilliSatoshi) ToBTC() float64 {
sat := m.ToSatoshis()
return sat.ToBTC()
}
// ToSatoshis converts the target MilliSatoshi amount to satoshis. Simply, this
// sheds a factor of 1000 from the mSAT amount in order to convert it to SAT.
func (m MilliSatoshi) ToSatoshis() btcutil.Amount {
return btcutil.Amount(uint64(m) / mSatScale)
}
// String returns the string representation of the mSAT amount.
func (m MilliSatoshi) String() string {
return fmt.Sprintf("%v mSAT", uint64(m))
}
// TODO(roasbeef): extend with arithmetic operations?
// Record returns a TLV record that can be used to encode/decode a MilliSatoshi
// to/from a TLV stream.
func (m *MilliSatoshi) Record() tlv.Record {
return tlv.MakeDynamicRecord(
0, m, tlv.SizeBigSize(m), encodeMilliSatoshis,
decodeMilliSatoshis,
)
}
func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*MilliSatoshi); ok {
bigSize := uint64(*v)
return tlv.EBigSize(w, &bigSize, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi")
}
func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*MilliSatoshi); ok {
var bigSize uint64
err := tlv.DBigSize(r, &bigSize, buf, l)
if err != nil {
return err
}
*v = MilliSatoshi(bigSize)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l)
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/tlv"
)
// NonceRecordTypeT is the TLV type used to encode a local musig2 nonce.
type NonceRecordTypeT = tlv.TlvType4
// nonceRecordType is the TLV (integer) type used to encode a local musig2
// nonce.
var nonceRecordType tlv.Type = (NonceRecordTypeT)(nil).TypeVal()
type (
// Musig2Nonce represents a musig2 public nonce, which is the
// concatenation of two EC points serialized in compressed format.
Musig2Nonce [musig2.PubNonceSize]byte
// Musig2NonceTLV is a TLV type that can be used to encode/decode a
// musig2 nonce. This is an optional TLV.
Musig2NonceTLV = tlv.RecordT[NonceRecordTypeT, Musig2Nonce]
// OptMusig2NonceTLV is a TLV type that can be used to encode/decode a
// musig2 nonce.
OptMusig2NonceTLV = tlv.OptionalRecordT[NonceRecordTypeT, Musig2Nonce]
)
// Record returns a TLV record that can be used to encode/decode the musig2
// nonce from a given TLV stream.
func (m *Musig2Nonce) Record() tlv.Record {
return tlv.MakeStaticRecord(
nonceRecordType, m, musig2.PubNonceSize,
nonceTypeEncoder, nonceTypeDecoder,
)
}
// nonceTypeEncoder is a custom TLV encoder for the Musig2Nonce type.
func nonceTypeEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*Musig2Nonce); ok {
_, err := w.Write(v[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "lnwire.Musig2Nonce")
}
// nonceTypeDecoder is a custom TLV decoder for the Musig2Nonce record.
func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*Musig2Nonce); ok {
_, err := io.ReadFull(r, v[:])
return err
}
return tlv.NewTypeForDecodingErr(
val, "lnwire.Musig2Nonce", l, musig2.PubNonceSize,
)
}
// SomeMusig2Nonce is a helper function that creates a musig2 nonce TLV.
func SomeMusig2Nonce(nonce Musig2Nonce) OptMusig2NonceTLV {
return tlv.SomeRecordT(
tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce),
)
}
package lnwire
import (
"fmt"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
)
// NetAddress represents information pertaining to the identity and network
// reachability of a peer. Information stored includes the node's identity
// public key for establishing a confidential+authenticated connection, the
// service bits it supports, and a TCP address the node is reachable at.
//
// TODO(roasbeef): merge with LinkNode in some fashion
type NetAddress struct {
// IdentityKey is the long-term static public key for a node. This node is
// used throughout the network as a node's identity key. It is used to
// authenticate any data sent to the network on behalf of the node, and
// additionally to establish a confidential+authenticated connection with
// the node.
IdentityKey *btcec.PublicKey
// Address is the IP address and port of the node. This is left
// general so that multiple implementations can be used.
Address net.Addr
// ChainNet is the Bitcoin network this node is associated with.
// TODO(roasbeef): make a slice in the future for multi-chain
ChainNet wire.BitcoinNet
}
// A compile time assertion to ensure that NetAddress meets the net.Addr
// interface.
var _ net.Addr = (*NetAddress)(nil)
// String returns a human readable string describing the target NetAddress. The
// current string format is: <pubkey>@host.
//
// This part of the net.Addr interface.
func (n *NetAddress) String() string {
// TODO(roasbeef): use base58?
pubkey := n.IdentityKey.SerializeCompressed()
return fmt.Sprintf("%x@%v", pubkey, n.Address)
}
// Network returns the name of the network this address is bound to.
//
// This part of the net.Addr interface.
func (n *NetAddress) Network() string {
return n.Address.Network()
}
package lnwire
import (
"bytes"
"fmt"
"image/color"
"io"
"net"
"unicode/utf8"
)
// ErrUnknownAddrType is an error returned if we encounter an unknown address type
// when parsing addresses.
type ErrUnknownAddrType struct {
addrType addressType
}
// Error returns a human readable string describing the error.
//
// NOTE: implements the error interface.
func (e ErrUnknownAddrType) Error() string {
return fmt.Sprintf("unknown address type: %v", e.addrType)
}
// ErrInvalidNodeAlias is an error returned if a node alias we parse on the
// wire is invalid, as in it has non UTF-8 characters.
type ErrInvalidNodeAlias struct{}
// Error returns a human readable string describing the error.
//
// NOTE: implements the error interface.
func (e ErrInvalidNodeAlias) Error() string {
return "node alias has non-utf8 characters"
}
// NodeAlias is a hex encoded UTF-8 string that may be displayed as an
// alternative to the node's ID. Notice that aliases are not unique and may be
// freely chosen by the node operators.
type NodeAlias [32]byte
// NewNodeAlias creates a new instance of a NodeAlias. Verification is
// performed on the passed string to ensure it meets the alias requirements.
func NewNodeAlias(s string) (NodeAlias, error) {
var n NodeAlias
if len(s) > 32 {
return n, fmt.Errorf("alias too large: max is %v, got %v", 32,
len(s))
}
if !utf8.ValidString(s) {
return n, &ErrInvalidNodeAlias{}
}
copy(n[:], []byte(s))
return n, nil
}
// String returns a utf8 string representation of the alias bytes.
func (n NodeAlias) String() string {
// Trim trailing zero-bytes for presentation
return string(bytes.Trim(n[:], "\x00"))
}
// NodeAnnouncement message is used to announce the presence of a Lightning
// node and also to signal that the node is accepting incoming connections.
// Each NodeAnnouncement authenticating the advertised information within the
// announcement via a signature using the advertised node pubkey.
type NodeAnnouncement struct {
// Signature is used to prove the ownership of node id.
Signature Sig
// Features is the list of protocol features this node supports.
Features *RawFeatureVector
// Timestamp allows ordering in the case of multiple announcements.
Timestamp uint32
// NodeID is a public key which is used as node identification.
NodeID [33]byte
// RGBColor is used to customize their node's appearance in maps and
// graphs
RGBColor color.RGBA
// Alias is used to customize their node's appearance in maps and
// graphs
Alias NodeAlias
// Address includes two specification fields: 'ipv6' and 'port' on
// which the node is accepting incoming connections.
Addresses []net.Addr
// ExtraOpaqueData is the set of data that was appended to this
// message, some of which we may not actually know how to iterate or
// parse. By holding onto this data, we ensure that we're able to
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// A compile time check to ensure NodeAnnouncement implements the
// lnwire.Message interface.
var _ Message = (*NodeAnnouncement)(nil)
// A compile time check to ensure NodeAnnouncement implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*NodeAnnouncement)(nil)
// Decode deserializes a serialized NodeAnnouncement stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&a.Signature,
&a.Features,
&a.Timestamp,
&a.NodeID,
&a.RGBColor,
&a.Alias,
&a.Addresses,
&a.ExtraOpaqueData,
)
}
// Encode serializes the target NodeAnnouncement into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteSig(w, a.Signature); err != nil {
return err
}
if err := WriteRawFeatureVector(w, a.Features); err != nil {
return err
}
if err := WriteUint32(w, a.Timestamp); err != nil {
return err
}
if err := WriteBytes(w, a.NodeID[:]); err != nil {
return err
}
if err := WriteColorRGBA(w, a.RGBColor); err != nil {
return err
}
if err := WriteNodeAlias(w, a.Alias); err != nil {
return err
}
if err := WriteNetAddrs(w, a.Addresses); err != nil {
return err
}
return WriteBytes(w, a.ExtraOpaqueData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (a *NodeAnnouncement) MsgType() MessageType {
return MsgNodeAnnouncement
}
// DataToSign returns the part of the message that should be signed.
func (a *NodeAnnouncement) DataToSign() ([]byte, error) {
// We should not include the signatures itself.
buffer := make([]byte, 0, MaxMsgBody)
buf := bytes.NewBuffer(buffer)
if err := WriteRawFeatureVector(buf, a.Features); err != nil {
return nil, err
}
if err := WriteUint32(buf, a.Timestamp); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.NodeID[:]); err != nil {
return nil, err
}
if err := WriteColorRGBA(buf, a.RGBColor); err != nil {
return nil, err
}
if err := WriteNodeAlias(buf, a.Alias); err != nil {
return nil, err
}
if err := WriteNetAddrs(buf, a.Addresses); err != nil {
return nil, err
}
if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (a *NodeAnnouncement) SerializedSize() (uint32, error) {
return MessageSerializedSize(a)
}
package lnwire
import (
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/tlv"
)
// FailureMessage represents the onion failure object identified by its unique
// failure code.
type FailureMessage interface {
// Code returns a failure code describing the exact nature of the
// error.
Code() FailCode
// Error returns a human readable string describing the error. With
// this method, the FailureMessage interface meets the built-in error
// interface.
Error() string
}
// FailureMessageLength is the size of the failure message plus the size of
// padding. The FailureMessage message should always be EXACTLY this size.
const FailureMessageLength = 256
const (
// FlagBadOnion error flag describes an unparsable, encrypted by
// previous node.
FlagBadOnion FailCode = 0x8000
// FlagPerm error flag indicates a permanent failure.
FlagPerm FailCode = 0x4000
// FlagNode error flag indicates a node failure.
FlagNode FailCode = 0x2000
// FlagUpdate error flag indicates a new channel update is enclosed
// within the error.
FlagUpdate FailCode = 0x1000
)
// FailCode specifies the precise reason that an upstream HTLC was canceled.
// Each UpdateFailHTLC message carries a FailCode which is to be passed
// backwards, encrypted at each step back to the source of the HTLC within the
// route.
type FailCode uint16
// The currently defined onion failure types within this current version of the
// Lightning protocol.
const (
CodeNone FailCode = 0
CodeInvalidRealm = FlagBadOnion | 1
CodeTemporaryNodeFailure = FlagNode | 2
CodePermanentNodeFailure = FlagPerm | FlagNode | 2
CodeRequiredNodeFeatureMissing = FlagPerm | FlagNode | 3
CodeInvalidOnionVersion = FlagBadOnion | FlagPerm | 4
CodeInvalidOnionHmac = FlagBadOnion | FlagPerm | 5
CodeInvalidOnionKey = FlagBadOnion | FlagPerm | 6
CodeTemporaryChannelFailure = FlagUpdate | 7
CodePermanentChannelFailure = FlagPerm | 8
CodeRequiredChannelFeatureMissing = FlagPerm | 9
CodeUnknownNextPeer = FlagPerm | 10
CodeAmountBelowMinimum = FlagUpdate | 11
CodeFeeInsufficient = FlagUpdate | 12
CodeIncorrectCltvExpiry = FlagUpdate | 13
CodeExpiryTooSoon = FlagUpdate | 14
CodeChannelDisabled = FlagUpdate | 20
CodeIncorrectOrUnknownPaymentDetails = FlagPerm | 15
CodeIncorrectPaymentAmount = FlagPerm | 16
CodeFinalExpiryTooSoon FailCode = 17
CodeFinalIncorrectCltvExpiry FailCode = 18
CodeFinalIncorrectHtlcAmount FailCode = 19
CodeExpiryTooFar FailCode = 21
CodeInvalidOnionPayload = FlagPerm | 22
CodeMPPTimeout FailCode = 23
CodeInvalidBlinding = FlagBadOnion | FlagPerm | 24 //nolint:ll
)
// String returns the string representation of the failure code.
func (c FailCode) String() string {
switch c {
case CodeInvalidRealm:
return "InvalidRealm"
case CodeTemporaryNodeFailure:
return "TemporaryNodeFailure"
case CodePermanentNodeFailure:
return "PermanentNodeFailure"
case CodeRequiredNodeFeatureMissing:
return "RequiredNodeFeatureMissing"
case CodeInvalidOnionVersion:
return "InvalidOnionVersion"
case CodeInvalidOnionHmac:
return "InvalidOnionHmac"
case CodeInvalidOnionKey:
return "InvalidOnionKey"
case CodeTemporaryChannelFailure:
return "TemporaryChannelFailure"
case CodePermanentChannelFailure:
return "PermanentChannelFailure"
case CodeRequiredChannelFeatureMissing:
return "RequiredChannelFeatureMissing"
case CodeUnknownNextPeer:
return "UnknownNextPeer"
case CodeAmountBelowMinimum:
return "AmountBelowMinimum"
case CodeFeeInsufficient:
return "FeeInsufficient"
case CodeIncorrectCltvExpiry:
return "IncorrectCltvExpiry"
case CodeIncorrectPaymentAmount:
return "IncorrectPaymentAmount"
case CodeExpiryTooSoon:
return "ExpiryTooSoon"
case CodeChannelDisabled:
return "ChannelDisabled"
case CodeIncorrectOrUnknownPaymentDetails:
return "IncorrectOrUnknownPaymentDetails"
case CodeFinalExpiryTooSoon:
return "FinalExpiryTooSoon"
case CodeFinalIncorrectCltvExpiry:
return "FinalIncorrectCltvExpiry"
case CodeFinalIncorrectHtlcAmount:
return "FinalIncorrectHtlcAmount"
case CodeExpiryTooFar:
return "ExpiryTooFar"
case CodeInvalidOnionPayload:
return "InvalidOnionPayload"
case CodeMPPTimeout:
return "MPPTimeout"
case CodeInvalidBlinding:
return "InvalidBlinding"
default:
return "<unknown>"
}
}
// FailInvalidRealm is returned if the realm byte is unknown.
//
// NOTE: May be returned by any node in the payment route.
type FailInvalidRealm struct{}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidRealm) Error() string {
return f.Code().String()
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidRealm) Code() FailCode {
return CodeInvalidRealm
}
// FailTemporaryNodeFailure is returned if an otherwise unspecified transient
// error occurs for the entire node.
//
// NOTE: May be returned by any node in the payment route.
type FailTemporaryNodeFailure struct{}
// Code returns the failure unique code.
// NOTE: Part of the FailureMessage interface.
func (f *FailTemporaryNodeFailure) Code() FailCode {
return CodeTemporaryNodeFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailTemporaryNodeFailure) Error() string {
return f.Code().String()
}
// FailPermanentNodeFailure is returned if an otherwise unspecified permanent
// error occurs for the entire node.
//
// NOTE: May be returned by any node in the payment route.
type FailPermanentNodeFailure struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailPermanentNodeFailure) Code() FailCode {
return CodePermanentNodeFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailPermanentNodeFailure) Error() string {
return f.Code().String()
}
// FailRequiredNodeFeatureMissing is returned if a node has requirement
// advertised in its node_announcement features which were not present in the
// onion.
//
// NOTE: May be returned by any node in the payment route.
type FailRequiredNodeFeatureMissing struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailRequiredNodeFeatureMissing) Code() FailCode {
return CodeRequiredNodeFeatureMissing
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailRequiredNodeFeatureMissing) Error() string {
return f.Code().String()
}
// FailPermanentChannelFailure is return if an otherwise unspecified permanent
// error occurs for the outgoing channel (eg. channel (recently).
//
// NOTE: May be returned by any node in the payment route.
type FailPermanentChannelFailure struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailPermanentChannelFailure) Code() FailCode {
return CodePermanentChannelFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailPermanentChannelFailure) Error() string {
return f.Code().String()
}
// FailRequiredChannelFeatureMissing is returned if the outgoing channel has a
// requirement advertised in its channel announcement features which were not
// present in the onion.
//
// NOTE: May only be returned by intermediate nodes.
type FailRequiredChannelFeatureMissing struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailRequiredChannelFeatureMissing) Code() FailCode {
return CodeRequiredChannelFeatureMissing
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailRequiredChannelFeatureMissing) Error() string {
return f.Code().String()
}
// FailUnknownNextPeer is returned if the next peer specified by the onion is
// not known.
//
// NOTE: May only be returned by intermediate nodes.
type FailUnknownNextPeer struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailUnknownNextPeer) Code() FailCode {
return CodeUnknownNextPeer
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailUnknownNextPeer) Error() string {
return f.Code().String()
}
// FailIncorrectPaymentAmount is returned if the amount paid is less than the
// amount expected, the final node MUST fail the HTLC. If the amount paid is
// more than twice the amount expected, the final node SHOULD fail the HTLC.
// This allows the sender to reduce information leakage by altering the amount,
// without allowing accidental gross overpayment.
//
// NOTE: May only be returned by the final node in the path.
type FailIncorrectPaymentAmount struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectPaymentAmount) Code() FailCode {
return CodeIncorrectPaymentAmount
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailIncorrectPaymentAmount) Error() string {
return f.Code().String()
}
// FailIncorrectDetails is returned for two reasons:
//
// 1) if the payment hash has already been paid, the final node MAY treat the
// payment hash as unknown, or may succeed in accepting the HTLC. If the
// payment hash is unknown, the final node MUST fail the HTLC.
//
// 2) if the amount paid is less than the amount expected, the final node MUST
// fail the HTLC. If the amount paid is more than twice the amount expected,
// the final node SHOULD fail the HTLC. This allows the sender to reduce
// information leakage by altering the amount, without allowing accidental
// gross overpayment.
//
// NOTE: May only be returned by the final node in the path.
type FailIncorrectDetails struct {
// amount is the value of the extended HTLC.
amount MilliSatoshi
// height is the block height when the htlc was received.
height uint32
// extraOpaqueData contains additional failure message tlv data.
extraOpaqueData ExtraOpaqueData
}
// NewFailIncorrectDetails makes a new instance of the FailIncorrectDetails
// error bound to the specified HTLC amount and acceptance height.
func NewFailIncorrectDetails(amt MilliSatoshi,
height uint32) *FailIncorrectDetails {
return &FailIncorrectDetails{
amount: amt,
height: height,
extraOpaqueData: []byte{},
}
}
// Amount is the value of the extended HTLC.
func (f *FailIncorrectDetails) Amount() MilliSatoshi {
return f.amount
}
// Height is the block height when the htlc was received.
func (f *FailIncorrectDetails) Height() uint32 {
return f.height
}
// ExtraOpaqueData returns additional failure message tlv data.
func (f *FailIncorrectDetails) ExtraOpaqueData() ExtraOpaqueData {
return f.extraOpaqueData
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectDetails) Code() FailCode {
return CodeIncorrectOrUnknownPaymentDetails
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailIncorrectDetails) Error() string {
return fmt.Sprintf(
"%v(amt=%v, height=%v)", CodeIncorrectOrUnknownPaymentDetails,
f.amount, f.height,
)
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error {
err := ReadElement(r, &f.amount)
switch {
// This is an optional tack on that was added later in the protocol. As
// a result, older nodes may not include this value. We'll account for
// this by checking for io.EOF here which means that no bytes were read
// at all.
case err == io.EOF:
return nil
case err != nil:
return err
}
// At a later stage, the height field was also tacked on. We need to
// check for io.EOF here as well.
err = ReadElement(r, &f.height)
switch {
case err == io.EOF:
return nil
case err != nil:
return err
}
return f.extraOpaqueData.Decode(r)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectDetails) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteMilliSatoshi(w, f.amount); err != nil {
return err
}
if err := WriteUint32(w, f.height); err != nil {
return err
}
return f.extraOpaqueData.Encode(w)
}
// FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final
// node MUST fail the HTLC.
//
// NOTE: May only be returned by the final node in the path.
type FailFinalExpiryTooSoon struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalExpiryTooSoon) Code() FailCode {
return CodeFinalExpiryTooSoon
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalExpiryTooSoon) Error() string {
return f.Code().String()
}
// NewFinalExpiryTooSoon creates new instance of the FailFinalExpiryTooSoon.
func NewFinalExpiryTooSoon() *FailFinalExpiryTooSoon {
return &FailFinalExpiryTooSoon{}
}
// FailInvalidOnionVersion is returned if the onion version byte is unknown.
//
// NOTE: May be returned only by intermediate nodes.
type FailInvalidOnionVersion struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionVersion) Error() string {
return fmt.Sprintf("InvalidOnionVersion(onion_sha=%x)", f.OnionSHA256[:])
}
// NewInvalidOnionVersion creates new instance of the FailInvalidOnionVersion.
func NewInvalidOnionVersion(onion []byte) *FailInvalidOnionVersion {
return &FailInvalidOnionVersion{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionVersion) Code() FailCode {
return CodeInvalidOnionVersion
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionVersion) Encode(w *bytes.Buffer, pver uint32) error {
return WriteBytes(w, f.OnionSHA256[:])
}
// FailInvalidOnionHmac is return if the onion HMAC is incorrect.
//
// NOTE: May only be returned by intermediate nodes.
type FailInvalidOnionHmac struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// NewInvalidOnionHmac creates new instance of the FailInvalidOnionHmac.
func NewInvalidOnionHmac(onion []byte) *FailInvalidOnionHmac {
return &FailInvalidOnionHmac{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionHmac) Code() FailCode {
return CodeInvalidOnionHmac
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionHmac) Encode(w *bytes.Buffer, pver uint32) error {
return WriteBytes(w, f.OnionSHA256[:])
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionHmac) Error() string {
return fmt.Sprintf("InvalidOnionHMAC(onion_sha=%x)", f.OnionSHA256[:])
}
// FailInvalidOnionKey is return if the ephemeral key in the onion is
// unparsable.
//
// NOTE: May only be returned by intermediate nodes.
type FailInvalidOnionKey struct {
// OnionSHA256 hash of the onion blob which haven't been proceeded.
OnionSHA256 [sha256.Size]byte
}
// NewInvalidOnionKey creates new instance of the FailInvalidOnionKey.
func NewInvalidOnionKey(onion []byte) *FailInvalidOnionKey {
return &FailInvalidOnionKey{OnionSHA256: sha256.Sum256(onion)}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidOnionKey) Code() FailCode {
return CodeInvalidOnionKey
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidOnionKey) Encode(w *bytes.Buffer, pver uint32) error {
return WriteBytes(w, f.OnionSHA256[:])
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidOnionKey) Error() string {
return fmt.Sprintf("InvalidOnionKey(onion_sha=%x)", f.OnionSHA256[:])
}
// parseChannelUpdateCompatibilityMode will attempt to parse a channel updated
// encoded into an onion error payload in two ways. First, we'll try the
// compatibility oriented version wherein we'll _skip_ the length prefixing on
// the channel update message. Older versions of c-lighting do this so we'll
// attempt to parse these messages in order to retain compatibility. If we're
// unable to pull out a fully valid version, then we'll fall back to the
// regular parsing mechanism which includes the length prefix an NO type byte.
func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16,
chanUpdate *ChannelUpdate1, pver uint32) error {
// Instantiate a LimitReader because there may be additional data
// present after the channel update. Without limiting the stream, the
// additional data would be interpreted as channel update tlv data.
limitReader := io.LimitReader(reader, int64(length))
r := bufio.NewReader(limitReader)
// We'll peek out two bytes from the buffer without advancing the
// buffer so we can decide how to parse the remainder of it.
maybeTypeBytes, err := r.Peek(2)
if err != nil {
return err
}
// Some nodes well prefix an additional set of bytes in front of their
// channel updates. These bytes will _almost_ always be 258 or the type
// of the ChannelUpdate message.
typeInt := binary.BigEndian.Uint16(maybeTypeBytes)
if typeInt == MsgChannelUpdate {
// At this point it's likely the case that this is a channel
// update message with its type prefixed, so we'll snip off the
// first two bytes and parse it as normal.
var throwAwayTypeBytes [2]byte
_, err := r.Read(throwAwayTypeBytes[:])
if err != nil {
return err
}
}
// At this pint, we've either decided to keep the entire thing, or snip
// off the first two bytes. In either case, we can just read it as
// normal.
return chanUpdate.Decode(r, pver)
}
// FailTemporaryChannelFailure is if an otherwise unspecified transient error
// occurs for the outgoing channel (eg. channel capacity reached, too many
// in-flight htlcs)
//
// NOTE: May only be returned by intermediate nodes.
type FailTemporaryChannelFailure struct {
// Update is used to update information about state of the channel
// which caused the failure.
//
// NOTE: This field is optional.
Update *ChannelUpdate1
}
// NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure.
func NewTemporaryChannelFailure(
update *ChannelUpdate1) *FailTemporaryChannelFailure {
return &FailTemporaryChannelFailure{Update: update}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailTemporaryChannelFailure) Code() FailCode {
return CodeTemporaryChannelFailure
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailTemporaryChannelFailure) Error() string {
if f.Update == nil {
return f.Code().String()
}
return fmt.Sprintf("TemporaryChannelFailure(update=%v)",
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error {
var length uint16
err := ReadElement(r, &length)
if err != nil {
return err
}
if length != 0 {
f.Update = &ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, f.Update, pver,
)
}
return nil
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailTemporaryChannelFailure) Encode(w *bytes.Buffer,
pver uint32) error {
if f.Update != nil {
return writeOnionErrorChanUpdate(w, f.Update, pver)
}
// Write zero length to indicate no channel_update is present.
return WriteUint16(w, 0)
}
// FailAmountBelowMinimum is returned if the HTLC does not reach the current
// minimum amount, we tell them the amount of the incoming HTLC and the current
// channel setting for the outgoing channel.
//
// NOTE: May only be returned by the intermediate nodes in the path.
type FailAmountBelowMinimum struct {
// HtlcMsat is the wrong amount of the incoming HTLC.
HtlcMsat MilliSatoshi
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate1
}
// NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum.
func NewAmountBelowMinimum(htlcMsat MilliSatoshi,
update ChannelUpdate1) *FailAmountBelowMinimum {
return &FailAmountBelowMinimum{
HtlcMsat: htlcMsat,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailAmountBelowMinimum) Code() FailCode {
return CodeAmountBelowMinimum
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailAmountBelowMinimum) Error() string {
return fmt.Sprintf("AmountBelowMinimum(amt=%v, update=%v", f.HtlcMsat,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.HtlcMsat); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteMilliSatoshi(w, f.HtlcMsat); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we
// tell them the amount of the incoming HTLC and the current channel setting
// for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailFeeInsufficient struct {
// HtlcMsat is the wrong amount of the incoming HTLC.
HtlcMsat MilliSatoshi
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate1
}
// NewFeeInsufficient creates new instance of the FailFeeInsufficient.
func NewFeeInsufficient(htlcMsat MilliSatoshi,
update ChannelUpdate1) *FailFeeInsufficient {
return &FailFeeInsufficient{
HtlcMsat: htlcMsat,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFeeInsufficient) Code() FailCode {
return CodeFeeInsufficient
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFeeInsufficient) Error() string {
return fmt.Sprintf("FeeInsufficient(htlc_amt==%v, update=%v", f.HtlcMsat,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.HtlcMsat); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteMilliSatoshi(w, f.HtlcMsat); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match
// the update add htlc's cltv expiry minus cltv expiry delta for the outgoing
// channel, we tell them the cltv expiry and the current channel setting for
// the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailIncorrectCltvExpiry struct {
// CltvExpiry is the wrong absolute timeout in blocks, after which
// outgoing HTLC expires.
CltvExpiry uint32
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate1
}
// NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry.
func NewIncorrectCltvExpiry(cltvExpiry uint32,
update ChannelUpdate1) *FailIncorrectCltvExpiry {
return &FailIncorrectCltvExpiry{
CltvExpiry: cltvExpiry,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailIncorrectCltvExpiry) Code() FailCode {
return CodeIncorrectCltvExpiry
}
func (f *FailIncorrectCltvExpiry) Error() string {
return fmt.Sprintf("IncorrectCltvExpiry(expiry=%v, update=%v",
f.CltvExpiry, spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.CltvExpiry); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteUint32(w, f.CltvExpiry); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them
// the current channel setting for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailExpiryTooSoon struct {
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate1
}
// NewExpiryTooSoon creates new instance of the FailExpiryTooSoon.
func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon {
return &FailExpiryTooSoon{
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailExpiryTooSoon) Code() FailCode {
return CodeExpiryTooSoon
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailExpiryTooSoon) Error() string {
return fmt.Sprintf("ExpiryTooSoon(update=%v", spew.Sdump(f.Update))
}
// Decode decodes the failure from l stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error {
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailExpiryTooSoon) Encode(w *bytes.Buffer, pver uint32) error {
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailChannelDisabled is returned if the channel is disabled, we tell them the
// current channel setting for the outgoing channel.
//
// NOTE: May only be returned by intermediate nodes.
type FailChannelDisabled struct {
// Flags least-significant bit must be set to 0 if the creating node
// corresponds to the first node in the previously sent channel
// announcement and 1 otherwise.
Flags uint16
// Update is used to update information about state of the channel
// which caused the failure.
Update ChannelUpdate1
}
// NewChannelDisabled creates new instance of the FailChannelDisabled.
func NewChannelDisabled(flags uint16,
update ChannelUpdate1) *FailChannelDisabled {
return &FailChannelDisabled{
Flags: flags,
Update: update,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailChannelDisabled) Code() FailCode {
return CodeChannelDisabled
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailChannelDisabled) Error() string {
return fmt.Sprintf("ChannelDisabled(flags=%v, update=%v", f.Flags,
spew.Sdump(f.Update))
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error {
if err := ReadElement(r, &f.Flags); err != nil {
return err
}
var length uint16
if err := ReadElement(r, &length); err != nil {
return err
}
f.Update = ChannelUpdate1{}
return parseChannelUpdateCompatibilityMode(
r, length, &f.Update, pver,
)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteUint16(w, f.Flags); err != nil {
return err
}
return writeOnionErrorChanUpdate(w, &f.Update, pver)
}
// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not
// match the ctlv_expiry of the HTLC at the final hop.
//
// NOTE: might be returned by final node only.
type FailFinalIncorrectCltvExpiry struct {
// CltvExpiry is the wrong absolute timeout in blocks, after which
// outgoing HTLC expires.
CltvExpiry uint32
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalIncorrectCltvExpiry) Error() string {
return fmt.Sprintf("FinalIncorrectCltvExpiry(expiry=%v)", f.CltvExpiry)
}
// NewFinalIncorrectCltvExpiry creates new instance of the
// FailFinalIncorrectCltvExpiry.
func NewFinalIncorrectCltvExpiry(cltvExpiry uint32) *FailFinalIncorrectCltvExpiry {
return &FailFinalIncorrectCltvExpiry{
CltvExpiry: cltvExpiry,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalIncorrectCltvExpiry) Code() FailCode {
return CodeFinalIncorrectCltvExpiry
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, &f.CltvExpiry)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectCltvExpiry) Encode(w *bytes.Buffer,
pver uint32) error {
return WriteUint32(w, f.CltvExpiry)
}
// FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher
// than incoming_htlc_amt of the HTLC at the final hop.
//
// NOTE: May only be returned by the final node.
type FailFinalIncorrectHtlcAmount struct {
// IncomingHTLCAmount is the wrong forwarded htlc amount.
IncomingHTLCAmount MilliSatoshi
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailFinalIncorrectHtlcAmount) Error() string {
return fmt.Sprintf("FinalIncorrectHtlcAmount(amt=%v)",
f.IncomingHTLCAmount)
}
// NewFinalIncorrectHtlcAmount creates new instance of the
// FailFinalIncorrectHtlcAmount.
func NewFinalIncorrectHtlcAmount(amount MilliSatoshi) *FailFinalIncorrectHtlcAmount {
return &FailFinalIncorrectHtlcAmount{
IncomingHTLCAmount: amount,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailFinalIncorrectHtlcAmount) Code() FailCode {
return CodeFinalIncorrectHtlcAmount
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error {
return ReadElement(r, &f.IncomingHTLCAmount)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailFinalIncorrectHtlcAmount) Encode(w *bytes.Buffer,
pver uint32) error {
return WriteMilliSatoshi(w, f.IncomingHTLCAmount)
}
// FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the
// future.
//
// NOTE: May be returned by any node in the payment route.
type FailExpiryTooFar struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailExpiryTooFar) Code() FailCode {
return CodeExpiryTooFar
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailExpiryTooFar) Error() string {
return f.Code().String()
}
// InvalidOnionPayload is returned if the hop could not process the TLV payload
// enclosed in the onion.
type InvalidOnionPayload struct {
// Type is the TLV type that caused the specific failure.
Type uint64
// Offset is the byte offset within the payload where the failure
// occurred.
Offset uint16
}
// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure.
func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload {
return &InvalidOnionPayload{
Type: typ,
Offset: offset,
}
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *InvalidOnionPayload) Code() FailCode {
return CodeInvalidOnionPayload
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *InvalidOnionPayload) Error() string {
return fmt.Sprintf("%v(type=%v, offset=%d)",
f.Code(), f.Type, f.Offset)
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error {
var buf [8]byte
typ, err := tlv.ReadVarInt(r, &buf)
if err != nil {
return err
}
f.Type = typ
return ReadElements(r, &f.Offset)
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *InvalidOnionPayload) Encode(w *bytes.Buffer, pver uint32) error {
var buf [8]byte
if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil {
return err
}
return WriteUint16(w, f.Offset)
}
// FailMPPTimeout is returned if the complete amount for a multi part payment
// was not received within a reasonable time.
//
// NOTE: May only be returned by the final node in the path.
type FailMPPTimeout struct{}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailMPPTimeout) Code() FailCode {
return CodeMPPTimeout
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailMPPTimeout) Error() string {
return f.Code().String()
}
// FailInvalidBlinding is returned if there has been a route blinding related
// error.
type FailInvalidBlinding struct {
OnionSHA256 [sha256.Size]byte
}
// Code returns the failure unique code.
//
// NOTE: Part of the FailureMessage interface.
func (f *FailInvalidBlinding) Code() FailCode {
return CodeInvalidBlinding
}
// Returns a human readable string describing the target FailureMessage.
//
// NOTE: Implements the error interface.
func (f *FailInvalidBlinding) Error() string {
return f.Code().String()
}
// Decode decodes the failure from bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Decode(r io.Reader, _ uint32) error {
return ReadElement(r, f.OnionSHA256[:])
}
// Encode writes the failure in bytes stream.
//
// NOTE: Part of the Serializable interface.
func (f *FailInvalidBlinding) Encode(w *bytes.Buffer, _ uint32) error {
return WriteBytes(w, f.OnionSHA256[:])
}
// NewInvalidBlinding creates new instance of FailInvalidBlinding.
func NewInvalidBlinding(
onion fn.Option[[OnionPacketSize]byte]) *FailInvalidBlinding {
// The spec allows empty onion hashes for invalid blinding, so we only
// include our onion hash if it's provided.
if onion.IsNone() {
return &FailInvalidBlinding{}
}
shaSum := fn.MapOptionZ(onion, func(o [OnionPacketSize]byte) [32]byte {
return sha256.Sum256(o[:])
})
return &FailInvalidBlinding{OnionSHA256: shaSum}
}
// DecodeFailure decodes, validates, and parses the lnwire onion failure, for
// the provided protocol version.
func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
// First, we'll parse out the encapsulated failure message itself. This
// is a 2 byte length followed by the payload itself.
var failureLength uint16
if err := ReadElement(r, &failureLength); err != nil {
return nil, fmt.Errorf("unable to read failure len: %w", err)
}
failureData := make([]byte, failureLength)
if _, err := io.ReadFull(r, failureData); err != nil {
return nil, fmt.Errorf("unable to full read payload of "+
"%v: %w", failureLength, err)
}
// Read the padding.
var padLength uint16
if err := ReadElement(r, &padLength); err != nil {
return nil, fmt.Errorf("unable to read pad len: %w", err)
}
if _, err := io.CopyN(io.Discard, r, int64(padLength)); err != nil {
return nil, fmt.Errorf("unable to read padding %w", err)
}
// Verify that we are at the end of the stream now.
scratch := make([]byte, 1)
_, err := r.Read(scratch)
if err != io.EOF {
return nil, fmt.Errorf("unexpected failure bytes")
}
// Check the total length. Convert to 32 bits to prevent overflow.
totalLength := uint32(padLength) + uint32(failureLength)
if totalLength < FailureMessageLength {
return nil, fmt.Errorf("failure message too short: "+
"msg=%v, pad=%v, total=%v",
failureLength, padLength, totalLength)
}
// Decode the failure message.
dataReader := bytes.NewReader(failureData)
return DecodeFailureMessage(dataReader, pver)
}
// DecodeFailureMessage decodes just the failure message, ignoring any padding
// that may be present at the end.
func DecodeFailureMessage(r io.Reader, pver uint32) (FailureMessage, error) {
// Once we have the failure data, we can obtain the failure code from
// the first two bytes of the buffer.
var codeBytes [2]byte
if _, err := io.ReadFull(r, codeBytes[:]); err != nil {
return nil, fmt.Errorf("unable to read failure code: %w", err)
}
failCode := FailCode(binary.BigEndian.Uint16(codeBytes[:]))
// Create the empty failure by given code and populate the failure with
// additional data if needed.
failure, err := makeEmptyOnionError(failCode)
if err != nil {
return nil, fmt.Errorf("unable to make empty error: %w", err)
}
// Finally, if this failure has a payload, then we'll read that now as
// well.
switch f := failure.(type) {
case Serializable:
if err := f.Decode(r, pver); err != nil {
return nil, fmt.Errorf("unable to decode error "+
"update (type=%T): %v", failure, err)
}
}
return failure, nil
}
// EncodeFailure encodes, including the necessary onion failure header
// information.
func EncodeFailure(w *bytes.Buffer, failure FailureMessage, pver uint32) error {
var failureMessageBuffer bytes.Buffer
err := EncodeFailureMessage(&failureMessageBuffer, failure, pver)
if err != nil {
return err
}
// The combined size of this message must be below the max allowed
// failure message length.
failureMessage := failureMessageBuffer.Bytes()
if len(failureMessage) > FailureMessageLength {
return fmt.Errorf("failure message exceed max "+
"available size: %v", len(failureMessage))
}
// Finally, we'll add some padding in order to ensure that all failure
// messages are fixed size.
pad := make([]byte, FailureMessageLength-len(failureMessage))
if err := WriteUint16(w, uint16(len(failureMessage))); err != nil {
return err
}
if err := WriteBytes(w, failureMessage); err != nil {
return err
}
if err := WriteUint16(w, uint16(len(pad))); err != nil {
return err
}
return WriteBytes(w, pad)
}
// EncodeFailureMessage encodes just the failure message without adding a length
// and padding the message for the onion protocol.
func EncodeFailureMessage(w *bytes.Buffer,
failure FailureMessage, pver uint32) error {
// First, we'll write out the error code itself into the failure
// buffer.
var codeBytes [2]byte
code := uint16(failure.Code())
binary.BigEndian.PutUint16(codeBytes[:], code)
_, err := w.Write(codeBytes[:])
if err != nil {
return err
}
// Next, some message have an additional message payload, if this is
// one of those types, then we'll also encode the error payload as
// well.
switch failure := failure.(type) {
case Serializable:
if err := failure.Encode(w, pver); err != nil {
return err
}
}
return nil
}
// makeEmptyOnionError creates a new empty onion error of the proper concrete
// type based on the passed failure code.
func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
switch code {
case CodeInvalidRealm:
return &FailInvalidRealm{}, nil
case CodeTemporaryNodeFailure:
return &FailTemporaryNodeFailure{}, nil
case CodePermanentNodeFailure:
return &FailPermanentNodeFailure{}, nil
case CodeRequiredNodeFeatureMissing:
return &FailRequiredNodeFeatureMissing{}, nil
case CodePermanentChannelFailure:
return &FailPermanentChannelFailure{}, nil
case CodeRequiredChannelFeatureMissing:
return &FailRequiredChannelFeatureMissing{}, nil
case CodeUnknownNextPeer:
return &FailUnknownNextPeer{}, nil
case CodeIncorrectOrUnknownPaymentDetails:
return &FailIncorrectDetails{}, nil
case CodeIncorrectPaymentAmount:
return &FailIncorrectPaymentAmount{}, nil
case CodeFinalExpiryTooSoon:
return &FailFinalExpiryTooSoon{}, nil
case CodeInvalidOnionVersion:
return &FailInvalidOnionVersion{}, nil
case CodeInvalidOnionHmac:
return &FailInvalidOnionHmac{}, nil
case CodeInvalidOnionKey:
return &FailInvalidOnionKey{}, nil
case CodeTemporaryChannelFailure:
return &FailTemporaryChannelFailure{}, nil
case CodeAmountBelowMinimum:
return &FailAmountBelowMinimum{}, nil
case CodeFeeInsufficient:
return &FailFeeInsufficient{}, nil
case CodeIncorrectCltvExpiry:
return &FailIncorrectCltvExpiry{}, nil
case CodeExpiryTooSoon:
return &FailExpiryTooSoon{}, nil
case CodeChannelDisabled:
return &FailChannelDisabled{}, nil
case CodeFinalIncorrectCltvExpiry:
return &FailFinalIncorrectCltvExpiry{}, nil
case CodeFinalIncorrectHtlcAmount:
return &FailFinalIncorrectHtlcAmount{}, nil
case CodeExpiryTooFar:
return &FailExpiryTooFar{}, nil
case CodeInvalidOnionPayload:
return &InvalidOnionPayload{}, nil
case CodeMPPTimeout:
return &FailMPPTimeout{}, nil
case CodeInvalidBlinding:
return &FailInvalidBlinding{}, nil
default:
return nil, errors.Errorf("unknown error code: %v", code)
}
}
// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error
// format. The format is that we first write out the true serialized length of
// the channel update, followed by the serialized channel update itself.
func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1,
pver uint32) error {
// First, we encode the channel update in a temporary buffer in order
// to get the exact serialized size.
var b bytes.Buffer
updateLen, err := WriteMessage(&b, chanUpdate, pver)
if err != nil {
return err
}
// Now that we know the size, we can write the length out in the main
// writer.
if err := WriteUint16(w, uint16(updateLen)); err != nil {
return err
}
// With the length written, we'll then write out the serialized channel
// update.
if _, err := w.Write(b.Bytes()); err != nil {
return err
}
return nil
}
package lnwire
import (
"encoding/hex"
"net"
)
// OpaqueAddrs is used to store the address bytes for address types that are
// unknown to us.
type OpaqueAddrs struct {
Payload []byte
}
// A compile-time assertion to ensure that OpaqueAddrs meets the net.Addr
// interface.
var _ net.Addr = (*OpaqueAddrs)(nil)
// String returns a human-readable string describing the target OpaqueAddrs.
// Since this is an unknown address (and could even be multiple addresses), we
// just return the hex string of the payload.
//
// This part of the net.Addr interface.
func (o *OpaqueAddrs) String() string {
return hex.EncodeToString(o.Payload)
}
// Network returns the name of the network this address is bound to. Since this
// is an unknown address, we don't know the network and so just return a string
// indicating this.
//
// This part of the net.Addr interface.
func (o *OpaqueAddrs) Network() string {
return "unknown network for unrecognized address type"
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// FundingFlag represents the possible bit mask values for the ChannelFlags
// field within the OpenChannel struct.
type FundingFlag uint8
const (
// FFAnnounceChannel is a FundingFlag that when set, indicates the
// initiator of a funding flow wishes to announce the channel to the
// greater network.
FFAnnounceChannel FundingFlag = 1 << iota
)
// OpenChannel is the message Alice sends to Bob if we should like to create a
// channel with Bob where she's the sole provider of funds to the channel.
// Single funder channels simplify the initial funding workflow, are supported
// by nodes backed by SPV Bitcoin clients, and have a simpler security models
// than dual funded channels.
type OpenChannel struct {
// ChainHash is the target chain that the initiator wishes to open a
// channel within.
ChainHash chainhash.Hash
// PendingChannelID serves to uniquely identify the future channel
// created by the initiated single funder workflow.
PendingChannelID [32]byte
// FundingAmount is the amount of satoshis that the initiator of the
// channel wishes to use as the total capacity of the channel. The
// initial balance of the funding will be this value minus the push
// amount (if set).
FundingAmount btcutil.Amount
// PushAmount is the value that the initiating party wishes to "push"
// to the responding as part of the first commitment state. If the
// responder accepts, then this will be their initial balance.
PushAmount MilliSatoshi
// DustLimit is the specific dust limit the sender of this message
// would like enforced on their version of the commitment transaction.
// Any output below this value will be "trimmed" from the commitment
// transaction, with the amount of the HTLC going to dust.
DustLimit btcutil.Amount
// MaxValueInFlight represents the maximum amount of coins that can be
// pending within the channel at any given time. If the amount of funds
// in limbo exceeds this amount, then the channel will be failed.
MaxValueInFlight MilliSatoshi
// ChannelReserve is the amount of BTC that the receiving party MUST
// maintain a balance above at all times. This is a safety mechanism to
// ensure that both sides always have skin in the game during the
// channel's lifetime.
ChannelReserve btcutil.Amount
// HtlcMinimum is the smallest HTLC that the sender of this message
// will accept.
HtlcMinimum MilliSatoshi
// FeePerKiloWeight is the initial fee rate that the initiator suggests
// for both commitment transaction. This value is expressed in sat per
// kilo-weight.
//
// TODO(halseth): make SatPerKWeight when fee estimation is in own
// package. Currently this will cause an import cycle.
FeePerKiloWeight uint32
// CsvDelay is the number of blocks to use for the relative time lock
// in the pay-to-self output of both commitment transactions.
CsvDelay uint16
// MaxAcceptedHTLCs is the total number of incoming HTLC's that the
// sender of this channel will accept.
MaxAcceptedHTLCs uint16
// FundingKey is the key that should be used on behalf of the sender
// within the 2-of-2 multi-sig output that it contained within the
// funding transaction.
FundingKey *btcec.PublicKey
// RevocationPoint is the base revocation point for the sending party.
// Any commitment transaction belonging to the receiver of this message
// should use this key and their per-commitment point to derive the
// revocation key for the commitment transaction.
RevocationPoint *btcec.PublicKey
// PaymentPoint is the base payment point for the sending party. This
// key should be combined with the per commitment point for a
// particular commitment state in order to create the key that should
// be used in any output that pays directly to the sending party, and
// also within the HTLC covenant transactions.
PaymentPoint *btcec.PublicKey
// DelayedPaymentPoint is the delay point for the sending party. This
// key should be combined with the per commitment point to derive the
// keys that are used in outputs of the sender's commitment transaction
// where they claim funds.
DelayedPaymentPoint *btcec.PublicKey
// HtlcPoint is the base point used to derive the set of keys for this
// party that will be used within the HTLC public key scripts. This
// value is combined with the receiver's revocation base point in order
// to derive the keys that are used within HTLC scripts.
HtlcPoint *btcec.PublicKey
// FirstCommitmentPoint is the first commitment point for the sending
// party. This value should be combined with the receiver's revocation
// base point in order to derive the revocation keys that are placed
// within the commitment transaction of the sender.
FirstCommitmentPoint *btcec.PublicKey
// ChannelFlags is a bit-field which allows the initiator of the
// channel to specify further behavior surrounding the channel.
// Currently, the least significant bit of this bit field indicates the
// initiator of the channel wishes to advertise this channel publicly.
ChannelFlags FundingFlag
// UpfrontShutdownScript is the script to which the channel funds should
// be paid when mutually closing the channel. This field is optional, and
// and has a length prefix, so a zero will be written if it is not set
// and its length followed by the script will be written if it is set.
UpfrontShutdownScript DeliveryAddress
// ChannelType is the explicit channel type the initiator wishes to
// open.
ChannelType *ChannelType
// LeaseExpiry represents the absolute expiration height of a channel
// lease. This is a custom TLV record that will only apply when a leased
// channel is being opened using the script enforced lease commitment
// type.
LeaseExpiry *LeaseExpiry
// LocalNonce is an optional field that transmits the
// local/verification nonce for a party. This nonce will be used to
// verify the very first commitment transaction signature. This will
// only be populated if the simple taproot channels type was
// negotiated.
LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
//
// NOTE: Since the upfront shutdown script MUST be present (though can
// be zero-length) if any TLV data is available, the script will be
// extracted and removed from this blob when decoding. ExtraData will
// contain all TLV records _except_ the DeliveryAddress record in that
// case.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure OpenChannel implements the lnwire.Message
// interface.
var _ Message = (*OpenChannel)(nil)
// A compile time check to ensure OpenChannel implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*OpenChannel)(nil)
// Encode serializes the target OpenChannel into the passed io.Writer
// implementation. Serialization will observe the rules defined by the passed
// protocol version.
func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := []tlv.RecordProducer{&o.UpfrontShutdownScript}
if o.ChannelType != nil {
recordProducers = append(recordProducers, o.ChannelType)
}
if o.LeaseExpiry != nil {
recordProducers = append(recordProducers, o.LeaseExpiry)
}
o.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, &localNonce)
})
err := EncodeMessageExtraData(&o.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteBytes(w, o.ChainHash[:]); err != nil {
return err
}
if err := WriteBytes(w, o.PendingChannelID[:]); err != nil {
return err
}
if err := WriteSatoshi(w, o.FundingAmount); err != nil {
return err
}
if err := WriteMilliSatoshi(w, o.PushAmount); err != nil {
return err
}
if err := WriteSatoshi(w, o.DustLimit); err != nil {
return err
}
if err := WriteMilliSatoshi(w, o.MaxValueInFlight); err != nil {
return err
}
if err := WriteSatoshi(w, o.ChannelReserve); err != nil {
return err
}
if err := WriteMilliSatoshi(w, o.HtlcMinimum); err != nil {
return err
}
if err := WriteUint32(w, o.FeePerKiloWeight); err != nil {
return err
}
if err := WriteUint16(w, o.CsvDelay); err != nil {
return err
}
if err := WriteUint16(w, o.MaxAcceptedHTLCs); err != nil {
return err
}
if err := WritePublicKey(w, o.FundingKey); err != nil {
return err
}
if err := WritePublicKey(w, o.RevocationPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.PaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.DelayedPaymentPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.HtlcPoint); err != nil {
return err
}
if err := WritePublicKey(w, o.FirstCommitmentPoint); err != nil {
return err
}
if err := WriteFundingFlag(w, o.ChannelFlags); err != nil {
return err
}
return WriteBytes(w, o.ExtraData)
}
// Decode deserializes the serialized OpenChannel stored in the passed
// io.Reader into the target OpenChannel using the deserialization rules
// defined by the passed protocol version.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) Decode(r io.Reader, pver uint32) error {
// Read all the mandatory fields in the open message.
err := ReadElements(r,
o.ChainHash[:],
o.PendingChannelID[:],
&o.FundingAmount,
&o.PushAmount,
&o.DustLimit,
&o.MaxValueInFlight,
&o.ChannelReserve,
&o.HtlcMinimum,
&o.FeePerKiloWeight,
&o.CsvDelay,
&o.MaxAcceptedHTLCs,
&o.FundingKey,
&o.RevocationPoint,
&o.PaymentPoint,
&o.DelayedPaymentPoint,
&o.HtlcPoint,
&o.FirstCommitmentPoint,
&o.ChannelFlags,
)
if err != nil {
return err
}
// For backwards compatibility, the optional extra data blob for
// OpenChannel must contain an entry for the upfront shutdown script.
// We'll read it out and attempt to parse it.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
// Next we'll parse out the set of known records, keeping the raw tlv
// bytes untouched to ensure we don't drop any bytes erroneously.
var (
chanType ChannelType
leaseExpiry LeaseExpiry
localNonce = o.LocalNonce.Zero()
)
typeMap, err := tlvRecords.ExtractRecords(
&o.UpfrontShutdownScript, &chanType, &leaseExpiry,
&localNonce,
)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil {
o.ChannelType = &chanType
}
if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil {
o.LeaseExpiry = &leaseExpiry
}
if val, ok := typeMap[o.LocalNonce.TlvType()]; ok && val == nil {
o.LocalNonce = tlv.SomeRecordT(localNonce)
}
o.ExtraData = tlvRecords
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as an OpenChannel on the wire.
//
// This is part of the lnwire.Message interface.
func (o *OpenChannel) MsgType() MessageType {
return MsgOpenChannel
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (o *OpenChannel) SerializedSize() (uint32, error) {
return MessageSerializedSize(o)
}
package lnwire
import (
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// PartialSigLen is the length of a musig2 partial signature.
PartialSigLen = 32
)
type (
// PartialSigType is the type of the tlv record for a musig2
// partial signature. This is an _even_ type, which means it's required
// if included.
PartialSigType = tlv.TlvType6
// PartialSigTLV is a tlv record for a musig2 partial signature.
PartialSigTLV = tlv.RecordT[PartialSigType, PartialSig]
// OptPartialSigTLV is a tlv record for a musig2 partial signature.
// This is an optional record type.
OptPartialSigTLV = tlv.OptionalRecordT[PartialSigType, PartialSig]
)
// PartialSig is the base partial sig type. This only encodes the 32-byte
// partial signature. This is used for the co-op close flow, as both sides have
// already exchanged nonces, so they can send just the partial signature.
type PartialSig struct {
// Sig is the 32-byte musig2 partial signature.
Sig btcec.ModNScalar
}
// NewPartialSig creates a new partial sig.
func NewPartialSig(sig btcec.ModNScalar) PartialSig {
return PartialSig{
Sig: sig,
}
}
// Record returns the tlv record for the partial sig.
func (p *PartialSig) Record() tlv.Record {
return tlv.MakeStaticRecord(
(PartialSigType)(nil).TypeVal(), p, PartialSigLen,
partialSigTypeEncoder, partialSigTypeDecoder,
)
}
// partialSigTypeEncoder encodes a 32-byte musig2 partial signature as a TLV
// value.
func partialSigTypeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*PartialSig); ok {
sigBytes := v.Sig.Bytes()
return tlv.EBytes32(w, &sigBytes, buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.PartialSig")
}
// Encode writes the encoded version of this message to the passed io.Writer.
func (p *PartialSig) Encode(w io.Writer) error {
return partialSigTypeEncoder(w, p, nil)
}
// partialSigTypeDecoder decodes a 32-byte musig2 extended partial signature.
func partialSigTypeDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*PartialSig); ok && l == PartialSigLen {
var sBytes [32]byte
err := tlv.DBytes32(r, &sBytes, buf, PartialSigLen)
if err != nil {
return err
}
var s btcec.ModNScalar
s.SetBytes(&sBytes)
*v = PartialSig{
Sig: s,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.PartialSig", l,
PartialSigLen)
}
// Decode reads the encoded version of this message from the passed io.Reader.
func (p *PartialSig) Decode(r io.Reader) error {
return partialSigTypeDecoder(r, p, nil, PartialSigLen)
}
// SomePartialSig is a helper function that returns an otional PartialSig.
func SomePartialSig(sig PartialSig) OptPartialSigTLV {
return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig](sig))
}
const (
// PartialSigWithNonceLen is the length of a serialized
// PartialSigWithNonce. The sig is encoded as the 32 byte S value
// followed by the 66 nonce value.
PartialSigWithNonceLen = 98
)
type (
// PartialSigWithNonceType is the type of the tlv record for a musig2
// partial signature with nonce. This is an _even_ type, which means
// it's required if included.
PartialSigWithNonceType = tlv.TlvType2
// PartialSigWithNonceTLV is a tlv record for a musig2 partial
// signature.
PartialSigWithNonceTLV = tlv.RecordT[
PartialSigWithNonceType, PartialSigWithNonce,
]
// OptPartialSigWithNonceTLV is a tlv record for a musig2 partial
// signature. This is an optional record type.
OptPartialSigWithNonceTLV = tlv.OptionalRecordT[
PartialSigWithNonceType, PartialSigWithNonce,
]
)
// PartialSigWithNonce is a partial signature with the nonce that was used to
// generate the signature. This is used for funding as well as the commitment
// transaction update dance. By sending the nonce only with the signature, we
// enable the sender to generate their nonce just before they create their
// signature. Signers can use this trait to mix in additional contextual data
// such as the commitment txn itself into their nonce generation function.
//
// The final signature is 98 bytes: 32 bytes for the S value, and 66 bytes for
// the public nonce (two compressed points).
type PartialSigWithNonce struct {
PartialSig
// Nonce is the 66-byte musig2 nonce.
Nonce Musig2Nonce
}
// NewPartialSigWithNonce creates a new partial sig with nonce.
func NewPartialSigWithNonce(nonce [musig2.PubNonceSize]byte,
sig btcec.ModNScalar) *PartialSigWithNonce {
return &PartialSigWithNonce{
Nonce: nonce,
PartialSig: NewPartialSig(sig),
}
}
// Record returns the tlv record for the partial sig with nonce.
func (p *PartialSigWithNonce) Record() tlv.Record {
return tlv.MakeStaticRecord(
(PartialSigWithNonceType)(nil).TypeVal(), p,
PartialSigWithNonceLen, partialSigWithNonceTypeEncoder,
partialSigWithNonceTypeDecoder,
)
}
// partialSigWithNonceTypeEncoder encodes 98-byte musig2 extended partial
// signature as: s {32} || nonce {66}.
func partialSigWithNonceTypeEncoder(w io.Writer, val interface{},
_ *[8]byte) error {
if v, ok := val.(*PartialSigWithNonce); ok {
sigBytes := v.Sig.Bytes()
if _, err := w.Write(sigBytes[:]); err != nil {
return err
}
if _, err := w.Write(v.Nonce[:]); err != nil {
return err
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.PartialSigWithNonce")
}
// Encode writes the encoded version of this message to the passed io.Writer.
func (p *PartialSigWithNonce) Encode(w io.Writer) error {
return partialSigWithNonceTypeEncoder(w, p, nil)
}
// partialSigWithNonceTypeDecoder decodes a 98-byte musig2 extended partial
// signature.
func partialSigWithNonceTypeDecoder(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*PartialSigWithNonce); ok &&
l == PartialSigWithNonceLen {
var sBytes [32]byte
err := tlv.DBytes32(r, &sBytes, buf, PartialSigLen)
if err != nil {
return err
}
var s btcec.ModNScalar
s.SetBytes(&sBytes)
var nonce [66]byte
if _, err := io.ReadFull(r, nonce[:]); err != nil {
return err
}
*v = PartialSigWithNonce{
PartialSig: NewPartialSig(s),
Nonce: nonce,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.PartialSigWithNonce", l,
PartialSigWithNonceLen)
}
// Decode reads the encoded version of this message from the passed io.Reader.
func (p *PartialSigWithNonce) Decode(r io.Reader) error {
return partialSigWithNonceTypeDecoder(
r, p, nil, PartialSigWithNonceLen,
)
}
// MaybePartialSigWithNonce is a helper function that returns an optional
// PartialSigWithNonceTLV.
func MaybePartialSigWithNonce(sig *PartialSigWithNonce,
) OptPartialSigWithNonceTLV {
if sig == nil {
var none OptPartialSigWithNonceTLV
return none
}
return tlv.SomeRecordT(
tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce](
*sig,
),
)
}
package lnwire
import (
"bytes"
"io"
)
// PingPayload is a set of opaque bytes used to pad out a ping message.
type PingPayload []byte
// Ping defines a message which is sent by peers periodically to determine if
// the connection is still valid. Each ping message carries the number of bytes
// to pad the pong response with, and also a number of bytes to be ignored at
// the end of the ping message (which is padding).
type Ping struct {
// NumPongBytes is the number of bytes the pong response to this
// message should carry.
NumPongBytes uint16
// PaddingBytes is a set of opaque bytes used to pad out this ping
// message. Using this field in conjunction to the one above, it's
// possible for node to generate fake cover traffic.
PaddingBytes PingPayload
}
// NewPing returns a new Ping message.
func NewPing(numBytes uint16) *Ping {
return &Ping{
NumPongBytes: numBytes,
}
}
// A compile time check to ensure Ping implements the lnwire.Message interface.
var _ Message = (*Ping)(nil)
// A compile time check to ensure Ping implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Ping)(nil)
// Decode deserializes a serialized Ping message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p *Ping) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r, &p.NumPongBytes, &p.PaddingBytes)
if err != nil {
return err
}
if p.NumPongBytes > MaxPongBytes {
return ErrMaxPongBytesExceeded
}
return nil
}
// Encode serializes the target Ping into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (p *Ping) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteUint16(w, p.NumPongBytes); err != nil {
return err
}
return WritePingPayload(w, p.PaddingBytes)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (p *Ping) MsgType() MessageType {
return MsgPing
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (p *Ping) SerializedSize() (uint32, error) {
return MessageSerializedSize(p)
}
package lnwire
import (
"bytes"
"fmt"
"io"
)
// MaxPongBytes is the maximum number of extra bytes a pong can be requested to
// send. The type of the message (19) takes 2 bytes, the length field takes up
// 2 bytes, leaving 65531 bytes.
const MaxPongBytes = 65531
// ErrMaxPongBytesExceeded indicates that the NumPongBytes field from the ping
// message has exceeded MaxPongBytes.
var ErrMaxPongBytesExceeded = fmt.Errorf("pong bytes exceeded")
// PongPayload is a set of opaque bytes sent in response to a ping message.
type PongPayload []byte
// Pong defines a message which is the direct response to a received Ping
// message. A Pong reply indicates that a connection is still active. The Pong
// reply to a Ping message should contain the nonce carried in the original
// Pong message.
type Pong struct {
// PongBytes is a set of opaque bytes that corresponds to the
// NumPongBytes defined in the ping message that this pong is
// replying to.
PongBytes PongPayload
}
// NewPong returns a new Pong message.
func NewPong(pongBytes []byte) *Pong {
return &Pong{
PongBytes: pongBytes,
}
}
// A compile time check to ensure Pong implements the lnwire.Message interface.
var _ Message = (*Pong)(nil)
// A compile time check to ensure Pong implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Pong)(nil)
// Decode deserializes a serialized Pong message stored in the passed io.Reader
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (p *Pong) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&p.PongBytes,
)
}
// Encode serializes the target Pong into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (p *Pong) Encode(w *bytes.Buffer, pver uint32) error {
return WritePongPayload(w, p.PongBytes)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (p *Pong) MsgType() MessageType {
return MsgPong
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (p *Pong) SerializedSize() (uint32, error) {
return MessageSerializedSize(p)
}
package lnwire
import (
"bytes"
"io"
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// QueryChannelRange is a message sent by a node in order to query the
// receiving node of the set of open channel they know of with short channel
// ID's after the specified block height, capped at the number of blocks beyond
// that block height. This will be used by nodes upon initial connect to
// synchronize their views of the network.
type QueryChannelRange struct {
// ChainHash denotes the target chain that we're trying to synchronize
// channel graph state for.
ChainHash chainhash.Hash
// FirstBlockHeight is the first block in the query range. The
// responder should send all new short channel IDs from this block
// until this block plus the specified number of blocks.
FirstBlockHeight uint32
// NumBlocks is the number of blocks beyond the first block that short
// channel ID's should be sent for.
NumBlocks uint32
// QueryOptions is an optional feature bit vector that can be used to
// specify additional query options.
QueryOptions *QueryOptions
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewQueryChannelRange creates a new empty QueryChannelRange message.
func NewQueryChannelRange() *QueryChannelRange {
return &QueryChannelRange{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure QueryChannelRange implements the
// lnwire.Message interface.
var _ Message = (*QueryChannelRange)(nil)
// A compile time check to ensure QueryChannelRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*QueryChannelRange)(nil)
// Decode deserializes a serialized QueryChannelRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Decode(r io.Reader, _ uint32) error {
err := ReadElements(
r, q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var queryOptions QueryOptions
typeMap, err := tlvRecords.ExtractRecords(&queryOptions)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[QueryOptionsRecordType]; ok && val == nil {
q.QueryOptions = &queryOptions
}
if len(tlvRecords) != 0 {
q.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target QueryChannelRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteBytes(w, q.ChainHash[:]); err != nil {
return err
}
if err := WriteUint32(w, q.FirstBlockHeight); err != nil {
return err
}
if err := WriteUint32(w, q.NumBlocks); err != nil {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
if q.QueryOptions != nil {
recordProducers = append(recordProducers, q.QueryOptions)
}
err := EncodeMessageExtraData(&q.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, q.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (q *QueryChannelRange) MsgType() MessageType {
return MsgQueryChannelRange
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (q *QueryChannelRange) SerializedSize() (uint32, error) {
msgCpy := *q
return MessageSerializedSize(&msgCpy)
}
// LastBlockHeight returns the last block height covered by the range of a
// QueryChannelRange message.
func (q *QueryChannelRange) LastBlockHeight() uint32 {
// Handle overflows by casting to uint64.
lastBlockHeight := uint64(q.FirstBlockHeight) + uint64(q.NumBlocks) - 1
if lastBlockHeight > math.MaxUint32 {
return math.MaxUint32
}
return uint32(lastBlockHeight)
}
// WithTimestamps returns true if the query has asked for timestamps too.
func (q *QueryChannelRange) WithTimestamps() bool {
if q.QueryOptions == nil {
return false
}
queryOpts := RawFeatureVector(*q.QueryOptions)
return queryOpts.IsSet(QueryOptionTimestampBit)
}
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// QueryOptionsRecordType is the TLV number of the query_options TLV
// record in the query_channel_range message.
QueryOptionsRecordType tlv.Type = 1
// QueryOptionTimestampBit is the bit position in the query_option
// feature bit vector which is used to indicate that timestamps are
// desired in the reply_channel_range response.
QueryOptionTimestampBit = 0
)
// QueryOptions is the type used to represent the query_options feature bit
// vector in the query_channel_range message.
type QueryOptions RawFeatureVector
// NewTimestampQueryOption is a helper constructor used to construct a
// QueryOption with the timestamp bit set.
func NewTimestampQueryOption() *QueryOptions {
opt := QueryOptions(*NewRawFeatureVector(
QueryOptionTimestampBit,
))
return &opt
}
// featureBitLen calculates and returns the size of the resulting feature bit
// vector.
func (c *QueryOptions) featureBitLen() uint64 {
fv := RawFeatureVector(*c)
return uint64(fv.SerializeSize())
}
// Record constructs a tlv.Record from the QueryOptions to be used in the
// query_channel_range message.
func (c *QueryOptions) Record() tlv.Record {
return tlv.MakeDynamicRecord(
QueryOptionsRecordType, c, c.featureBitLen, queryOptionsEncoder,
queryOptionsDecoder,
)
}
// queryOptionsEncoder encodes the QueryOptions and writes it to the provided
// writer.
func queryOptionsEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*QueryOptions); ok {
// Encode the feature bits as a byte slice without its length
// prepended, as that's already taken care of by the TLV record.
fv := RawFeatureVector(*v)
return fv.encode(w, fv.SerializeSize(), 8)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.QueryOptions")
}
// queryOptionsDecoder attempts to read a QueryOptions from the given reader.
func queryOptionsDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*QueryOptions); ok {
fv := NewRawFeatureVector()
if err := fv.decode(r, int(l), 8); err != nil {
return err
}
*v = QueryOptions(*fv)
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.QueryOptions")
}
package lnwire
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sort"
"sync"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
const (
// maxZlibBufSize is the max number of bytes that we'll accept from a
// zlib decoding instance. We do this in order to limit the total
// amount of memory allocated during a decoding instance.
maxZlibBufSize = 67413630
)
// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
// items were not sorted.
type ErrUnsortedSIDs struct {
prevSID ShortChannelID
curSID ShortChannelID
}
// Error returns a human-readable description of the error.
func (e ErrUnsortedSIDs) Error() string {
return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
e.curSID, e.prevSID)
}
// zlibDecodeMtx is a package level mutex that we'll use in order to ensure
// that we'll only attempt a single zlib decoding instance at a time. This
// allows us to also further bound our memory usage.
var zlibDecodeMtx sync.Mutex
// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
// came across an unknown short channel ID encoding, and therefore were unable
// to continue parsing.
func ErrUnknownShortChanIDEncoding(encoding QueryEncoding) error {
return fmt.Errorf("unknown short chan id encoding: %v", encoding)
}
// QueryShortChanIDs is a message that allows the sender to query a set of
// channel announcement and channel update messages that correspond to the set
// of encoded short channel ID's. The encoding of the short channel ID's is
// detailed in the query message ensuring that the receiver knows how to
// properly decode each encode short channel ID which may be encoded using a
// compression format. The receiver should respond with a series of channel
// announcement and channel updates, finally sending a ReplyShortChanIDsEnd
// message.
type QueryShortChanIDs struct {
// ChainHash denotes the target chain that we're querying for the
// channel ID's of.
ChainHash chainhash.Hash
// EncodingType is a signal to the receiver of the message that
// indicates exactly how the set of short channel ID's that follow have
// been encoded.
EncodingType QueryEncoding
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used during testing.
noSort bool
}
// NewQueryShortChanIDs creates a new QueryShortChanIDs message.
func NewQueryShortChanIDs(h chainhash.Hash, e QueryEncoding,
s []ShortChannelID) *QueryShortChanIDs {
return &QueryShortChanIDs{
ChainHash: h,
EncodingType: e,
ShortChanIDs: s,
}
}
// A compile time check to ensure QueryShortChanIDs implements the
// lnwire.Message interface.
var _ Message = (*QueryShortChanIDs)(nil)
// A compile time check to ensure QueryShortChanIDs implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*QueryShortChanIDs)(nil)
// Decode deserializes a serialized QueryShortChanIDs message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r, q.ChainHash[:])
if err != nil {
return err
}
q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
if err != nil {
return err
}
return q.ExtraData.Decode(r)
}
// decodeShortChanIDs decodes a set of short channel ID's that have been
// encoded. The first byte of the body details how the short chan ID's were
// encoded. We'll use this type to govern exactly how we go about encoding the
// set of short channel ID's.
func decodeShortChanIDs(r io.Reader) (QueryEncoding, []ShortChannelID, error) {
// First, we'll attempt to read the number of bytes in the body of the
// set of encoded short channel ID's.
var numBytesResp uint16
err := ReadElements(r, &numBytesResp)
if err != nil {
return 0, nil, err
}
if numBytesResp == 0 {
return 0, nil, nil
}
queryBody := make([]byte, numBytesResp)
if _, err := io.ReadFull(r, queryBody); err != nil {
return 0, nil, err
}
// The first byte is the encoding type, so we'll extract that so we can
// continue our parsing.
encodingType := QueryEncoding(queryBody[0])
// Before continuing, we'll snip off the first byte of the query body
// as that was just the encoding type.
queryBody = queryBody[1:]
// Otherwise, depending on the encoding type, we'll decode the encode
// short channel ID's in a different manner.
switch encodingType {
// In this encoding, we'll simply read a sort array of encoded short
// channel ID's from the buffer.
case EncodingSortedPlain:
// If after extracting the encoding type, the number of
// remaining bytes is not a whole multiple of the size of an
// encoded short channel ID (8 bytes), then we'll return a
// parsing error.
if len(queryBody)%8 != 0 {
return 0, nil, fmt.Errorf("whole number of short "+
"chan ID's cannot be encoded in len=%v",
len(queryBody))
}
// As each short channel ID is encoded as 8 bytes, we can
// compute the number of bytes encoded based on the size of the
// query body.
numShortChanIDs := len(queryBody) / 8
if numShortChanIDs == 0 {
return encodingType, nil, nil
}
// Finally, we'll read out the exact number of short channel
// ID's to conclude our parsing.
shortChanIDs := make([]ShortChannelID, numShortChanIDs)
bodyReader := bytes.NewReader(queryBody)
var lastChanID ShortChannelID
for i := 0; i < numShortChanIDs; i++ {
if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
return 0, nil, fmt.Errorf("unable to parse "+
"short chan ID: %v", err)
}
// We'll ensure that this short chan ID is greater than
// the last one. This is a requirement within the
// encoding, and if violated can aide us in detecting
// malicious payloads. This can only be true starting
// at the second chanID.
cid := shortChanIDs[i]
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
}
return encodingType, shortChanIDs, nil
// In this encoding, we'll use zlib to decode the compressed payload.
// However, we'll pay attention to ensure that we don't open our selves
// up to a memory exhaustion attack.
case EncodingSortedZlib:
// We'll obtain an ultimately release the zlib decode mutex.
// This guards us against allocating too much memory to decode
// each instance from concurrent peers.
zlibDecodeMtx.Lock()
defer zlibDecodeMtx.Unlock()
// At this point, if there's no body remaining, then only the encoding
// type was specified, meaning that there're no further bytes to be
// parsed.
if len(queryBody) == 0 {
return encodingType, nil, nil
}
// Before we start to decode, we'll create a limit reader over
// the current reader. This will ensure that we can control how
// much memory we're allocating during the decoding process.
limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
R: bytes.NewReader(queryBody),
N: maxZlibBufSize,
})
if err != nil {
return 0, nil, fmt.Errorf("unable to create zlib "+
"reader: %w", err)
}
var (
shortChanIDs []ShortChannelID
lastChanID ShortChannelID
i int
)
for {
// We'll now attempt to read the next short channel ID
// encoded in the payload.
var cid ShortChannelID
err := ReadElements(limitedDecompressor, &cid)
switch {
// If we get an EOF error, then that either means we've
// read all that's contained in the buffer, or have hit
// our limit on the number of bytes we'll read. In
// either case, we'll return what we have so far.
case err == io.ErrUnexpectedEOF || err == io.EOF:
return encodingType, shortChanIDs, nil
// Otherwise, we hit some other sort of error, possibly
// an invalid payload, so we'll exit early with the
// error.
case err != nil:
return 0, nil, fmt.Errorf("unable to "+
"deflate next short chan "+
"ID: %v", err)
}
// We successfully read the next ID, so we'll collect
// that in the set of final ID's to return.
shortChanIDs = append(shortChanIDs, cid)
// Finally, we'll ensure that this short chan ID is
// greater than the last one. This is a requirement
// within the encoding, and if violated can aide us in
// detecting malicious payloads. This can only be true
// starting at the second chanID.
if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
}
lastChanID = cid
i++
}
default:
// If we've been sent an encoding type that we don't know of,
// then we'll return a parsing error as we can't continue if
// we're unable to encode them.
return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
}
}
// Encode serializes the target QueryShortChanIDs into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
// First, we'll write out the chain hash.
if err := WriteBytes(w, q.ChainHash[:]); err != nil {
return err
}
// For both of the current encoding types, the channel ID's are to be
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !q.noSort {
sort.Slice(q.ShortChanIDs, func(i, j int) bool {
return q.ShortChanIDs[i].ToUint64() <
q.ShortChanIDs[j].ToUint64()
})
}
// Base on our encoding type, we'll write out the set of short channel
// ID's.
err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs)
if err != nil {
return err
}
return WriteBytes(w, q.ExtraData)
}
// encodeShortChanIDs encodes the passed short channel ID's into the passed
// io.Writer, respecting the specified encoding type.
func encodeShortChanIDs(w *bytes.Buffer, encodingType QueryEncoding,
shortChanIDs []ShortChannelID) error {
switch encodingType {
// In this encoding, we'll simply write a sorted array of encoded short
// channel ID's from the buffer.
case EncodingSortedPlain:
// First, we'll write out the number of bytes of the query
// body. We add 1 as the response will have the encoding type
// prepended to it.
numBytesBody := uint16(len(shortChanIDs)*8) + 1
if err := WriteUint16(w, numBytesBody); err != nil {
return err
}
// We'll then write out the encoding that that follows the
// actual encoded short channel ID's.
err := WriteQueryEncoding(w, encodingType)
if err != nil {
return err
}
// Now that we know they're sorted, we can write out each short
// channel ID to the buffer.
for _, chanID := range shortChanIDs {
if err := WriteShortChannelID(w, chanID); err != nil {
return fmt.Errorf("unable to write short chan "+
"ID: %v", err)
}
}
return nil
// For this encoding we'll first write out a serialized version of all
// the channel ID's into a buffer, then zlib encode that. The final
// payload is what we'll write out to the passed io.Writer.
//
// TODO(roasbeef): assumes the caller knows the proper chunk size to
// pass to avoid bin-packing here
case EncodingSortedZlib:
// If we don't have anything at all to write, then we'll write
// an empty payload so we don't include things like the zlib
// header when the remote party is expecting no actual short
// channel IDs.
var compressedPayload []byte
if len(shortChanIDs) > 0 {
// We'll make a new write buffer to hold the bytes of
// shortChanIDs.
var wb bytes.Buffer
// Next, we'll write out all the channel ID's directly
// into the zlib writer, which will do compressing on
// the fly.
for _, chanID := range shortChanIDs {
err := WriteShortChannelID(&wb, chanID)
if err != nil {
return fmt.Errorf(
"unable to write short chan "+
"ID: %v", err,
)
}
}
// With shortChanIDs written into wb, we'll create a
// zlib writer and write all the compressed bytes.
var zlibBuffer bytes.Buffer
zlibWriter := zlib.NewWriter(&zlibBuffer)
if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
return fmt.Errorf(
"unable to write compressed short chan"+
"ID: %w", err)
}
// Now that we've written all the elements, we'll
// ensure the compressed stream is written to the
// underlying buffer.
if err := zlibWriter.Close(); err != nil {
return fmt.Errorf("unable to finalize "+
"compression: %v", err)
}
compressedPayload = zlibBuffer.Bytes()
}
// Now that we have all the items compressed, we can compute
// what the total payload size will be. We add one to account
// for the byte to encode the type.
//
// If we don't have any actual bytes to write, then we'll end
// up emitting one byte for the length, followed by the
// encoding type, and nothing more. The spec isn't 100% clear
// in this area, but we do this as this is what most of the
// other implementations do.
numBytesBody := len(compressedPayload) + 1
// Finally, we can write out the number of bytes, the
// compression type, and finally the buffer itself.
if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
return err
}
err := WriteQueryEncoding(w, encodingType)
if err != nil {
return err
}
return WriteBytes(w, compressedPayload)
default:
// If we're trying to encode with an encoding type that we
// don't know of, then we'll return a parsing error as we can't
// continue if we're unable to encode them.
return ErrUnknownShortChanIDEncoding(encodingType)
}
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (q *QueryShortChanIDs) MsgType() MessageType {
return MsgQueryShortChanIDs
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (q *QueryShortChanIDs) SerializedSize() (uint32, error) {
return MessageSerializedSize(q)
}
package lnwire
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// ReplyChannelRange is the response to the QueryChannelRange message. It
// includes the original query, and the next streaming chunk of encoded short
// channel ID's as the response. We'll also include a byte that indicates if
// this is the last query in the message.
type ReplyChannelRange struct {
// ChainHash denotes the target chain that we're trying to synchronize
// channel graph state for.
ChainHash chainhash.Hash
// FirstBlockHeight is the first block in the query range. The
// responder should send all new short channel IDs from this block
// until this block plus the specified number of blocks.
FirstBlockHeight uint32
// NumBlocks is the number of blocks beyond the first block that short
// channel ID's should be sent for.
NumBlocks uint32
// Complete denotes if this is the conclusion of the set of streaming
// responses to the original query.
Complete uint8
// EncodingType is a signal to the receiver of the message that
// indicates exactly how the set of short channel ID's that follow have
// been encoded.
EncodingType QueryEncoding
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// Timestamps is an optional set of timestamps corresponding to the
// latest timestamps for the channel update messages corresponding to
// those referenced in the ShortChanIDs list. If this field is used,
// then the length must match the length of ShortChanIDs.
Timestamps Timestamps
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
// noSort indicates whether or not to sort the short channel ids before
// writing them out.
//
// NOTE: This should only be used for testing.
noSort bool
}
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
func NewReplyChannelRange() *ReplyChannelRange {
return &ReplyChannelRange{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure ReplyChannelRange implements the
// lnwire.Message interface.
var _ Message = (*ReplyChannelRange)(nil)
// A compile time check to ensure ReplyChannelRange implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ReplyChannelRange)(nil)
// Decode deserializes a serialized ReplyChannelRange message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
c.ChainHash[:],
&c.FirstBlockHeight,
&c.NumBlocks,
&c.Complete,
)
if err != nil {
return err
}
c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var timeStamps Timestamps
typeMap, err := tlvRecords.ExtractRecords(&timeStamps)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[TimestampsRecordType]; ok && val == nil {
c.Timestamps = timeStamps
// Check that a timestamp was provided for each SCID.
if len(c.Timestamps) != len(c.ShortChanIDs) {
return fmt.Errorf("number of timestamps does not " +
"match number of SCIDs")
}
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ReplyChannelRange into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteBytes(w, c.ChainHash[:]); err != nil {
return err
}
if err := WriteUint32(w, c.FirstBlockHeight); err != nil {
return err
}
if err := WriteUint32(w, c.NumBlocks); err != nil {
return err
}
if err := WriteUint8(w, c.Complete); err != nil {
return err
}
// For both of the current encoding types, the channel ID's are to be
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !c.noSort {
var scidPreSortIndex map[uint64]int
if len(c.Timestamps) != 0 {
// Sanity check that a timestamp was provided for each
// SCID.
if len(c.Timestamps) != len(c.ShortChanIDs) {
return fmt.Errorf("must provide a timestamp " +
"pair for each of the given SCIDs")
}
// Create a map from SCID value to the original index of
// the SCID in the unsorted list.
scidPreSortIndex = make(
map[uint64]int, len(c.ShortChanIDs),
)
for i, scid := range c.ShortChanIDs {
scidPreSortIndex[scid.ToUint64()] = i
}
// Sanity check that there were no duplicates in the
// SCID list.
if len(scidPreSortIndex) != len(c.ShortChanIDs) {
return fmt.Errorf("scid list should not " +
"contain duplicates")
}
}
// Now sort the SCIDs.
sort.Slice(c.ShortChanIDs, func(i, j int) bool {
return c.ShortChanIDs[i].ToUint64() <
c.ShortChanIDs[j].ToUint64()
})
if len(c.Timestamps) != 0 {
timestamps := make(Timestamps, len(c.Timestamps))
for i, scid := range c.ShortChanIDs {
timestamps[i] = []ChanUpdateTimestamps(
c.Timestamps,
)[scidPreSortIndex[scid.ToUint64()]]
}
c.Timestamps = timestamps
}
}
err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
if err != nil {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
if len(c.Timestamps) != 0 {
recordProducers = append(recordProducers, &c.Timestamps)
}
err = EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) MsgType() MessageType {
return MsgReplyChannelRange
}
// LastBlockHeight returns the last block height covered by the range of a
// QueryChannelRange message.
func (c *ReplyChannelRange) LastBlockHeight() uint32 {
// Handle overflows by casting to uint64.
lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1
if lastBlockHeight > math.MaxUint32 {
return math.MaxUint32
}
return uint32(lastBlockHeight)
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ReplyChannelRange) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ReplyShortChanIDsEnd is a message that marks the end of a streaming message
// response to an initial QueryShortChanIDs message. This marks that the
// receiver of the original QueryShortChanIDs for the target chain has either
// sent all adequate responses it knows of, or doesn't know of any short chan
// ID's for the target chain.
type ReplyShortChanIDsEnd struct {
// ChainHash denotes the target chain that we're respond to a short
// chan ID query for.
ChainHash chainhash.Hash
// Complete will be set to 0 if we don't know of the chain that the
// remote peer sent their query for. Otherwise, we'll set this to 1 in
// order to indicate that we've sent all known responses for the prior
// set of short chan ID's in the corresponding QueryShortChanIDs
// message.
Complete uint8
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message.
func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd {
return &ReplyShortChanIDsEnd{}
}
// A compile time check to ensure ReplyShortChanIDsEnd implements the
// lnwire.Message interface.
var _ Message = (*ReplyShortChanIDsEnd)(nil)
// A compile time check to ensure ReplyShortChanIDsEnd implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ReplyShortChanIDsEnd)(nil)
// Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
c.ChainHash[:],
&c.Complete,
&c.ExtraData,
)
}
// Encode serializes the target ReplyShortChanIDsEnd into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteBytes(w, c.ChainHash[:]); err != nil {
return err
}
if err := WriteUint8(w, c.Complete); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *ReplyShortChanIDsEnd) MsgType() MessageType {
return MsgReplyShortChanIDsEnd
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *ReplyShortChanIDsEnd) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tlv"
)
// RevokeAndAck is sent by either side once a CommitSig message has been
// received, and validated. This message serves to revoke the prior commitment
// transaction, which was the most up to date version until a CommitSig message
// referencing the specified ChannelPoint was received. Additionally, this
// message also piggyback's the next revocation hash that Alice should use when
// constructing the Bob's version of the next commitment transaction (which
// would be done before sending a CommitSig message). This piggybacking allows
// Alice to send the next CommitSig message modifying Bob's commitment
// transaction without first asking for a revocation hash initially.
type RevokeAndAck struct {
// ChanID uniquely identifies to which currently active channel this
// RevokeAndAck applies to.
ChanID ChannelID
// Revocation is the preimage to the revocation hash of the now prior
// commitment transaction.
Revocation [32]byte
// NextRevocationKey is the next commitment point which should be used
// for the next commitment transaction the remote peer creates for us.
// This, in conjunction with revocation base point will be used to
// create the proper revocation key used within the commitment
// transaction.
NextRevocationKey *btcec.PublicKey
// LocalNonce is the next _local_ nonce for the sending party. This
// allows the receiving party to propose a new commitment using their
// remote nonce and the sender's local nonce.
LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewRevokeAndAck creates a new RevokeAndAck message.
func NewRevokeAndAck() *RevokeAndAck {
return &RevokeAndAck{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure RevokeAndAck implements the lnwire.Message
// interface.
var _ Message = (*RevokeAndAck)(nil)
// A compile time check to ensure RevokeAndAck implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*RevokeAndAck)(nil)
// Decode deserializes a serialized RevokeAndAck message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
&c.ChanID,
c.Revocation[:],
&c.NextRevocationKey,
)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
localNonce := c.LocalNonce.Zero()
typeMap, err := tlvRecords.ExtractRecords(&localNonce)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[c.LocalNonce.TlvType()]; ok && val == nil {
c.LocalNonce = tlv.SomeRecordT(localNonce)
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target RevokeAndAck into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1)
c.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, &localNonce)
})
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteBytes(w, c.Revocation[:]); err != nil {
return err
}
if err := WritePublicKey(w, c.NextRevocationKey); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *RevokeAndAck) MsgType() MessageType {
return MsgRevokeAndAck
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *RevokeAndAck) TargetChanID() ChannelID {
return c.ChanID
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *RevokeAndAck) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// AliasScidRecordType is the type of the experimental record to denote
// the alias being used in an option_scid_alias channel.
AliasScidRecordType tlv.Type = 1
)
// ShortChannelID represents the set of data which is needed to retrieve all
// necessary data to validate the channel existence.
type ShortChannelID struct {
// BlockHeight is the height of the block where funding transaction
// located.
//
// NOTE: This field is limited to 3 bytes.
BlockHeight uint32
// TxIndex is a position of funding transaction within a block.
//
// NOTE: This field is limited to 3 bytes.
TxIndex uint32
// TxPosition indicating transaction output which pays to the channel.
TxPosition uint16
}
// NewShortChanIDFromInt returns a new ShortChannelID which is the decoded
// version of the compact channel ID encoded within the uint64. The format of
// the compact channel ID is as follows: 3 bytes for the block height, 3 bytes
// for the transaction index, and 2 bytes for the output index.
func NewShortChanIDFromInt(chanID uint64) ShortChannelID {
return ShortChannelID{
BlockHeight: uint32(chanID >> 40),
TxIndex: uint32(chanID>>16) & 0xFFFFFF,
TxPosition: uint16(chanID),
}
}
// ToUint64 converts the ShortChannelID into a compact format encoded within a
// uint64 (8 bytes).
func (c ShortChannelID) ToUint64() uint64 {
// TODO(roasbeef): explicit error on overflow?
return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) |
(uint64(c.TxPosition)))
}
// String generates a human-readable representation of the channel ID.
func (c ShortChannelID) String() string {
return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition)
}
// AltString generates a human-readable representation of the channel ID
// with 'x' as a separator.
func (c ShortChannelID) AltString() string {
return fmt.Sprintf("%dx%dx%d", c.BlockHeight, c.TxIndex, c.TxPosition)
}
// Record returns a TLV record that can be used to encode/decode a
// ShortChannelID to/from a TLV stream.
func (c *ShortChannelID) Record() tlv.Record {
return tlv.MakeStaticRecord(
AliasScidRecordType, c, 8, EShortChannelID, DShortChannelID,
)
}
// IsDefault returns true if the ShortChannelID represents the zero value for
// its type.
func (c ShortChannelID) IsDefault() bool {
return c == ShortChannelID{}
}
// EShortChannelID is an encoder for ShortChannelID. It is exported so other
// packages can use the encoding scheme.
func EShortChannelID(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*ShortChannelID); ok {
return tlv.EUint64T(w, v.ToUint64(), buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.ShortChannelID")
}
// DShortChannelID is a decoder for ShortChannelID. It is exported so other
// packages can use the decoding scheme.
func DShortChannelID(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if v, ok := val.(*ShortChannelID); ok {
var scid uint64
err := tlv.DUint64(r, &scid, buf, 8)
if err != nil {
return err
}
*v = NewShortChanIDFromInt(scid)
return nil
}
return tlv.NewTypeForDecodingErr(val, "lnwire.ShortChannelID", l, 8)
}
package lnwire
import (
"bytes"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
type (
// ShutdownNonceType is the type of the shutdown nonce TLV record.
ShutdownNonceType = tlv.TlvType8
// ShutdownNonceTLV is the TLV record that contains the shutdown nonce.
ShutdownNonceTLV = tlv.OptionalRecordT[ShutdownNonceType, Musig2Nonce]
)
// SomeShutdownNonce returns a ShutdownNonceTLV with the given nonce.
func SomeShutdownNonce(nonce Musig2Nonce) ShutdownNonceTLV {
return tlv.SomeRecordT(
tlv.NewRecordT[ShutdownNonceType, Musig2Nonce](nonce),
)
}
// Shutdown is sent by either side in order to initiate the cooperative closure
// of a channel. This message is sparse as both sides implicitly have the
// information necessary to construct a transaction that will send the settled
// funds of both parties to the final delivery addresses negotiated during the
// funding workflow.
type Shutdown struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// Address is the script to which the channel funds will be paid.
Address DeliveryAddress
// ShutdownNonce is the nonce the sender will use to sign the first
// co-op sign offer.
ShutdownNonce ShutdownNonceTLV
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the Shutdown
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewShutdown creates a new Shutdown message.
func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
return &Shutdown{
ChannelID: cid,
Address: addr,
}
}
// A compile-time check to ensure Shutdown implements the lnwire.Message
// interface.
var _ Message = (*Shutdown)(nil)
// A compile-time check to ensure Shutdown implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Shutdown)(nil)
// Decode deserializes a serialized Shutdown from the passed io.Reader,
// observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r, &s.ChannelID, &s.Address)
if err != nil {
return err
}
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
// Extract TLV records from the extra data field.
musigNonce := s.ShutdownNonce.Zero()
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
tlvRecords, &musigNonce,
)
if err != nil {
return err
}
// Assign the parsed records back to the message.
if _, ok := parsed[musigNonce.TlvType()]; ok {
s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
}
s.CustomRecords = customRecords
s.ExtraData = extraData
return nil
}
// Encode serializes the target Shutdown into the passed io.Writer observing
// the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, s.ChannelID); err != nil {
return err
}
if err := WriteDeliveryAddress(w, s.Address); err != nil {
return err
}
// Only include nonce in extra data if present.
var records []tlv.RecordProducer
s.ShutdownNonce.WhenSome(
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
records = append(records, &nonce)
},
)
extraData, err := MergeAndEncode(records, s.ExtraData, s.CustomRecords)
if err != nil {
return err
}
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (s *Shutdown) MsgType() MessageType {
return MsgShutdown
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (s *Shutdown) SerializedSize() (uint32, error) {
return MessageSerializedSize(s)
}
package lnwire
import (
"errors"
"fmt"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/tlv"
)
var (
errSigTooShort = errors.New("malformed signature: too short")
errBadLength = errors.New("malformed signature: bad length")
errBadRLength = errors.New("malformed signature: bogus R length")
errBadSLength = errors.New("malformed signature: bogus S length")
errRTooLong = errors.New("R is over 32 bytes long without padding")
errSTooLong = errors.New("S is over 32 bytes long without padding")
)
// sigType represents the type of signature that is carried within the Sig.
// Today this can either be an ECDSA sig or a schnorr sig. Both of these can
// fit cleanly into 64 bytes.
type sigType uint
const (
// sigTypeECDSA represents an ECDSA signature.
sigTypeECDSA sigType = iota
// sigTypeSchnorr represents a schnorr signature.
sigTypeSchnorr
)
// Sig is a fixed-sized ECDSA signature or 64-byte schnorr signature. For the
// ECDSA sig, unlike Bitcoin, we use fixed sized signatures on the wire,
// instead of DER encoded signatures. This type provides several methods to
// convert to/from a regular Bitcoin DER encoded signature (raw bytes and
// *ecdsa.Signature).
type Sig struct {
bytes [64]byte
sigType sigType
}
// ForceSchnorr forces the signature to be interpreted as a schnorr signature.
// This is useful when reading an HTLC sig off the wire for a taproot channel.
// In this case, in order to obtain an input.Signature, we need to know that
// the sig is a schnorr sig.
func (s *Sig) ForceSchnorr() {
s.sigType = sigTypeSchnorr
}
// RawBytes returns the raw bytes of signature.
func (s *Sig) RawBytes() []byte {
return s.bytes[:]
}
// Copy copies the signature into a new Sig instance.
func (s *Sig) Copy() Sig {
var sCopy Sig
copy(sCopy.bytes[:], s.bytes[:])
sCopy.sigType = s.sigType
return sCopy
}
// Record returns a Record that can be used to encode or decode the backing
// object.
//
// This returns a record that serializes the sig as a 64-byte fixed size
// signature.
func (s *Sig) Record() tlv.Record {
// We set a type here as zero as it isn't needed when used as a
// RecordT.
return tlv.MakePrimitiveRecord(0, &s.bytes)
}
// NewSigFromWireECDSA returns a Sig instance based on an ECDSA signature
// that's already in the 64-byte format we expect.
func NewSigFromWireECDSA(sig []byte) (Sig, error) {
if len(sig) != 64 {
return Sig{}, fmt.Errorf("%w: %v bytes", errSigTooShort,
len(sig))
}
var s Sig
copy(s.bytes[:], sig)
return s, nil
}
// NewSigFromECDSARawSignature returns a Sig from a Bitcoin raw signature
// encoded in the canonical DER encoding.
func NewSigFromECDSARawSignature(sig []byte) (Sig, error) {
var b [64]byte
// Check the total length is above the minimal.
if len(sig) < ecdsa.MinSigLen {
return Sig{}, errSigTooShort
}
// The DER representation is laid out as:
// 0x30 <length> 0x02 <length r> r 0x02 <length s> s
// which means the length of R is the 4th byte and the length of S is
// the second byte after R ends. 0x02 signifies a length-prefixed,
// zero-padded, big-endian bigint. 0x30 signifies a DER signature.
// See the Serialize() method for ecdsa.Signature for details.
// Reading <length>, remaining: [0x02 <length r> r 0x02 <length s> s]
sigLen := int(sig[1])
// siglen should be less than the entire message and greater than
// the minimal message size.
if sigLen+2 > len(sig) || sigLen+2 < ecdsa.MinSigLen {
return Sig{}, errBadLength
}
// Reading <length r>, remaining: [r 0x02 <length s> s]
rLen := int(sig[3])
// rLen must be positive and must be able to fit in other elements.
// Assuming s is one byte, then we have 0x30, <length>, 0x20,
// <length r>, 0x20, <length s>, s, a total of 7 bytes.
if rLen <= 0 || rLen+7 > len(sig) {
return Sig{}, errBadRLength
}
// Reading <length s>, remaining: [s]
sLen := int(sig[5+rLen])
// S should be the rest of the string.
// sLen must be positive and must be able to fit in other elements.
// We know r is rLen bytes, and we have 0x30, <length>, 0x20,
// <length r>, 0x20, <length s>, a total of rLen+6 bytes.
if sLen <= 0 || sLen+rLen+6 > len(sig) {
return Sig{}, errBadSLength
}
// Check to make sure R and S can both fit into their intended buffers.
// We check S first because these code blocks decrement sLen and rLen
// in the case of a 33-byte 0-padded integer returned from Serialize()
// and rLen is used in calculating array indices for S. We can track
// this with additional variables, but it's more efficient to just
// check S first.
if sLen > 32 {
if (sLen > 33) || (sig[6+rLen] != 0x00) {
return Sig{}, errSTooLong
}
sLen--
copy(b[64-sLen:], sig[7+rLen:])
} else {
copy(b[64-sLen:], sig[6+rLen:])
}
// Do the same for R as we did for S
if rLen > 32 {
if (rLen > 33) || (sig[4] != 0x00) {
return Sig{}, errRTooLong
}
rLen--
copy(b[32-rLen:], sig[5:5+rLen])
} else {
copy(b[32-rLen:], sig[4:4+rLen])
}
return Sig{
bytes: b,
sigType: sigTypeECDSA,
}, nil
}
// NewSigFromSchnorrRawSignature converts a raw schnorr signature into an
// lnwire.Sig.
func NewSigFromSchnorrRawSignature(sig []byte) (Sig, error) {
var s Sig
copy(s.bytes[:], sig)
s.sigType = sigTypeSchnorr
return s, nil
}
// NewSigFromSignature creates a new signature as used on the wire, from an
// existing ecdsa.Signature or schnorr.Signature.
func NewSigFromSignature(e input.Signature) (Sig, error) {
if e == nil {
return Sig{}, fmt.Errorf("cannot decode empty signature")
}
// Nil is still a valid interface, apparently. So we need a more
// explicit check here.
if ecsig, ok := e.(*ecdsa.Signature); ok && ecsig == nil {
return Sig{}, fmt.Errorf("cannot decode empty signature")
}
switch ecSig := e.(type) {
// If this is a schnorr signature, then we can just pack it as normal,
// since the default encoding is already 64 bytes.
case *schnorr.Signature:
return NewSigFromSchnorrRawSignature(e.Serialize())
// For ECDSA signatures, we'll need to do a bit more work to map the
// signature into a compact 64 byte form.
case *ecdsa.Signature:
// Serialize the signature with all the checks that entails.
return NewSigFromECDSARawSignature(e.Serialize())
default:
return Sig{}, fmt.Errorf("unknown wire sig type: %T", ecSig)
}
}
// ToSignature converts the fixed-sized signature to a input.Signature which
// can be used for signature validation checks.
func (s *Sig) ToSignature() (input.Signature, error) {
switch s.sigType {
case sigTypeSchnorr:
return schnorr.ParseSignature(s.bytes[:])
case sigTypeECDSA:
// Parse the signature with strict checks.
sigBytes := s.ToSignatureBytes()
sig, err := ecdsa.ParseDERSignature(sigBytes)
if err != nil {
return nil, err
}
return sig, nil
default:
return nil, fmt.Errorf("unknown sig type: %v", s.sigType)
}
}
// ToSignatureBytes serializes the target fixed-sized signature into the
// encoding of the primary domain for the signature. For ECDSA signatures, this
// is the raw bytes of a DER encoding.
func (s *Sig) ToSignatureBytes() []byte {
switch s.sigType {
// For ECDSA signatures, we'll convert to DER encoding.
case sigTypeECDSA:
// Extract canonically-padded bigint representations from buffer
r := extractCanonicalPadding(s.bytes[0:32])
s := extractCanonicalPadding(s.bytes[32:64])
rLen := uint8(len(r))
sLen := uint8(len(s))
// Create a canonical serialized signature. DER format is:
// 0x30 <length> 0x02 <length r> r 0x02 <length s> s
sigBytes := make([]byte, 6+rLen+sLen)
sigBytes[0] = 0x30 // DER signature magic value
sigBytes[1] = 4 + rLen + sLen // Length of rest of signature
sigBytes[2] = 0x02 // Big integer magic value
sigBytes[3] = rLen // Length of R
sigBytes[rLen+4] = 0x02 // Big integer magic value
sigBytes[rLen+5] = sLen // Length of S
copy(sigBytes[4:], r) // Copy R
copy(sigBytes[rLen+6:], s) // Copy S
return sigBytes
// For schnorr signatures, we can use the same internal 64 bytes.
case sigTypeSchnorr:
// We'll make a copy of the signature so we don't return a
// reference into the raw slice.
var sig [64]byte
copy(sig[:], s.bytes[:])
return sig[:]
default:
// TODO(roasbeef): can only be called via public methods so
// never reachable?
panic("sig type not set")
}
}
// extractCanonicalPadding is a utility function to extract the canonical
// padding of a big-endian integer from the wire encoding (a 0-padded
// big-endian integer) such that it passes btcec.canonicalPadding test.
func extractCanonicalPadding(b []byte) []byte {
for i := 0; i < len(b); i++ {
// Found first non-zero byte.
if b[i] > 0 {
// If the MSB is set, we need zero padding.
if b[i]&0x80 == 0x80 {
return append([]byte{0x00}, b[i:]...)
}
return b[i:]
}
}
return []byte{0x00}
}
package lnwire
import (
"bytes"
"io"
)
// Stfu is a message that is sent to lock the channel state prior to some other
// interactive protocol where channel updates need to be paused.
type Stfu struct {
// ChanID identifies which channel needs to be frozen.
ChanID ChannelID
// Initiator is a byte that identifies whether we are the initiator of
// this process.
Initiator bool
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure Stfu implements the lnwire.Message interface.
var _ Message = (*Stfu)(nil)
// A compile time check to ensure Stfu implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Stfu)(nil)
// Encode serializes the target Stfu into the passed io.Writer.
// Serialization will observe the rules defined by the passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (s *Stfu) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, s.ChanID); err != nil {
return err
}
if err := WriteBool(w, s.Initiator); err != nil {
return err
}
return WriteBytes(w, s.ExtraData)
}
// Decode deserializes the serialized Stfu stored in the passed io.Reader
// into the target Stfu using the deserialization rules defined by the
// passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (s *Stfu) Decode(r io.Reader, _ uint32) error {
if err := ReadElements(
r, &s.ChanID, &s.Initiator, &s.ExtraData,
); err != nil {
return err
}
// This is required to pass the fuzz test round trip equality check.
if len(s.ExtraData) == 0 {
s.ExtraData = nil
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as a Stfu on the wire.
//
// This is part of the lnwire.Message interface.
func (s *Stfu) MsgType() MessageType {
return MsgStfu
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (s *Stfu) SerializedSize() (uint32, error) {
return MessageSerializedSize(s)
}
// A compile time check to ensure Stfu implements the
// lnwire.LinkUpdater interface.
var _ LinkUpdater = (*Stfu)(nil)
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (s *Stfu) TargetChanID() ChannelID {
return s.ChanID
}
package lnwire
import (
"bytes"
"fmt"
"image/color"
"math"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
"pgregory.net/rapid"
)
// TestMessage is an interface that extends the base Message interface with a
// method to populate the message with random testing data.
type TestMessage interface {
Message
// RandTestMessage populates the message with random data suitable for
// testing. It uses the rapid testing framework to generate random
// values.
RandTestMessage(t *rapid.T) Message
}
// A compile time check to ensure AcceptChannel implements the TestMessage
// interface.
var _ TestMessage = (*AcceptChannel)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *AcceptChannel) RandTestMessage(t *rapid.T) Message {
var pendingChanID [32]byte
pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "pendingChanID",
)
copy(pendingChanID[:], pendingChanIDBytes)
var channelType *ChannelType
includeChannelType := rapid.Bool().Draw(t, "includeChannelType")
includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry")
includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce")
if includeChannelType {
channelType = RandChannelType(t)
}
var leaseExpiry *LeaseExpiry
if includeLeaseExpiry {
leaseExpiry = RandLeaseExpiry(t)
}
var localNonce OptMusig2NonceTLV
if includeLocalNonce {
nonce := RandMusig2Nonce(t)
localNonce = tlv.SomeRecordT(
tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce),
)
}
return &AcceptChannel{
PendingChannelID: pendingChanID,
DustLimit: btcutil.Amount(
rapid.IntRange(100, 1000).Draw(t, "dustLimit"),
),
MaxValueInFlight: MilliSatoshi(
rapid.IntRange(10000, 1000000).Draw(
t, "maxValueInFlight",
),
),
ChannelReserve: btcutil.Amount(
rapid.IntRange(1000, 10000).Draw(t, "channelReserve"),
),
HtlcMinimum: MilliSatoshi(
rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"),
),
MinAcceptDepth: uint32(
rapid.IntRange(1, 10).Draw(t, "minAcceptDepth"),
),
CsvDelay: uint16(
rapid.IntRange(144, 1000).Draw(t, "csvDelay"),
),
MaxAcceptedHTLCs: uint16(
rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"),
),
FundingKey: RandPubKey(t),
RevocationPoint: RandPubKey(t),
PaymentPoint: RandPubKey(t),
DelayedPaymentPoint: RandPubKey(t),
HtlcPoint: RandPubKey(t),
FirstCommitmentPoint: RandPubKey(t),
UpfrontShutdownScript: RandDeliveryAddress(t),
ChannelType: channelType,
LeaseExpiry: leaseExpiry,
LocalNonce: localNonce,
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure AnnounceSignatures1 implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*AnnounceSignatures1)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *AnnounceSignatures1) RandTestMessage(t *rapid.T) Message {
return &AnnounceSignatures1{
ChannelID: RandChannelID(t),
ShortChannelID: RandShortChannelID(t),
NodeSignature: RandSignature(t),
BitcoinSignature: RandSignature(t),
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure AnnounceSignatures2 implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*AnnounceSignatures2)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *AnnounceSignatures2) RandTestMessage(t *rapid.T) Message {
return &AnnounceSignatures2{
ChannelID: RandChannelID(t),
ShortChannelID: RandShortChannelID(t),
PartialSignature: *RandPartialSig(t),
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure ChannelAnnouncement1 implements the
// TestMessage interface.
var _ TestMessage = (*ChannelAnnouncement1)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *ChannelAnnouncement1) RandTestMessage(t *rapid.T) Message {
// Generate Node IDs and Bitcoin keys (compressed public keys)
node1PubKey := RandPubKey(t)
node2PubKey := RandPubKey(t)
bitcoin1PubKey := RandPubKey(t)
bitcoin2PubKey := RandPubKey(t)
// Convert to byte arrays
var nodeID1, nodeID2, bitcoinKey1, bitcoinKey2 [33]byte
copy(nodeID1[:], node1PubKey.SerializeCompressed())
copy(nodeID2[:], node2PubKey.SerializeCompressed())
copy(bitcoinKey1[:], bitcoin1PubKey.SerializeCompressed())
copy(bitcoinKey2[:], bitcoin2PubKey.SerializeCompressed())
// Ensure nodeID1 is numerically less than nodeID2
// This is a requirement stated in the field description
if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 {
nodeID1, nodeID2 = nodeID2, nodeID1
}
// Generate chain hash
chainHash := RandChainHash(t)
var hash chainhash.Hash
copy(hash[:], chainHash[:])
return &ChannelAnnouncement1{
NodeSig1: RandSignature(t),
NodeSig2: RandSignature(t),
BitcoinSig1: RandSignature(t),
BitcoinSig2: RandSignature(t),
Features: RandFeatureVector(t),
ChainHash: hash,
ShortChannelID: RandShortChannelID(t),
NodeID1: nodeID1,
NodeID2: nodeID2,
BitcoinKey1: bitcoinKey1,
BitcoinKey2: bitcoinKey2,
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure ChannelAnnouncement2 implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ChannelAnnouncement2)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message {
features := RandFeatureVector(t)
shortChanID := RandShortChannelID(t)
capacity := uint64(rapid.IntRange(1, 16777215).Draw(t, "capacity"))
var nodeID1, nodeID2 [33]byte
copy(nodeID1[:], RandPubKey(t).SerializeCompressed())
copy(nodeID2[:], RandPubKey(t).SerializeCompressed())
// Make sure nodeID1 is numerically less than nodeID2 (as per spec).
if bytes.Compare(nodeID1[:], nodeID2[:]) > 0 {
nodeID1, nodeID2 = nodeID2, nodeID1
}
chainHash := RandChainHash(t)
var chainHashObj chainhash.Hash
copy(chainHashObj[:], chainHash[:])
msg := &ChannelAnnouncement2{
Signature: RandSignature(t),
ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash](
chainHashObj,
),
Features: tlv.NewRecordT[tlv.TlvType2, RawFeatureVector](
*features,
),
ShortChannelID: tlv.NewRecordT[tlv.TlvType4, ShortChannelID](
shortChanID,
),
Capacity: tlv.NewPrimitiveRecord[tlv.TlvType6, uint64](
capacity,
),
NodeID1: tlv.NewPrimitiveRecord[tlv.TlvType8, [33]byte](
nodeID1,
),
NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte](
nodeID2,
),
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
msg.Signature.ForceSchnorr()
// Randomly include optional fields
if rapid.Bool().Draw(t, "includeBitcoinKey1") {
var bitcoinKey1 [33]byte
copy(bitcoinKey1[:], RandPubKey(t).SerializeCompressed())
msg.BitcoinKey1 = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType12, [33]byte](
bitcoinKey1,
),
)
}
if rapid.Bool().Draw(t, "includeBitcoinKey2") {
var bitcoinKey2 [33]byte
copy(bitcoinKey2[:], RandPubKey(t).SerializeCompressed())
msg.BitcoinKey2 = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType14, [33]byte](
bitcoinKey2,
),
)
}
if rapid.Bool().Draw(t, "includeMerkleRootHash") {
hash := RandSHA256Hash(t)
var merkleRootHash [32]byte
copy(merkleRootHash[:], hash[:])
msg.MerkleRootHash = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType16, [32]byte](
merkleRootHash,
),
)
}
return msg
}
// A compile time check to ensure ChannelReady implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*ChannelReady)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ChannelReady) RandTestMessage(t *rapid.T) Message {
msg := &ChannelReady{
ChanID: RandChannelID(t),
NextPerCommitmentPoint: RandPubKey(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
includeAliasScid := rapid.Bool().Draw(t, "includeAliasScid")
includeNextLocalNonce := rapid.Bool().Draw(t, "includeNextLocalNonce")
includeAnnouncementNodeNonce := rapid.Bool().Draw(
t, "includeAnnouncementNodeNonce",
)
includeAnnouncementBitcoinNonce := rapid.Bool().Draw(
t, "includeAnnouncementBitcoinNonce",
)
if includeAliasScid {
scid := RandShortChannelID(t)
msg.AliasScid = &scid
}
if includeNextLocalNonce {
nonce := RandMusig2Nonce(t)
msg.NextLocalNonce = SomeMusig2Nonce(nonce)
}
if includeAnnouncementNodeNonce {
nonce := RandMusig2Nonce(t)
msg.AnnouncementNodeNonce = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType0, Musig2Nonce](nonce),
)
}
if includeAnnouncementBitcoinNonce {
nonce := RandMusig2Nonce(t)
msg.AnnouncementBitcoinNonce = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType2, Musig2Nonce](nonce),
)
}
return msg
}
// A compile time check to ensure ChannelReestablish implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ChannelReestablish)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *ChannelReestablish) RandTestMessage(t *rapid.T) Message {
msg := &ChannelReestablish{
ChanID: RandChannelID(t),
NextLocalCommitHeight: rapid.Uint64().Draw(
t, "nextLocalCommitHeight",
),
RemoteCommitTailHeight: rapid.Uint64().Draw(
t, "remoteCommitTailHeight",
),
LastRemoteCommitSecret: RandPaymentPreimage(t),
LocalUnrevokedCommitPoint: RandPubKey(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
// Randomly decide whether to include optional fields
includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce")
includeDynHeight := rapid.Bool().Draw(t, "includeDynHeight")
if includeLocalNonce {
nonce := RandMusig2Nonce(t)
msg.LocalNonce = SomeMusig2Nonce(nonce)
}
if includeDynHeight {
height := DynHeight(rapid.Uint64().Draw(t, "dynHeight"))
msg.DynHeight = fn.Some(height)
}
return msg
}
// A compile time check to ensure ChannelUpdate1 implements the TestMessage
// interface.
var _ TestMessage = (*ChannelUpdate1)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message {
// Generate random message flags
// Randomly decide whether to include max HTLC field
includeMaxHtlc := rapid.Bool().Draw(t, "includeMaxHtlc")
var msgFlags ChanUpdateMsgFlags
if includeMaxHtlc {
msgFlags |= ChanUpdateRequiredMaxHtlc
}
// Generate random channel flags
// Randomly decide direction (node1 or node2)
isNode2 := rapid.Bool().Draw(t, "isNode2")
var chanFlags ChanUpdateChanFlags
if isNode2 {
chanFlags |= ChanUpdateDirection
}
// Randomly decide if channel is disabled
isDisabled := rapid.Bool().Draw(t, "isDisabled")
if isDisabled {
chanFlags |= ChanUpdateDisabled
}
// Generate chain hash
chainHash := RandChainHash(t)
var hash chainhash.Hash
copy(hash[:], chainHash[:])
// Generate other random fields
maxHtlc := MilliSatoshi(rapid.Uint64().Draw(t, "maxHtlc"))
// If max HTLC flag is not set, we need to zero the value
if !includeMaxHtlc {
maxHtlc = 0
}
return &ChannelUpdate1{
Signature: RandSignature(t),
ChainHash: hash,
ShortChannelID: RandShortChannelID(t),
Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw(
t, "timestamp"),
),
MessageFlags: msgFlags,
ChannelFlags: chanFlags,
TimeLockDelta: uint16(rapid.IntRange(0, 65535).Draw(
t, "timelockDelta"),
),
HtlcMinimumMsat: MilliSatoshi(rapid.Uint64().Draw(
t, "htlcMinimum"),
),
BaseFee: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw(
t, "baseFee"),
),
FeeRate: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw(
t, "feeRate"),
),
HtlcMaximumMsat: maxHtlc,
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure ChannelUpdate2 implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ChannelUpdate2)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message {
shortChanID := RandShortChannelID(t)
blockHeight := uint32(rapid.IntRange(0, 1000000).Draw(t, "blockHeight"))
var disabledFlags ChanUpdateDisableFlags
if rapid.Bool().Draw(t, "disableIncoming") {
disabledFlags |= ChanUpdateDisableIncoming
}
if rapid.Bool().Draw(t, "disableOutgoing") {
disabledFlags |= ChanUpdateDisableOutgoing
}
cltvExpiryDelta := uint16(rapid.IntRange(10, 200).Draw(
t, "cltvExpiryDelta"),
)
htlcMinMsat := MilliSatoshi(rapid.IntRange(1, 10000).Draw(
t, "htlcMinMsat"),
)
htlcMaxMsat := MilliSatoshi(rapid.IntRange(10000, 100000000).Draw(
t, "htlcMaxMsat"),
)
feeBaseMsat := uint32(rapid.IntRange(0, 10000).Draw(t, "feeBaseMsat"))
feeProportionalMillionths := uint32(rapid.IntRange(0, 10000).Draw(
t, "feeProportionalMillionths"),
)
chainHash := RandChainHash(t)
var chainHashObj chainhash.Hash
copy(chainHashObj[:], chainHash[:])
//nolint:ll
msg := &ChannelUpdate2{
Signature: RandSignature(t),
ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash](
chainHashObj,
),
ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID](
shortChanID,
),
BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType4, uint32](
blockHeight,
),
DisabledFlags: tlv.NewPrimitiveRecord[tlv.TlvType6, ChanUpdateDisableFlags]( //nolint:ll
disabledFlags,
),
CLTVExpiryDelta: tlv.NewPrimitiveRecord[tlv.TlvType10, uint16](
cltvExpiryDelta,
),
HTLCMinimumMsat: tlv.NewPrimitiveRecord[tlv.TlvType12, MilliSatoshi](
htlcMinMsat,
),
HTLCMaximumMsat: tlv.NewPrimitiveRecord[tlv.TlvType14, MilliSatoshi](
htlcMaxMsat,
),
FeeBaseMsat: tlv.NewPrimitiveRecord[tlv.TlvType16, uint32](
feeBaseMsat,
),
FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32](
feeProportionalMillionths,
),
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
msg.Signature.ForceSchnorr()
if rapid.Bool().Draw(t, "isSecondPeer") {
msg.SecondPeer = tlv.SomeRecordT(
tlv.RecordT[tlv.TlvType8, TrueBoolean]{},
)
}
return msg
}
// A compile time check to ensure ClosingComplete implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ClosingComplete)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ClosingComplete) RandTestMessage(t *rapid.T) Message {
msg := &ClosingComplete{
ChannelID: RandChannelID(t),
FeeSatoshis: btcutil.Amount(rapid.Int64Range(0, 1000000).Draw(
t, "feeSatoshis"),
),
LockTime: rapid.Uint32Range(0, 0xffffffff).Draw(
t, "lockTime",
),
CloseeScript: RandDeliveryAddress(t),
CloserScript: RandDeliveryAddress(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee")
includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee")
includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee")
// Ensure at least one signature is present.
if !includeCloserNoClosee && !includeNoCloserClosee &&
!includeCloserAndClosee {
// If all are false, enable at least one randomly.
choice := rapid.IntRange(0, 2).Draw(t, "sigChoice")
switch choice {
case 0:
includeCloserNoClosee = true
case 1:
includeNoCloserClosee = true
case 2:
includeCloserAndClosee = true
}
}
if includeCloserNoClosee {
sig := RandSignature(t)
msg.CloserNoClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType1, Sig](sig),
)
}
if includeNoCloserClosee {
sig := RandSignature(t)
msg.NoCloserClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType2, Sig](sig),
)
}
if includeCloserAndClosee {
sig := RandSignature(t)
msg.CloserAndClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3, Sig](sig),
)
}
return msg
}
// A compile time check to ensure ClosingSig implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*ClosingSig)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ClosingSig) RandTestMessage(t *rapid.T) Message {
msg := &ClosingSig{
ChannelID: RandChannelID(t),
CloseeScript: RandDeliveryAddress(t),
CloserScript: RandDeliveryAddress(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee")
includeNoCloserClosee := rapid.Bool().Draw(t, "includeNoCloserClosee")
includeCloserAndClosee := rapid.Bool().Draw(t, "includeCloserAndClosee")
// Ensure at least one signature is present.
if !includeCloserNoClosee && !includeNoCloserClosee &&
!includeCloserAndClosee {
// If all are false, enable at least one randomly.
choice := rapid.IntRange(0, 2).Draw(t, "sigChoice")
switch choice {
case 0:
includeCloserNoClosee = true
case 1:
includeNoCloserClosee = true
case 2:
includeCloserAndClosee = true
}
}
if includeCloserNoClosee {
sig := RandSignature(t)
msg.CloserNoClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType1, Sig](sig),
)
}
if includeNoCloserClosee {
sig := RandSignature(t)
msg.NoCloserClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType2, Sig](sig),
)
}
if includeCloserAndClosee {
sig := RandSignature(t)
msg.CloserAndClosee = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3, Sig](sig),
)
}
return msg
}
// A compile time check to ensure ClosingSigned implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ClosingSigned)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ClosingSigned) RandTestMessage(t *rapid.T) Message {
// Generate a random boolean to decide whether to include CommitSig or
// PartialSig Since they're mutually exclusive, when one is populated,
// the other must be blank.
usePartialSig := rapid.Bool().Draw(t, "usePartialSig")
msg := &ClosingSigned{
ChannelID: RandChannelID(t),
FeeSatoshis: btcutil.Amount(
rapid.Int64Range(0, 1000000).Draw(t, "feeSatoshis"),
),
ExtraData: RandExtraOpaqueData(t, nil),
}
if usePartialSig {
sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "sigScalar",
)
var s btcec.ModNScalar
_ = s.SetByteSlice(sigBytes)
msg.PartialSig = SomePartialSig(NewPartialSig(s))
msg.Signature = Sig{}
} else {
msg.Signature = RandSignature(t)
}
return msg
}
// A compile time check to ensure CommitSig implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*CommitSig)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *CommitSig) RandTestMessage(t *rapid.T) Message {
cr, _ := RandCustomRecords(t, nil, true)
sig := &CommitSig{
ChanID: RandChannelID(t),
CommitSig: RandSignature(t),
CustomRecords: cr,
}
numHtlcSigs := rapid.IntRange(0, 20).Draw(t, "numHtlcSigs")
htlcSigs := make([]Sig, numHtlcSigs)
for i := 0; i < numHtlcSigs; i++ {
htlcSigs[i] = RandSignature(t)
}
if len(htlcSigs) > 0 {
sig.HtlcSigs = htlcSigs
}
includePartialSig := rapid.Bool().Draw(t, "includePartialSig")
if includePartialSig {
sigWithNonce := RandPartialSigWithNonce(t)
sig.PartialSig = MaybePartialSigWithNonce(sigWithNonce)
}
return sig
}
// A compile time check to ensure Custom implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Custom)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *Custom) RandTestMessage(t *rapid.T) Message {
msgType := MessageType(
rapid.IntRange(int(CustomTypeStart), 65535).Draw(
t, "customMsgType",
),
)
dataLen := rapid.IntRange(0, 1000).Draw(t, "customDataLength")
data := rapid.SliceOfN(rapid.Byte(), dataLen, dataLen).Draw(
t, "customData",
)
msg, err := NewCustom(msgType, data)
if err != nil {
panic(fmt.Sprintf("Error creating custom message: %v", err))
}
return msg
}
// A compile time check to ensure DynAck implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*DynAck)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (da *DynAck) RandTestMessage(t *rapid.T) Message {
msg := &DynAck{
ChanID: RandChannelID(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce")
if includeLocalNonce {
msg.LocalNonce = fn.Some(RandMusig2Nonce(t))
}
return msg
}
// A compile time check to ensure DynPropose implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*DynPropose)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (dp *DynPropose) RandTestMessage(t *rapid.T) Message {
msg := &DynPropose{
ChanID: RandChannelID(t),
Initiator: rapid.Bool().Draw(t, "initiator"),
ExtraData: RandExtraOpaqueData(t, nil),
}
// Randomly decide which optional fields to include
includeDustLimit := rapid.Bool().Draw(t, "includeDustLimit")
includeMaxValueInFlight := rapid.Bool().Draw(
t, "includeMaxValueInFlight",
)
includeChannelReserve := rapid.Bool().Draw(t, "includeChannelReserve")
includeCsvDelay := rapid.Bool().Draw(t, "includeCsvDelay")
includeMaxAcceptedHTLCs := rapid.Bool().Draw(
t, "includeMaxAcceptedHTLCs",
)
includeFundingKey := rapid.Bool().Draw(t, "includeFundingKey")
includeChannelType := rapid.Bool().Draw(t, "includeChannelType")
includeKickoffFeerate := rapid.Bool().Draw(t, "includeKickoffFeerate")
// Generate random values for each included field
if includeDustLimit {
dl := btcutil.Amount(rapid.Uint32().Draw(t, "dustLimit"))
msg.DustLimit = fn.Some(dl)
}
if includeMaxValueInFlight {
mvif := MilliSatoshi(rapid.Uint64().Draw(t, "maxValueInFlight"))
msg.MaxValueInFlight = fn.Some(mvif)
}
if includeChannelReserve {
cr := btcutil.Amount(rapid.Uint32().Draw(t, "channelReserve"))
msg.ChannelReserve = fn.Some(cr)
}
if includeCsvDelay {
cd := rapid.Uint16().Draw(t, "csvDelay")
msg.CsvDelay = fn.Some(cd)
}
if includeMaxAcceptedHTLCs {
mah := rapid.Uint16().Draw(t, "maxAcceptedHTLCs")
msg.MaxAcceptedHTLCs = fn.Some(mah)
}
if includeFundingKey {
msg.FundingKey = fn.Some(*RandPubKey(t))
}
if includeChannelType {
msg.ChannelType = fn.Some(*RandChannelType(t))
}
if includeKickoffFeerate {
kf := chainfee.SatPerKWeight(rapid.Uint32().Draw(
t, "kickoffFeerate"),
)
msg.KickoffFeerate = fn.Some(kf)
}
return msg
}
// A compile time check to ensure DynReject implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*DynReject)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (dr *DynReject) RandTestMessage(t *rapid.T) Message {
featureVec := NewRawFeatureVector()
numFeatures := rapid.IntRange(0, 8).Draw(t, "numRejections")
for i := 0; i < numFeatures; i++ {
bit := FeatureBit(
rapid.IntRange(0, 31).Draw(
t, fmt.Sprintf("rejectionBit-%d", i),
),
)
featureVec.Set(bit)
}
var extraData ExtraOpaqueData
randData := RandExtraOpaqueData(t, nil)
if len(randData) > 0 {
extraData = randData
}
return &DynReject{
ChanID: RandChannelID(t),
UpdateRejections: *featureVec,
ExtraData: extraData,
}
}
// A compile time check to ensure FundingCreated implements the TestMessage
// interface.
var _ TestMessage = (*FundingCreated)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (f *FundingCreated) RandTestMessage(t *rapid.T) Message {
var pendingChanID [32]byte
pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "pendingChanID",
)
copy(pendingChanID[:], pendingChanIDBytes)
includePartialSig := rapid.Bool().Draw(t, "includePartialSig")
var partialSig OptPartialSigWithNonceTLV
var commitSig Sig
if includePartialSig {
sigWithNonce := RandPartialSigWithNonce(t)
partialSig = MaybePartialSigWithNonce(sigWithNonce)
// When using partial sig, CommitSig should be empty/blank.
commitSig = Sig{}
} else {
commitSig = RandSignature(t)
}
return &FundingCreated{
PendingChannelID: pendingChanID,
FundingPoint: RandOutPoint(t),
CommitSig: commitSig,
PartialSig: partialSig,
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure FundingSigned implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*FundingSigned)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (f *FundingSigned) RandTestMessage(t *rapid.T) Message {
usePartialSig := rapid.Bool().Draw(t, "usePartialSig")
msg := &FundingSigned{
ChanID: RandChannelID(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
if usePartialSig {
sigWithNonce := RandPartialSigWithNonce(t)
msg.PartialSig = MaybePartialSigWithNonce(sigWithNonce)
msg.CommitSig = Sig{}
} else {
msg.CommitSig = RandSignature(t)
}
return msg
}
// A compile time check to ensure GossipTimestampRange implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*GossipTimestampRange)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (g *GossipTimestampRange) RandTestMessage(t *rapid.T) Message {
var chainHash chainhash.Hash
hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
copy(chainHash[:], hashBytes)
msg := &GossipTimestampRange{
ChainHash: chainHash,
FirstTimestamp: rapid.Uint32().Draw(t, "firstTimestamp"),
TimestampRange: rapid.Uint32().Draw(t, "timestampRange"),
ExtraData: RandExtraOpaqueData(t, nil),
}
includeFirstBlockHeight := rapid.Bool().Draw(
t, "includeFirstBlockHeight",
)
includeBlockRange := rapid.Bool().Draw(t, "includeBlockRange")
if includeFirstBlockHeight {
height := rapid.Uint32().Draw(t, "firstBlockHeight")
msg.FirstBlockHeight = tlv.SomeRecordT(
tlv.RecordT[tlv.TlvType2, uint32]{Val: height},
)
}
if includeBlockRange {
blockRange := rapid.Uint32().Draw(t, "blockRange")
msg.BlockRange = tlv.SomeRecordT(
tlv.RecordT[tlv.TlvType4, uint32]{Val: blockRange},
)
}
return msg
}
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (msg *Init) RandTestMessage(t *rapid.T) Message {
global := NewRawFeatureVector()
local := NewRawFeatureVector()
numGlobalFeatures := rapid.IntRange(0, 20).Draw(t, "numGlobalFeatures")
for i := 0; i < numGlobalFeatures; i++ {
bit := FeatureBit(
rapid.IntRange(0, 100).Draw(
t, fmt.Sprintf("globalFeatureBit%d", i),
),
)
global.Set(bit)
}
numLocalFeatures := rapid.IntRange(0, 20).Draw(t, "numLocalFeatures")
for i := 0; i < numLocalFeatures; i++ {
bit := FeatureBit(
rapid.IntRange(0, 100).Draw(
t, fmt.Sprintf("localFeatureBit%d", i),
),
)
local.Set(bit)
}
return NewInitMessage(global, local)
}
// A compile time check to ensure KickoffSig implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*KickoffSig)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (ks *KickoffSig) RandTestMessage(t *rapid.T) Message {
return &KickoffSig{
ChanID: RandChannelID(t),
Signature: RandSignature(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure NodeAnnouncement implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*NodeAnnouncement)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (a *NodeAnnouncement) RandTestMessage(t *rapid.T) Message {
// Generate random compressed public key for node ID
pubKey := RandPubKey(t)
var nodeID [33]byte
copy(nodeID[:], pubKey.SerializeCompressed())
// Generate random RGB color
rgbColor := color.RGBA{
R: uint8(rapid.IntRange(0, 255).Draw(t, "rgbR")),
G: uint8(rapid.IntRange(0, 255).Draw(t, "rgbG")),
B: uint8(rapid.IntRange(0, 255).Draw(t, "rgbB")),
}
return &NodeAnnouncement{
Signature: RandSignature(t),
Features: RandFeatureVector(t),
Timestamp: uint32(rapid.IntRange(0, 0x7FFFFFFF).Draw(
t, "timestamp"),
),
NodeID: nodeID,
RGBColor: rgbColor,
Alias: RandNodeAlias(t),
Addresses: RandNetAddrs(t),
ExtraOpaqueData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure OpenChannel implements the TestMessage
// interface.
var _ TestMessage = (*OpenChannel)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (o *OpenChannel) RandTestMessage(t *rapid.T) Message {
chainHash := RandChainHash(t)
var hash chainhash.Hash
copy(hash[:], chainHash[:])
var pendingChanID [32]byte
pendingChanIDBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "pendingChanID",
)
copy(pendingChanID[:], pendingChanIDBytes)
includeChannelType := rapid.Bool().Draw(t, "includeChannelType")
includeLeaseExpiry := rapid.Bool().Draw(t, "includeLeaseExpiry")
includeLocalNonce := rapid.Bool().Draw(t, "includeLocalNonce")
var channelFlags FundingFlag
if rapid.Bool().Draw(t, "announceChannel") {
channelFlags |= FFAnnounceChannel
}
var localNonce OptMusig2NonceTLV
if includeLocalNonce {
nonce := RandMusig2Nonce(t)
localNonce = tlv.SomeRecordT(
tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce),
)
}
var channelType *ChannelType
if includeChannelType {
channelType = RandChannelType(t)
}
var leaseExpiry *LeaseExpiry
if includeLeaseExpiry {
leaseExpiry = RandLeaseExpiry(t)
}
return &OpenChannel{
ChainHash: hash,
PendingChannelID: pendingChanID,
FundingAmount: btcutil.Amount(
rapid.IntRange(5000, 10000000).Draw(t, "fundingAmount"),
),
PushAmount: MilliSatoshi(
rapid.IntRange(0, 1000000).Draw(t, "pushAmount"),
),
DustLimit: btcutil.Amount(
rapid.IntRange(100, 1000).Draw(t, "dustLimit"),
),
MaxValueInFlight: MilliSatoshi(
rapid.IntRange(10000, 1000000).Draw(
t, "maxValueInFlight",
),
),
ChannelReserve: btcutil.Amount(
rapid.IntRange(1000, 10000).Draw(t, "channelReserve"),
),
HtlcMinimum: MilliSatoshi(
rapid.IntRange(1, 1000).Draw(t, "htlcMinimum"),
),
FeePerKiloWeight: uint32(
rapid.IntRange(250, 10000).Draw(t, "feePerKw"),
),
CsvDelay: uint16(
rapid.IntRange(144, 1000).Draw(t, "csvDelay"),
),
MaxAcceptedHTLCs: uint16(
rapid.IntRange(10, 500).Draw(t, "maxAcceptedHTLCs"),
),
FundingKey: RandPubKey(t),
RevocationPoint: RandPubKey(t),
PaymentPoint: RandPubKey(t),
DelayedPaymentPoint: RandPubKey(t),
HtlcPoint: RandPubKey(t),
FirstCommitmentPoint: RandPubKey(t),
ChannelFlags: channelFlags,
UpfrontShutdownScript: RandDeliveryAddress(t),
ChannelType: channelType,
LeaseExpiry: leaseExpiry,
LocalNonce: localNonce,
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure Ping implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Ping)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (p *Ping) RandTestMessage(t *rapid.T) Message {
numPongBytes := uint16(rapid.IntRange(0, int(MaxPongBytes)).Draw(
t, "numPongBytes"),
)
// Generate padding bytes (but keeping within allowed message size)
// MaxMsgBody - 2 (for NumPongBytes) - 2 (for padding length)
maxPaddingLen := MaxMsgBody - 4
paddingLen := rapid.IntRange(0, maxPaddingLen).Draw(
t, "paddingLen",
)
padding := make(PingPayload, paddingLen)
// Fill padding with random bytes
for i := 0; i < paddingLen; i++ {
padding[i] = byte(rapid.IntRange(0, 255).Draw(
t, fmt.Sprintf("paddingByte%d", i)),
)
}
return &Ping{
NumPongBytes: numPongBytes,
PaddingBytes: padding,
}
}
// A compile time check to ensure Pong implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Pong)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (p *Pong) RandTestMessage(t *rapid.T) Message {
payloadLen := rapid.IntRange(0, 1000).Draw(t, "pongPayloadLength")
payload := rapid.SliceOfN(rapid.Byte(), payloadLen, payloadLen).Draw(
t, "pongPayload",
)
return &Pong{
PongBytes: payload,
}
}
// A compile time check to ensure QueryChannelRange implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*QueryChannelRange)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (q *QueryChannelRange) RandTestMessage(t *rapid.T) Message {
msg := &QueryChannelRange{
FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw(
t, "firstBlockHeight"),
),
NumBlocks: uint32(rapid.IntRange(1, 10000).Draw(
t, "numBlocks"),
),
ExtraData: RandExtraOpaqueData(t, nil),
}
// Generate chain hash
chainHash := RandChainHash(t)
var chainHashObj chainhash.Hash
copy(chainHashObj[:], chainHash[:])
msg.ChainHash = chainHashObj
// Randomly include QueryOptions
if rapid.Bool().Draw(t, "includeQueryOptions") {
queryOptions := &QueryOptions{}
*queryOptions = QueryOptions(*RandFeatureVector(t))
msg.QueryOptions = queryOptions
}
return msg
}
// A compile time check to ensure QueryShortChanIDs implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*QueryShortChanIDs)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (q *QueryShortChanIDs) RandTestMessage(t *rapid.T) Message {
var chainHash chainhash.Hash
hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
copy(chainHash[:], hashBytes)
encodingType := EncodingSortedPlain
if rapid.Bool().Draw(t, "useZlibEncoding") {
encodingType = EncodingSortedZlib
}
msg := &QueryShortChanIDs{
ChainHash: chainHash,
EncodingType: encodingType,
ExtraData: RandExtraOpaqueData(t, nil),
noSort: false,
}
numIDs := rapid.IntRange(2, 20).Draw(t, "numShortChanIDs")
// Generate sorted short channel IDs.
shortChanIDs := make([]ShortChannelID, numIDs)
for i := 0; i < numIDs; i++ {
shortChanIDs[i] = RandShortChannelID(t)
// Ensure they're properly sorted.
if i > 0 && shortChanIDs[i].ToUint64() <=
shortChanIDs[i-1].ToUint64() {
// Ensure this ID is larger than the previous one.
shortChanIDs[i] = NewShortChanIDFromInt(
shortChanIDs[i-1].ToUint64() + 1,
)
}
}
msg.ShortChanIDs = shortChanIDs
return msg
}
// A compile time check to ensure ReplyChannelRange implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ReplyChannelRange)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ReplyChannelRange) RandTestMessage(t *rapid.T) Message {
msg := &ReplyChannelRange{
FirstBlockHeight: uint32(rapid.IntRange(0, 1000000).Draw(
t, "firstBlockHeight"),
),
NumBlocks: uint32(rapid.IntRange(1, 10000).Draw(
t, "numBlocks"),
),
Complete: uint8(rapid.IntRange(0, 1).Draw(t, "complete")),
EncodingType: QueryEncoding(
rapid.IntRange(0, 1).Draw(t, "encodingType"),
),
ExtraData: RandExtraOpaqueData(t, nil),
}
msg.ChainHash = RandChainHash(t)
numShortChanIDs := rapid.IntRange(0, 20).Draw(t, "numShortChanIDs")
if numShortChanIDs == 0 {
return msg
}
scidSet := fn.NewSet[ShortChannelID]()
scids := make([]ShortChannelID, numShortChanIDs)
for i := 0; i < numShortChanIDs; i++ {
scid := RandShortChannelID(t)
for scidSet.Contains(scid) {
scid = RandShortChannelID(t)
}
scids[i] = scid
scidSet.Add(scid)
}
// Make sure there're no duplicates.
msg.ShortChanIDs = scids
if rapid.Bool().Draw(t, "includeTimestamps") && numShortChanIDs > 0 {
msg.Timestamps = make(Timestamps, numShortChanIDs)
for i := 0; i < numShortChanIDs; i++ {
msg.Timestamps[i] = ChanUpdateTimestamps{
Timestamp1: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-1-%d", i))), //nolint:ll
Timestamp2: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-2-%d", i))), //nolint:ll
}
}
}
return msg
}
// A compile time check to ensure ReplyShortChanIDsEnd implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*ReplyShortChanIDsEnd)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *ReplyShortChanIDsEnd) RandTestMessage(t *rapid.T) Message {
var chainHash chainhash.Hash
hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
copy(chainHash[:], hashBytes)
complete := uint8(rapid.IntRange(0, 1).Draw(t, "complete"))
return &ReplyShortChanIDsEnd{
ChainHash: chainHash,
Complete: complete,
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// RandTestMessage returns a RevokeAndAck message populated with random data.
//
// This is part of the TestMessage interface.
func (c *RevokeAndAck) RandTestMessage(t *rapid.T) Message {
msg := NewRevokeAndAck()
var chanID ChannelID
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID")
copy(chanID[:], bytes)
msg.ChanID = chanID
revBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "revocation")
copy(msg.Revocation[:], revBytes)
msg.NextRevocationKey = RandPubKey(t)
if rapid.Bool().Draw(t, "includeLocalNonce") {
var nonce Musig2Nonce
nonceBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "nonce",
)
copy(nonce[:], nonceBytes)
msg.LocalNonce = tlv.SomeRecordT(
tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce),
)
}
return msg
}
// A compile-time check to ensure Shutdown implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Shutdown)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (s *Shutdown) RandTestMessage(t *rapid.T) Message {
// Generate random delivery address
// First decide the address type (P2PKH, P2SH, P2WPKH, P2WSH, P2TR)
addrType := rapid.IntRange(0, 4).Draw(t, "addrType")
// Generate random address length based on type
var addrLen int
switch addrType {
// P2PKH
case 0:
addrLen = 25
// P2SH
case 1:
addrLen = 23
// P2WPKH
case 2:
addrLen = 22
// P2WSH
case 3:
addrLen = 34
// P2TR
case 4:
addrLen = 34
}
addr := rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw(
t, "address",
)
// Randomly decide whether to include a shutdown nonce
includeNonce := rapid.Bool().Draw(t, "includeNonce")
var shutdownNonce ShutdownNonceTLV
if includeNonce {
shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t))
}
cr, _ := RandCustomRecords(t, nil, true)
return &Shutdown{
ChannelID: RandChannelID(t),
Address: addr,
ShutdownNonce: shutdownNonce,
CustomRecords: cr,
}
}
// A compile time check to ensure Stfu implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Stfu)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (s *Stfu) RandTestMessage(t *rapid.T) Message {
m := &Stfu{
ChanID: RandChannelID(t),
Initiator: rapid.Bool().Draw(t, "initiator"),
}
extraData := RandExtraOpaqueData(t, nil)
if len(extraData) > 0 {
m.ExtraData = extraData
}
return m
}
// A compile time check to ensure UpdateAddHTLC implements the
// lnwire.TestMessage interface.
var _ TestMessage = (*UpdateAddHTLC)(nil)
// RandTestMessage returns an UpdateAddHTLC message populated with random data.
//
// This is part of the TestMessage interface.
func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message {
msg := &UpdateAddHTLC{
ChanID: RandChannelID(t),
ID: rapid.Uint64().Draw(t, "id"),
Amount: MilliSatoshi(rapid.Uint64().Draw(t, "amount")),
Expiry: rapid.Uint32().Draw(t, "expiry"),
}
hashBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash")
copy(msg.PaymentHash[:], hashBytes)
onionBytes := rapid.SliceOfN(
rapid.Byte(), OnionPacketSize, OnionPacketSize,
).Draw(t, "onionBlob")
copy(msg.OnionBlob[:], onionBytes)
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
if numRecords > 0 {
msg.CustomRecords, _ = RandCustomRecords(t, nil, true)
}
// 50/50 chance to add a blinding point
if rapid.Bool().Draw(t, "includeBlindingPoint") {
pubKey := RandPubKey(t)
msg.BlindingPoint = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[BlindingPointTlvType](pubKey),
)
}
return msg
}
// A compile time check to ensure UpdateFailHTLC implements the TestMessage
// interface.
var _ TestMessage = (*UpdateFailHTLC)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *UpdateFailHTLC) RandTestMessage(t *rapid.T) Message {
return &UpdateFailHTLC{
ChanID: RandChannelID(t),
ID: rapid.Uint64().Draw(t, "id"),
Reason: RandOpaqueReason(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure UpdateFailMalformedHTLC implements the
// TestMessage interface.
var _ TestMessage = (*UpdateFailMalformedHTLC)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *UpdateFailMalformedHTLC) RandTestMessage(t *rapid.T) Message {
return &UpdateFailMalformedHTLC{
ChanID: RandChannelID(t),
ID: rapid.Uint64().Draw(t, "id"),
ShaOnionBlob: RandSHA256Hash(t),
FailureCode: RandFailCode(t),
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure UpdateFee implements the TestMessage
// interface.
var _ TestMessage = (*UpdateFee)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *UpdateFee) RandTestMessage(t *rapid.T) Message {
return &UpdateFee{
ChanID: RandChannelID(t),
FeePerKw: uint32(rapid.IntRange(1, 10000).Draw(t, "feePerKw")),
ExtraData: RandExtraOpaqueData(t, nil),
}
}
// A compile time check to ensure UpdateFulfillHTLC implements the TestMessage
// interface.
var _ TestMessage = (*UpdateFulfillHTLC)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message {
msg := &UpdateFulfillHTLC{
ChanID: RandChannelID(t),
ID: rapid.Uint64().Draw(t, "id"),
PaymentPreimage: RandPaymentPreimage(t),
}
cr, ignoreRecords := RandCustomRecords(t, nil, true)
msg.CustomRecords = cr
randData := RandExtraOpaqueData(t, ignoreRecords)
if len(randData) > 0 {
msg.ExtraData = randData
}
return msg
}
// A compile time check to ensure Warning implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Warning)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *Warning) RandTestMessage(t *rapid.T) Message {
msg := &Warning{
ChanID: RandChannelID(t),
}
useASCII := rapid.Bool().Draw(t, "useASCII")
if useASCII {
length := rapid.IntRange(1, 100).Draw(t, "warningDataLength")
data := make([]byte, length)
for i := 0; i < length; i++ {
data[i] = byte(
rapid.IntRange(32, 126).Draw(
t, fmt.Sprintf("warningDataByte-%d", i),
),
)
}
msg.Data = data
} else {
length := rapid.IntRange(1, 100).Draw(t, "warningDataLength")
msg.Data = rapid.SliceOfN(rapid.Byte(), length, length).Draw(
t, "warningData",
)
}
return msg
}
// A compile time check to ensure Error implements the lnwire.TestMessage
// interface.
var _ TestMessage = (*Error)(nil)
// RandTestMessage populates the message with random data suitable for testing.
// It uses the rapid testing framework to generate random values.
//
// This is part of the TestMessage interface.
func (c *Error) RandTestMessage(t *rapid.T) Message {
msg := &Error{
ChanID: RandChannelID(t),
}
useASCII := rapid.Bool().Draw(t, "useASCII")
if useASCII {
length := rapid.IntRange(1, 100).Draw(t, "errorDataLength")
data := make([]byte, length)
for i := 0; i < length; i++ {
data[i] = byte(
rapid.IntRange(32, 126).Draw(
t, fmt.Sprintf("errorDataByte-%d", i),
),
)
}
msg.Data = data
} else {
// Generate random binary data
length := rapid.IntRange(1, 100).Draw(t, "errorDataLength")
msg.Data = rapid.SliceOfN(
rapid.Byte(), length, length,
).Draw(t, "errorData")
}
return msg
}
package lnwire
import (
"crypto/sha256"
"fmt"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
// RandChannelUpdate generates a random ChannelUpdate message using rapid's
// generators.
func RandPartialSig(t *rapid.T) *PartialSig {
// Generate random private key bytes
sigBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "privKeyBytes")
var s btcec.ModNScalar
s.SetByteSlice(sigBytes)
return &PartialSig{
Sig: s,
}
}
// RandPartialSigWithNonce generates a random PartialSigWithNonce using rapid
// generators.
func RandPartialSigWithNonce(t *rapid.T) *PartialSigWithNonce {
sigLen := rapid.IntRange(1, 65).Draw(t, "partialSigLen")
sigBytes := rapid.SliceOfN(
rapid.Byte(), sigLen, sigLen,
).Draw(t, "partialSig")
sigScalar := new(btcec.ModNScalar)
sigScalar.SetByteSlice(sigBytes)
return NewPartialSigWithNonce(
RandMusig2Nonce(t), *sigScalar,
)
}
// RandPubKey generates a random public key using rapid's generators.
func RandPubKey(t *rapid.T) *btcec.PublicKey {
privKeyBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(
t, "privKeyBytes",
)
_, pub := btcec.PrivKeyFromBytes(privKeyBytes)
return pub
}
// RandChannelID generates a random channel ID.
func RandChannelID(t *rapid.T) ChannelID {
var c ChannelID
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "channelID")
copy(c[:], bytes)
return c
}
// RandShortChannelID generates a random short channel ID.
func RandShortChannelID(t *rapid.T) ShortChannelID {
return NewShortChanIDFromInt(
uint64(rapid.IntRange(1, 100000).Draw(t, "shortChanID")),
)
}
// RandFeatureVector generates a random feature vector.
func RandFeatureVector(t *rapid.T) *RawFeatureVector {
featureVec := NewRawFeatureVector()
// Add a random number of random feature bits
numFeatures := rapid.IntRange(0, 20).Draw(t, "numFeatures")
for i := 0; i < numFeatures; i++ {
bit := FeatureBit(rapid.IntRange(0, 100).Draw(
t, fmt.Sprintf("featureBit-%d", i)),
)
featureVec.Set(bit)
}
return featureVec
}
// RandSignature generates a signature for testing.
func RandSignature(t *rapid.T) Sig {
testRScalar := new(btcec.ModNScalar)
testSScalar := new(btcec.ModNScalar)
// Generate random bytes for R and S
rBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "rBytes")
sBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "sBytes")
_ = testRScalar.SetByteSlice(rBytes)
_ = testSScalar.SetByteSlice(sBytes)
testSig := ecdsa.NewSignature(testRScalar, testSScalar)
sig, err := NewSigFromSignature(testSig)
if err != nil {
panic(fmt.Sprintf("unable to create signature: %v", err))
}
return sig
}
// RandPaymentHash generates a random payment hash.
func RandPaymentHash(t *rapid.T) [32]byte {
var hash [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "paymentHash")
copy(hash[:], bytes)
return hash
}
// RandPaymentPreimage generates a random payment preimage.
func RandPaymentPreimage(t *rapid.T) [32]byte {
var preimage [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "preimage")
copy(preimage[:], bytes)
return preimage
}
// RandChainHash generates a random chain hash.
func RandChainHash(t *rapid.T) chainhash.Hash {
var hash [32]byte
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "chainHash")
copy(hash[:], bytes)
return hash
}
// RandNodeAlias generates a random node alias.
func RandNodeAlias(t *rapid.T) NodeAlias {
var alias NodeAlias
aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength")
aliasBytes := rapid.StringN(
0, aliasLength, aliasLength,
).Draw(t, "alias")
copy(alias[:], aliasBytes)
return alias
}
// RandNetAddrs generates random network addresses.
func RandNetAddrs(t *rapid.T) []net.Addr {
numAddresses := rapid.IntRange(0, 5).Draw(t, "numAddresses")
if numAddresses == 0 {
return nil
}
addresses := make([]net.Addr, numAddresses)
for i := 0; i < numAddresses; i++ {
addressType := rapid.IntRange(0, 1).Draw(
t, fmt.Sprintf("addressType-%d", i),
)
switch addressType {
// IPv4.
case 0:
ipBytes := rapid.SliceOfN(rapid.Byte(), 4, 4).Draw(
t, fmt.Sprintf("ipv4-%d", i),
)
port := rapid.IntRange(1, 65535).Draw(
t, fmt.Sprintf("port-%d", i),
)
addresses[i] = &net.TCPAddr{
IP: ipBytes,
Port: port,
}
// IPv6.
case 1:
ipBytes := rapid.SliceOfN(rapid.Byte(), 16, 16).Draw(
t, fmt.Sprintf("ipv6-%d", i),
)
port := rapid.IntRange(1, 65535).Draw(
t, fmt.Sprintf("port-%d", i),
)
addresses[i] = &net.TCPAddr{
IP: ipBytes,
Port: port,
}
}
}
return addresses
}
// RandCustomRecords generates random custom TLV records.
func RandCustomRecords(t *rapid.T,
ignoreRecords fn.Set[uint64],
custom bool) (CustomRecords, fn.Set[uint64]) {
numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords")
customRecords := make(CustomRecords)
if numRecords == 0 {
return nil, nil
}
rangeStart := 0
rangeStop := int(CustomTypeStart)
if custom {
rangeStart = 70_000
rangeStop = 100_000
}
ignoreSet := fn.NewSet[uint64]()
for i := 0; i < numRecords; i++ {
recordType := uint64(
rapid.IntRange(rangeStart, rangeStop).
Filter(func(i int) bool {
return !ignoreRecords.Contains(
uint64(i),
)
}).
Draw(
t, fmt.Sprintf("recordType-%d", i),
),
)
recordLen := rapid.IntRange(4, 64).Draw(
t, fmt.Sprintf("recordLen-%d", i),
)
record := rapid.SliceOfN(
rapid.Byte(), recordLen, recordLen,
).Draw(t, fmt.Sprintf("record-%d", i))
customRecords[recordType] = record
ignoreSet.Add(recordType)
}
return customRecords, ignoreSet
}
// RandMusig2Nonce generates a random musig2 nonce.
func RandMusig2Nonce(t *rapid.T) Musig2Nonce {
var nonce Musig2Nonce
bytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "nonce")
copy(nonce[:], bytes)
return nonce
}
// RandExtraOpaqueData generates random extra opaque data.
func RandExtraOpaqueData(t *rapid.T,
ignoreRecords fn.Set[uint64]) ExtraOpaqueData {
// Make some random records.
cRecords, _ := RandCustomRecords(t, ignoreRecords, false)
if cRecords == nil {
return ExtraOpaqueData{}
}
// Encode those records as opaque data.
recordBytes, err := cRecords.Serialize()
require.NoError(t, err)
return ExtraOpaqueData(recordBytes)
}
// RandOpaqueReason generates a random opaque reason for HTLC failures.
func RandOpaqueReason(t *rapid.T) OpaqueReason {
reasonLen := rapid.IntRange(32, 300).Draw(t, "reasonLen")
return rapid.SliceOfN(rapid.Byte(), reasonLen, reasonLen).Draw(
t, "opaqueReason",
)
}
// RandFailCode generates a random HTLC failure code.
func RandFailCode(t *rapid.T) FailCode {
// List of known failure codes to choose from Using only the documented
// codes.
validCodes := []FailCode{
CodeInvalidRealm,
CodeTemporaryNodeFailure,
CodePermanentNodeFailure,
CodeRequiredNodeFeatureMissing,
CodePermanentChannelFailure,
CodeRequiredChannelFeatureMissing,
CodeUnknownNextPeer,
CodeIncorrectOrUnknownPaymentDetails,
CodeIncorrectPaymentAmount,
CodeFinalExpiryTooSoon,
CodeInvalidOnionVersion,
CodeInvalidOnionHmac,
CodeInvalidOnionKey,
CodeTemporaryChannelFailure,
CodeChannelDisabled,
CodeExpiryTooSoon,
CodeMPPTimeout,
CodeInvalidOnionPayload,
CodeFeeInsufficient,
}
// Choose a random code from the list.
idx := rapid.IntRange(0, len(validCodes)-1).Draw(t, "failCodeIndex")
return validCodes[idx]
}
// RandSHA256Hash generates a random SHA256 hash.
func RandSHA256Hash(t *rapid.T) [sha256.Size]byte {
var hash [sha256.Size]byte
bytes := rapid.SliceOfN(rapid.Byte(), sha256.Size, sha256.Size).Draw(
t, "sha256Hash",
)
copy(hash[:], bytes)
return hash
}
// RandDeliveryAddress generates a random delivery address (script).
func RandDeliveryAddress(t *rapid.T) DeliveryAddress {
addrLen := rapid.IntRange(1, 34).Draw(t, "addrLen")
return rapid.SliceOfN(rapid.Byte(), addrLen, addrLen).Draw(
t, "deliveryAddress",
)
}
// RandChannelType generates a random channel type.
func RandChannelType(t *rapid.T) *ChannelType {
vec := RandFeatureVector(t)
chanType := ChannelType(*vec)
return &chanType
}
// RandLeaseExpiry generates a random lease expiry.
func RandLeaseExpiry(t *rapid.T) *LeaseExpiry {
exp := LeaseExpiry(
uint32(rapid.IntRange(1000, 1000000).Draw(t, "leaseExpiry")),
)
return &exp
}
// RandOutPoint generates a random transaction outpoint.
func RandOutPoint(t *rapid.T) wire.OutPoint {
// Generate a random transaction ID
var txid chainhash.Hash
txidBytes := rapid.SliceOfN(rapid.Byte(), 32, 32).Draw(t, "txid")
copy(txid[:], txidBytes)
// Generate a random output index
vout := uint32(rapid.IntRange(0, 10).Draw(t, "vout"))
return wire.OutPoint{
Hash: txid,
Index: vout,
}
}
package lnwire
import (
"bytes"
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// TimestampsRecordType is the TLV number of the timestamps TLV record
// in the reply_channel_range message.
TimestampsRecordType tlv.Type = 1
// timestampPairSize is the number of bytes required to encode two
// timestamps. Each timestamp is four bytes.
timestampPairSize = 8
)
// Timestamps is a type representing the timestamps TLV field used in the
// reply_channel_range message to communicate the timestamps info of the updates
// of the SCID list being communicated.
type Timestamps []ChanUpdateTimestamps
// ChanUpdateTimestamps holds the timestamp info of the latest known channel
// updates corresponding to the two sides of a channel.
type ChanUpdateTimestamps struct {
Timestamp1 uint32
Timestamp2 uint32
}
// Record constructs the tlv.Record from the Timestamps.
func (t *Timestamps) Record() tlv.Record {
return tlv.MakeDynamicRecord(
TimestampsRecordType, t, t.encodedLen, timeStampsEncoder,
timeStampsDecoder,
)
}
// encodedLen calculates the length of the encoded Timestamps.
func (t *Timestamps) encodedLen() uint64 {
return uint64(1 + timestampPairSize*(len(*t)))
}
// timeStampsEncoder encodes the Timestamps and writes the encoded bytes to the
// given writer.
func timeStampsEncoder(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*Timestamps); ok {
var buf bytes.Buffer
// Add the encoding byte.
err := WriteQueryEncoding(&buf, EncodingSortedPlain)
if err != nil {
return err
}
// For each timestamp, write 4 byte timestamp of node 1 and the
// 4 byte timestamp of node 2.
for _, timestamps := range *v {
err = WriteUint32(&buf, timestamps.Timestamp1)
if err != nil {
return err
}
err = WriteUint32(&buf, timestamps.Timestamp2)
if err != nil {
return err
}
}
_, err = w.Write(buf.Bytes())
return err
}
return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps")
}
// timeStampsDecoder attempts to read and reconstruct a Timestamps object from
// the given reader.
func timeStampsDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*Timestamps); ok {
var encodingByte [1]byte
if _, err := r.Read(encodingByte[:]); err != nil {
return err
}
encoding := QueryEncoding(encodingByte[0])
if encoding != EncodingSortedPlain {
return fmt.Errorf("unsupported encoding: %x", encoding)
}
// The number of timestamps bytes is equal to the passed length
// minus one since the first byte is used for the encoding type.
numTimestampBytes := l - 1
if numTimestampBytes%timestampPairSize != 0 {
return fmt.Errorf("whole number of timestamps not " +
"encoded")
}
numTimestamps := int(numTimestampBytes) / timestampPairSize
timestamps := make(Timestamps, numTimestamps)
for i := 0; i < numTimestamps; i++ {
err := ReadElements(
r, ×tamps[i].Timestamp1,
×tamps[i].Timestamp2,
)
if err != nil {
return err
}
}
*v = timestamps
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.Timestamps")
}
package lnwire
import (
"github.com/lightningnetwork/lnd/tlv"
)
const (
// DeliveryAddrType is the TLV record type for delivery addreses within
// the name space of the OpenChannel and AcceptChannel messages.
DeliveryAddrType = 0
// deliveryAddressMaxSize is the maximum expected size in bytes of a
// DeliveryAddress based on the types of scripts we know.
// Following are the known scripts and their sizes in bytes.
// - pay to witness script hash: 34
// - pay to pubkey hash: 25
// - pay to script hash: 22
// - pay to witness pubkey hash: 22.
deliveryAddressMaxSize = 34
)
// DeliveryAddress is used to communicate the address to which funds from a
// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or
// p2wpkh.
type DeliveryAddress []byte
// Record returns a TLV record that can be used to encode the delivery
// address within the ExtraData TLV stream. This was introduced in order to
// allow the OpenChannel/AcceptChannel messages to properly be extended with
// TLV types.
func (d *DeliveryAddress) Record() tlv.Record {
addrBytes := (*[]byte)(d)
return tlv.MakeDynamicRecord(
DeliveryAddrType, addrBytes,
func() uint64 {
return uint64(len(*addrBytes))
},
tlv.EVarBytes, tlv.DVarBytes,
)
}
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
FeeRecordType tlv.Type = 55555
)
// Fee represents a fee schedule.
type Fee struct {
BaseFee int32
FeeRate int32
}
// Record returns a TLV record that can be used to encode/decode the fee
// type from a given TLV stream.
func (l *Fee) Record() tlv.Record {
return tlv.MakeStaticRecord(
FeeRecordType, l, 8, feeEncoder, feeDecoder, //nolint:gomnd
)
}
// feeEncoder is a custom TLV encoder for the fee record.
func feeEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
v, ok := val.(*Fee)
if !ok {
return tlv.NewTypeForEncodingErr(val, "lnwire.Fee")
}
if err := tlv.EUint32T(w, uint32(v.BaseFee), buf); err != nil {
return err
}
return tlv.EUint32T(w, uint32(v.FeeRate), buf)
}
// feeDecoder is a custom TLV decoder for the fee record.
func feeDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
v, ok := val.(*Fee)
if !ok {
return tlv.NewTypeForDecodingErr(val, "lnwire.Fee", l, 8)
}
var baseFee, feeRate uint32
if err := tlv.DUint32(r, &baseFee, buf, 4); err != nil {
return err
}
if err := tlv.DUint32(r, &feeRate, buf, 4); err != nil {
return err
}
v.FeeRate = int32(feeRate)
v.BaseFee = int32(baseFee)
return nil
}
package lnwire
import (
"io"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// LeaseExpiryType is the type of the experimental record used to
// communicate the expiration of a channel lease throughout the channel
// funding process.
//
// TODO: Decide on actual TLV type. Custom records start at 2^16.
LeaseExpiryRecordType tlv.Type = 1 << 16
)
// LeaseExpiry represents the absolute expiration height of a channel lease. All
// outputs that pay directly to the channel initiator are locked until this
// height is reached.
type LeaseExpiry uint32
// Record returns a TLV record that can be used to encode/decode the LeaseExpiry
// type from a given TLV stream.
func (l *LeaseExpiry) Record() tlv.Record {
return tlv.MakeStaticRecord(
LeaseExpiryRecordType, l, 4, leaseExpiryEncoder, leaseExpiryDecoder,
)
}
// leaseExpiryEncoder is a custom TLV encoder for the LeaseExpiry record.
func leaseExpiryEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*LeaseExpiry); ok {
return tlv.EUint32T(w, uint32(*v), buf)
}
return tlv.NewTypeForEncodingErr(val, "lnwire.LeaseExpiry")
}
// leaseExpiryDecoder is a custom TLV decoder for the LeaseExpiry record.
func leaseExpiryDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*LeaseExpiry); ok {
var leaseExpiry uint32
if err := tlv.DUint32(r, &leaseExpiry, buf, l); err != nil {
return err
}
*v = LeaseExpiry(leaseExpiry)
return nil
}
return tlv.NewTypeForEncodingErr(val, "lnwire.LeaseExpiry")
}
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// OnionPacketSize is the size of the serialized Sphinx onion packet
// included in each UpdateAddHTLC message. The breakdown of the onion
// packet is as follows: 1-byte version, 33-byte ephemeral public key
// (for ECDH), 1300-bytes of per-hop data, and a 32-byte HMAC over the
// entire packet.
OnionPacketSize = 1366
// ExperimentalEndorsementType is the TLV type used for a custom
// record that sets an experimental endorsement value.
ExperimentalEndorsementType tlv.Type = 106823
// ExperimentalUnendorsed is the value that the experimental endorsement
// field contains when a htlc is not endorsed.
ExperimentalUnendorsed = 0
// ExperimentalEndorsed is the value that the experimental endorsement
// field contains when a htlc is endorsed. We're using a single byte
// to represent our endorsement value, but limit the value to using
// the first three bits (max value = 00000111). Interpreted as a uint8
// (an alias for byte in go), we can just define this constant as 7.
ExperimentalEndorsed = 7
)
type (
// BlindingPointTlvType is the type for ephemeral pubkeys used in
// route blinding.
BlindingPointTlvType = tlv.TlvType0
// BlindingPointRecord holds an optional blinding point on update add
// htlc.
//nolint:ll
BlindingPointRecord = tlv.OptionalRecordT[BlindingPointTlvType, *btcec.PublicKey]
)
// UpdateAddHTLC is the message sent by Alice to Bob when she wishes to add an
// HTLC to his remote commitment transaction. In addition to information
// detailing the value, the ID, expiry, and the onion blob is also included
// which allows Bob to derive the next hop in the route. The HTLC added by this
// message is to be added to the remote node's "pending" HTLCs. A subsequent
// CommitSig message will move the pending HTLC to the newly created commitment
// transaction, marking them as "staged".
type UpdateAddHTLC struct {
// ChanID is the particular active channel that this UpdateAddHTLC is
// bound to.
ChanID ChannelID
// ID is the identification server for this HTLC. This value is
// explicitly included as it allows nodes to survive single-sided
// restarts. The ID value for this sides starts at zero, and increases
// with each offered HTLC.
ID uint64
// Amount is the amount of millisatoshis this HTLC is worth.
Amount MilliSatoshi
// PaymentHash is the payment hash to be included in the HTLC this
// request creates. The pre-image to this HTLC must be revealed by the
// upstream peer in order to fully settle the HTLC.
PaymentHash [32]byte
// Expiry is the number of blocks after which this HTLC should expire.
// It is the receiver's duty to ensure that the outgoing HTLC has a
// sufficient expiry value to allow her to redeem the incoming HTLC.
Expiry uint32
// OnionBlob is the raw serialized mix header used to route an HTLC in
// a privacy-preserving manner. The mix header is defined currently to
// be parsed as a 4-tuple: (groupElement, routingInfo, headerMAC,
// body). First the receiving node should use the groupElement, and
// its current onion key to derive a shared secret with the source.
// Once the shared secret has been derived, the headerMAC should be
// checked FIRST. Note that the MAC only covers the routingInfo field.
// If the MAC matches, and the shared secret is fresh, then the node
// should strip off a layer of encryption, exposing the next hop to be
// used in the subsequent UpdateAddHTLC message.
OnionBlob [OnionPacketSize]byte
// BlindingPoint is the ephemeral pubkey used to optionally blind the
// next hop for this htlc.
BlindingPoint BlindingPointRecord
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field of the UpdateAddHTLC
// message.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewUpdateAddHTLC returns a new empty UpdateAddHTLC message.
func NewUpdateAddHTLC() *UpdateAddHTLC {
return &UpdateAddHTLC{}
}
// A compile time check to ensure UpdateAddHTLC implements the lnwire.Message
// interface.
var _ Message = (*UpdateAddHTLC)(nil)
// Decode deserializes a serialized UpdateAddHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
// msgExtraData is a temporary variable used to read the message extra
// data field from the reader.
var msgExtraData ExtraOpaqueData
if err := ReadElements(r,
&c.ChanID,
&c.ID,
&c.Amount,
c.PaymentHash[:],
&c.Expiry,
c.OnionBlob[:],
&msgExtraData,
); err != nil {
return err
}
// Extract TLV records from the extra data field.
blindingRecord := c.BlindingPoint.Zero()
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
msgExtraData, &blindingRecord,
)
if err != nil {
return err
}
// Assign the parsed records back to the message.
if parsed.Contains(blindingRecord.TlvType()) {
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
}
c.CustomRecords = customRecords
c.ExtraData = extraData
return nil
}
// Encode serializes the target UpdateAddHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteMilliSatoshi(w, c.Amount); err != nil {
return err
}
if err := WriteBytes(w, c.PaymentHash[:]); err != nil {
return err
}
if err := WriteUint32(w, c.Expiry); err != nil {
return err
}
if err := WriteBytes(w, c.OnionBlob[:]); err != nil {
return err
}
// Only include blinding point in extra data if present.
var records []tlv.RecordProducer
c.BlindingPoint.WhenSome(
func(b tlv.RecordT[BlindingPointTlvType, *btcec.PublicKey]) {
records = append(records, &b)
},
)
extraData, err := MergeAndEncode(records, c.ExtraData, c.CustomRecords)
if err != nil {
return err
}
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateAddHTLC) MsgType() MessageType {
return MsgUpdateAddHTLC
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateAddHTLC) TargetChanID() ChannelID {
return c.ChanID
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *UpdateAddHTLC) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// A compile time check to ensure UpdateAddHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateAddHTLC)(nil)
package lnwire
import (
"bytes"
"io"
)
// OpaqueReason is an opaque encrypted byte slice that encodes the exact
// failure reason and additional some supplemental data. The contents of this
// slice can only be decrypted by the sender of the original HTLC.
type OpaqueReason []byte
// UpdateFailHTLC is sent by Alice to Bob in order to remove a previously added
// HTLC. Upon receipt of an UpdateFailHTLC the HTLC should be removed from the
// next commitment transaction, with the UpdateFailHTLC propagated backwards in
// the route to fully undo the HTLC.
type UpdateFailHTLC struct {
// ChanID is the particular active channel that this
// UpdateFailHTLC is bound to.
ChanID ChannelID
// ID references which HTLC on the remote node's commitment transaction
// has timed out.
ID uint64
// Reason is an onion-encrypted blob that details why the HTLC was
// failed. This blob is only fully decryptable by the initiator of the
// HTLC message.
Reason OpaqueReason
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure UpdateFailHTLC implements the lnwire.Message
// interface.
var _ Message = (*UpdateFailHTLC)(nil)
// A compile time check to ensure UpdateFailHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateFailHTLC)(nil)
// Decode deserializes a serialized UpdateFailHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
&c.Reason,
&c.ExtraData,
)
}
// Encode serializes the target UpdateFailHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteOpaqueReason(w, c.Reason); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailHTLC) MsgType() MessageType {
return MsgUpdateFailHTLC
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *UpdateFailHTLC) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFailHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"bytes"
"crypto/sha256"
"io"
)
// UpdateFailMalformedHTLC is sent by either the payment forwarder or by
// payment receiver to the payment sender in order to notify it that the onion
// blob can't be parsed. For that reason we send this message instead of
// obfuscate the onion failure.
type UpdateFailMalformedHTLC struct {
// ChanID is the particular active channel that this
// UpdateFailMalformedHTLC is bound to.
ChanID ChannelID
// ID references which HTLC on the remote node's commitment transaction
// has timed out.
ID uint64
// ShaOnionBlob hash of the onion blob on which can't be parsed by the
// node in the payment path.
ShaOnionBlob [sha256.Size]byte
// FailureCode the exact reason why onion blob haven't been parsed.
FailureCode FailCode
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure UpdateFailMalformedHTLC implements the
// lnwire.Message interface.
var _ Message = (*UpdateFailMalformedHTLC)(nil)
// A compile time check to ensure UpdateFailMalformedHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateFailMalformedHTLC)(nil)
// Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.ID,
c.ShaOnionBlob[:],
&c.FailureCode,
&c.ExtraData,
)
}
// Encode serializes the target UpdateFailMalformedHTLC into the passed
// io.Writer observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer,
pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteBytes(w, c.ShaOnionBlob[:]); err != nil {
return err
}
if err := WriteFailCode(w, c.FailureCode); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFailMalformedHTLC) MsgType() MessageType {
return MsgUpdateFailMalformedHTLC
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *UpdateFailMalformedHTLC) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFailMalformedHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"bytes"
"io"
)
// UpdateFee is the message the channel initiator sends to the other peer if
// the channel commitment fee needs to be updated.
type UpdateFee struct {
// ChanID is the channel that this UpdateFee is meant for.
ChanID ChannelID
// FeePerKw is the fee-per-kw on commit transactions that the sender of
// this message wants to use for this channel.
//
// TODO(halseth): make SatPerKWeight when fee estimation is moved to
// own package. Currently this will cause an import cycle.
FeePerKw uint32
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewUpdateFee creates a new UpdateFee message.
func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee {
return &UpdateFee{
ChanID: chanID,
FeePerKw: feePerKw,
}
}
// A compile time check to ensure UpdateFee implements the lnwire.Message
// interface.
var _ Message = (*UpdateFee)(nil)
// A compile time check to ensure UpdateFee implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateFee)(nil)
// Decode deserializes a serialized UpdateFee message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) Decode(r io.Reader, pver uint32) error {
return ReadElements(r,
&c.ChanID,
&c.FeePerKw,
&c.ExtraData,
)
}
// Encode serializes the target UpdateFee into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint32(w, c.FeePerKw); err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFee) MsgType() MessageType {
return MsgUpdateFee
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *UpdateFee) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFee) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"bytes"
"io"
)
// UpdateFulfillHTLC is sent by Alice to Bob when she wishes to settle a
// particular HTLC referenced by its HTLCKey within a specific active channel
// referenced by ChannelPoint. A subsequent CommitSig message will be sent by
// Alice to "lock-in" the removal of the specified HTLC, possible containing a
// batch signature covering several settled HTLC's.
type UpdateFulfillHTLC struct {
// ChanID references an active channel which holds the HTLC to be
// settled.
ChanID ChannelID
// ID denotes the exact HTLC stage within the receiving node's
// commitment transaction to be removed.
ID uint64
// PaymentPreimage is the R-value preimage required to fully settle an
// HTLC.
PaymentPreimage [32]byte
// CustomRecords maps TLV types to byte slices, storing arbitrary data
// intended for inclusion in the ExtraData field.
CustomRecords CustomRecords
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC.
func NewUpdateFulfillHTLC(chanID ChannelID, id uint64,
preimage [32]byte) *UpdateFulfillHTLC {
return &UpdateFulfillHTLC{
ChanID: chanID,
ID: id,
PaymentPreimage: preimage,
}
}
// A compile time check to ensure UpdateFulfillHTLC implements the
// lnwire.Message interface.
var _ Message = (*UpdateFulfillHTLC)(nil)
// A compile time check to ensure UpdateFulfillHTLC implements the
// lnwire.SizeableMessage interface.
var _ SizeableMessage = (*UpdateFulfillHTLC)(nil)
// Decode deserializes a serialized UpdateFulfillHTLC message stored in the
// passed io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error {
// msgExtraData is a temporary variable used to read the message extra
// data field from the reader.
var msgExtraData ExtraOpaqueData
if err := ReadElements(r,
&c.ChanID,
&c.ID,
c.PaymentPreimage[:],
&msgExtraData,
); err != nil {
return err
}
// Extract custom records from the extra data field.
customRecords, _, extraData, err := ParseAndExtractCustomRecords(
msgExtraData,
)
if err != nil {
return err
}
c.CustomRecords = customRecords
c.ExtraData = extraData
return nil
}
// Encode serializes the target UpdateFulfillHTLC into the passed io.Writer
// observing the protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error {
if err := WriteChannelID(w, c.ChanID); err != nil {
return err
}
if err := WriteUint64(w, c.ID); err != nil {
return err
}
if err := WriteBytes(w, c.PaymentPreimage[:]); err != nil {
return err
}
// Combine the custom records and the extra data, then encode the
// result as a byte slice.
extraData, err := MergeAndEncode(nil, c.ExtraData, c.CustomRecords)
if err != nil {
return err
}
return WriteBytes(w, extraData)
}
// MsgType returns the integer uniquely identifying this message type on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *UpdateFulfillHTLC) MsgType() MessageType {
return MsgUpdateFulfillHTLC
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *UpdateFulfillHTLC) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
// TargetChanID returns the channel id of the link for which this message is
// intended.
//
// NOTE: Part of peer.LinkUpdater interface.
func (c *UpdateFulfillHTLC) TargetChanID() ChannelID {
return c.ChanID
}
package lnwire
import (
"bytes"
"fmt"
"io"
)
// WarningData is a set of bytes associated with a particular sent warning. A
// receiving node SHOULD only print out data verbatim if the string is composed
// solely of printable ASCII characters. For reference, the printable character
// set includes byte values 32 through 127 inclusive.
type WarningData []byte
// Warning is used to express non-critical errors in the protocol, providing
// a "soft" way for nodes to communicate failures.
type Warning struct {
// ChanID references the active channel in which the warning occurred
// within. If the ChanID is all zeros, then this warning applies to the
// entire established connection.
ChanID ChannelID
// Data is the attached warning data that describes the exact failure
// which caused the warning message to be sent.
Data WarningData
}
// A compile time check to ensure Warning implements the lnwire.Message
// interface.
var _ Message = (*Warning)(nil)
// A compile time check to ensure Warning implements the lnwire.SizeableMessage
// interface.
var _ SizeableMessage = (*Warning)(nil)
// NewWarning creates a new Warning message.
func NewWarning() *Warning {
return &Warning{}
}
// Warning returns the string representation to Warning.
func (c *Warning) Warning() string {
errMsg := "non-ascii data"
if isASCII(c.Data) {
errMsg = string(c.Data)
}
return fmt.Sprintf("chan_id=%v, err=%v", c.ChanID, errMsg)
}
// Decode deserializes a serialized Warning message stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *Warning) Decode(r io.Reader, _ uint32) error {
return ReadElements(r,
&c.ChanID,
&c.Data,
)
}
// Encode serializes the target Warning into the passed io.Writer observing the
// protocol version specified.
//
// This is part of the lnwire.Message interface.
func (c *Warning) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteBytes(w, c.ChanID[:]); err != nil {
return err
}
return WriteWarningData(w, c.Data)
}
// MsgType returns the integer uniquely identifying an Warning message on the
// wire.
//
// This is part of the lnwire.Message interface.
func (c *Warning) MsgType() MessageType {
return MsgWarning
}
// SerializedSize returns the serialized size of the message in bytes.
//
// This is part of the lnwire.SizeableMessage interface.
func (c *Warning) SerializedSize() (uint32, error) {
return MessageSerializedSize(c)
}
package lnwire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"image/color"
"math"
"net"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/tor"
)
var (
// ErrNilFeatureVector is returned when the supplied feature is nil.
ErrNilFeatureVector = errors.New("cannot write nil feature vector")
// ErrPkScriptTooLong is returned when the length of the provided
// script exceeds 34.
ErrPkScriptTooLong = errors.New("'PkScript' too long")
// ErrNilTCPAddress is returned when the supplied address is nil.
ErrNilTCPAddress = errors.New("cannot write nil TCPAddr")
// ErrNilOnionAddress is returned when the supplied address is nil.
ErrNilOnionAddress = errors.New("cannot write nil onion address")
// ErrNilNetAddress is returned when a nil value is used in []net.Addr.
ErrNilNetAddress = errors.New("cannot write nil address")
// ErrNilOpaqueAddrs is returned when the supplied address is nil.
ErrNilOpaqueAddrs = errors.New("cannot write nil OpaqueAddrs")
// ErrNilPublicKey is returned when a nil pubkey is used.
ErrNilPublicKey = errors.New("cannot write nil pubkey")
// ErrUnknownServiceLength is returned when the onion service length is
// unknown.
ErrUnknownServiceLength = errors.New("unknown onion service length")
)
// ErrOutpointIndexTooBig is used when the outpoint index exceeds the max value
// of uint16.
func ErrOutpointIndexTooBig(index uint32) error {
return fmt.Errorf(
"index for outpoint (%v) is greater than "+
"max index of %v", index, math.MaxUint16,
)
}
// WriteBytes appends the given bytes to the provided buffer.
func WriteBytes(buf *bytes.Buffer, b []byte) error {
_, err := buf.Write(b)
return err
}
// WriteUint8 appends the uint8 to the provided buffer.
func WriteUint8(buf *bytes.Buffer, n uint8) error {
_, err := buf.Write([]byte{n})
return err
}
// WriteUint16 appends the uint16 to the provided buffer. It encodes the
// integer using big endian byte order.
func WriteUint16(buf *bytes.Buffer, n uint16) error {
var b [2]byte
binary.BigEndian.PutUint16(b[:], n)
_, err := buf.Write(b[:])
return err
}
// WriteUint32 appends the uint32 to the provided buffer. It encodes the
// integer using big endian byte order.
func WriteUint32(buf *bytes.Buffer, n uint32) error {
var b [4]byte
binary.BigEndian.PutUint32(b[:], n)
_, err := buf.Write(b[:])
return err
}
// WriteUint64 appends the uint64 to the provided buffer. It encodes the
// integer using big endian byte order.
func WriteUint64(buf *bytes.Buffer, n uint64) error {
var b [8]byte
binary.BigEndian.PutUint64(b[:], n)
_, err := buf.Write(b[:])
return err
}
// WriteSatoshi appends the Satoshi value to the provided buffer.
func WriteSatoshi(buf *bytes.Buffer, amount btcutil.Amount) error {
return WriteUint64(buf, uint64(amount))
}
// WriteMilliSatoshi appends the MilliSatoshi value to the provided buffer.
func WriteMilliSatoshi(buf *bytes.Buffer, amount MilliSatoshi) error {
return WriteUint64(buf, uint64(amount))
}
// WritePublicKey appends the compressed public key to the provided buffer.
func WritePublicKey(buf *bytes.Buffer, pub *btcec.PublicKey) error {
if pub == nil {
return ErrNilPublicKey
}
serializedPubkey := pub.SerializeCompressed()
return WriteBytes(buf, serializedPubkey)
}
// WriteChannelID appends the ChannelID to the provided buffer.
func WriteChannelID(buf *bytes.Buffer, channelID ChannelID) error {
return WriteBytes(buf, channelID[:])
}
// WriteNodeAlias appends the alias to the provided buffer.
func WriteNodeAlias(buf *bytes.Buffer, alias NodeAlias) error {
return WriteBytes(buf, alias[:])
}
// WriteShortChannelID appends the ShortChannelID to the provided buffer. It
// encodes the BlockHeight and TxIndex each using 3 bytes with big endian byte
// order, and encodes txPosition using 2 bytes with big endian byte order.
func WriteShortChannelID(buf *bytes.Buffer, shortChanID ShortChannelID) error {
// Check that field fit in 3 bytes and write the blockHeight
if shortChanID.BlockHeight > ((1 << 24) - 1) {
return errors.New("block height should fit in 3 bytes")
}
var blockHeight [4]byte
binary.BigEndian.PutUint32(blockHeight[:], shortChanID.BlockHeight)
if _, err := buf.Write(blockHeight[1:]); err != nil {
return err
}
// Check that field fit in 3 bytes and write the txIndex
if shortChanID.TxIndex > ((1 << 24) - 1) {
return errors.New("tx index should fit in 3 bytes")
}
var txIndex [4]byte
binary.BigEndian.PutUint32(txIndex[:], shortChanID.TxIndex)
if _, err := buf.Write(txIndex[1:]); err != nil {
return err
}
// Write the TxPosition
return WriteUint16(buf, shortChanID.TxPosition)
}
// WriteSig appends the signature to the provided buffer.
func WriteSig(buf *bytes.Buffer, sig Sig) error {
return WriteBytes(buf, sig.bytes[:])
}
// WriteSigs appends the slice of signatures to the provided buffer with its
// length.
func WriteSigs(buf *bytes.Buffer, sigs []Sig) error {
// Write the length of the sigs.
if err := WriteUint16(buf, uint16(len(sigs))); err != nil {
return err
}
for _, sig := range sigs {
if err := WriteSig(buf, sig); err != nil {
return err
}
}
return nil
}
// WriteFailCode appends the FailCode to the provided buffer.
func WriteFailCode(buf *bytes.Buffer, e FailCode) error {
return WriteUint16(buf, uint16(e))
}
// WriteRawFeatureVector encodes the feature using the feature's Encode method
// and appends the data to the provided buffer. An error will return if the
// passed feature is nil.
func WriteRawFeatureVector(buf *bytes.Buffer, feature *RawFeatureVector) error {
if feature == nil {
return ErrNilFeatureVector
}
return feature.Encode(buf)
}
// WriteColorRGBA appends the RGBA color using three bytes.
func WriteColorRGBA(buf *bytes.Buffer, e color.RGBA) error {
// Write R
if err := WriteUint8(buf, e.R); err != nil {
return err
}
// Write G
if err := WriteUint8(buf, e.G); err != nil {
return err
}
// Write B
return WriteUint8(buf, e.B)
}
// WriteQueryEncoding appends the QueryEncoding to the provided buffer.
func WriteQueryEncoding(buf *bytes.Buffer, e QueryEncoding) error {
return WriteUint8(buf, uint8(e))
}
// WriteFundingFlag appends the FundingFlag to the provided buffer.
func WriteFundingFlag(buf *bytes.Buffer, flag FundingFlag) error {
return WriteUint8(buf, uint8(flag))
}
// WriteChanUpdateMsgFlags appends the update flag to the provided buffer.
func WriteChanUpdateMsgFlags(buf *bytes.Buffer, f ChanUpdateMsgFlags) error {
return WriteUint8(buf, uint8(f))
}
// WriteChanUpdateChanFlags appends the update flag to the provided buffer.
func WriteChanUpdateChanFlags(buf *bytes.Buffer, f ChanUpdateChanFlags) error {
return WriteUint8(buf, uint8(f))
}
// WriteDeliveryAddress appends the address to the provided buffer.
func WriteDeliveryAddress(buf *bytes.Buffer, addr DeliveryAddress) error {
return writeDataWithLength(buf, addr)
}
// WritePingPayload appends the payload to the provided buffer.
func WritePingPayload(buf *bytes.Buffer, payload PingPayload) error {
return writeDataWithLength(buf, payload)
}
// WritePongPayload appends the payload to the provided buffer.
func WritePongPayload(buf *bytes.Buffer, payload PongPayload) error {
return writeDataWithLength(buf, payload)
}
// WriteWarningData appends the data to the provided buffer.
func WriteWarningData(buf *bytes.Buffer, data WarningData) error {
return writeDataWithLength(buf, data)
}
// WriteErrorData appends the data to the provided buffer.
func WriteErrorData(buf *bytes.Buffer, data ErrorData) error {
return writeDataWithLength(buf, data)
}
// WriteOpaqueReason appends the reason to the provided buffer.
func WriteOpaqueReason(buf *bytes.Buffer, reason OpaqueReason) error {
return writeDataWithLength(buf, reason)
}
// WriteBool appends the boolean to the provided buffer.
func WriteBool(buf *bytes.Buffer, b bool) error {
if b {
return WriteBytes(buf, []byte{1})
}
return WriteBytes(buf, []byte{0})
}
// WritePkScript appends the script to the provided buffer. Returns an error if
// the provided script exceeds 34 bytes.
func WritePkScript(buf *bytes.Buffer, s PkScript) error {
// The largest script we'll accept is a p2wsh which is exactly
// 34 bytes long.
scriptLength := len(s)
if scriptLength > 34 {
return ErrPkScriptTooLong
}
return wire.WriteVarBytes(buf, 0, s)
}
// WriteOutPoint appends the outpoint to the provided buffer.
func WriteOutPoint(buf *bytes.Buffer, p wire.OutPoint) error {
// Before we write anything to the buffer, check the Index is sane.
if p.Index > math.MaxUint16 {
return ErrOutpointIndexTooBig(p.Index)
}
var h [32]byte
copy(h[:], p.Hash[:])
if _, err := buf.Write(h[:]); err != nil {
return err
}
// Write the index using two bytes.
return WriteUint16(buf, uint16(p.Index))
}
// WriteTCPAddr appends the TCP address to the provided buffer, either a IPv4
// or a IPv6.
func WriteTCPAddr(buf *bytes.Buffer, addr *net.TCPAddr) error {
if addr == nil {
return ErrNilTCPAddress
}
// Make a slice of bytes to hold the data of descriptor and ip. At
// most, we need 17 bytes - 1 byte for the descriptor, 16 bytes for
// IPv6.
data := make([]byte, 0, 17)
if addr.IP.To4() != nil {
data = append(data, uint8(tcp4Addr))
data = append(data, addr.IP.To4()...)
} else {
data = append(data, uint8(tcp6Addr))
data = append(data, addr.IP.To16()...)
}
if _, err := buf.Write(data); err != nil {
return err
}
return WriteUint16(buf, uint16(addr.Port))
}
// WriteOnionAddr appends the onion address to the provided buffer.
func WriteOnionAddr(buf *bytes.Buffer, addr *tor.OnionAddr) error {
if addr == nil {
return ErrNilOnionAddress
}
var (
suffixIndex int
descriptor []byte
)
// Decide the suffixIndex and descriptor.
switch len(addr.OnionService) {
case tor.V2Len:
descriptor = []byte{byte(v2OnionAddr)}
suffixIndex = tor.V2Len - tor.OnionSuffixLen
case tor.V3Len:
descriptor = []byte{byte(v3OnionAddr)}
suffixIndex = tor.V3Len - tor.OnionSuffixLen
default:
return ErrUnknownServiceLength
}
// Decode the address.
host, err := tor.Base32Encoding.DecodeString(
addr.OnionService[:suffixIndex],
)
if err != nil {
return err
}
// Perform the actual write when the above checks passed.
if _, err := buf.Write(descriptor); err != nil {
return err
}
if _, err := buf.Write(host); err != nil {
return err
}
return WriteUint16(buf, uint16(addr.Port))
}
// WriteOpaqueAddrs appends the payload of the given OpaqueAddrs to buffer.
func WriteOpaqueAddrs(buf *bytes.Buffer, addr *OpaqueAddrs) error {
if addr == nil {
return ErrNilOpaqueAddrs
}
_, err := buf.Write(addr.Payload)
return err
}
// WriteNetAddrs appends a slice of addresses to the provided buffer with the
// length info.
func WriteNetAddrs(buf *bytes.Buffer, addresses []net.Addr) error {
// First, we'll encode all the addresses into an intermediate
// buffer. We need to do this in order to compute the total
// length of the addresses.
buffer := make([]byte, 0, MaxMsgBody)
addrBuf := bytes.NewBuffer(buffer)
for _, address := range addresses {
switch a := address.(type) {
case *net.TCPAddr:
if err := WriteTCPAddr(addrBuf, a); err != nil {
return err
}
case *tor.OnionAddr:
if err := WriteOnionAddr(addrBuf, a); err != nil {
return err
}
case *OpaqueAddrs:
if err := WriteOpaqueAddrs(addrBuf, a); err != nil {
return err
}
default:
return ErrNilNetAddress
}
}
// With the addresses fully encoded, we can now write out data.
return writeDataWithLength(buf, addrBuf.Bytes())
}
// writeDataWithLength writes the data and its length to the buffer.
func writeDataWithLength(buf *bytes.Buffer, data []byte) error {
var l [2]byte
binary.BigEndian.PutUint16(l[:], uint16(len(data)))
if _, err := buf.Write(l[:]); err != nil {
return err
}
_, err := buf.Write(data)
return err
}
package multimutex
import (
"fmt"
"sync"
)
// cntMutex is a struct that wraps a counter and a mutex, and is used to keep
// track of the number of goroutines waiting for access to the
// mutex, such that we can forget about it when the counter is zero.
type cntMutex struct {
cnt int
sync.Mutex
}
// Mutex is a struct that keeps track of a set of mutexes with a given ID. It
// can be used for making sure only one goroutine gets given the mutex per ID.
type Mutex[T comparable] struct {
// mutexes is a map of IDs to a cntMutex. The cntMutex for a given ID
// will hold the mutex to be used by all callers requesting access for
// the ID, in addition to the count of callers.
mutexes map[T]*cntMutex
// mapMtx is used to give synchronize concurrent access to the mutexes
// map.
mapMtx sync.Mutex
}
// NewMutex creates a new Mutex.
func NewMutex[T comparable]() *Mutex[T] {
return &Mutex[T]{
mutexes: make(map[T]*cntMutex),
}
}
// Lock locks the mutex by the given ID. If the mutex is already locked by this
// ID, Lock blocks until the mutex is available.
func (c *Mutex[T]) Lock(id T) {
c.mapMtx.Lock()
mtx, ok := c.mutexes[id]
if ok {
// If the mutex already existed in the map, we increment its
// counter, to indicate that there now is one more goroutine
// waiting for it.
mtx.cnt++
} else {
// If it was not in the map, it means no other goroutine has
// locked the mutex for this ID, and we can create a new mutex
// with count 1 and add it to the map.
mtx = &cntMutex{
cnt: 1,
}
c.mutexes[id] = mtx
}
c.mapMtx.Unlock()
// Acquire the mutex for this ID.
mtx.Lock()
}
// Unlock unlocks the mutex by the given ID. It is a run-time error if the
// mutex is not locked by the ID on entry to Unlock.
func (c *Mutex[T]) Unlock(id T) {
// Since we are done with all the work for this update, we update the
// map to reflect that.
c.mapMtx.Lock()
mtx, ok := c.mutexes[id]
if !ok {
// The mutex not existing in the map means an unlock for an ID
// not currently locked was attempted.
panic(fmt.Sprintf("double unlock for id %v",
id))
}
// Decrement the counter. If the count goes to zero, it means this
// caller was the last one to wait for the mutex, and we can delete it
// from the map. We can do this safely since we are under the mapMtx,
// meaning that all other goroutines waiting for the mutex already have
// incremented it, or will create a new mutex when they get the mapMtx.
mtx.cnt--
if mtx.cnt == 0 {
delete(c.mutexes, id)
}
c.mapMtx.Unlock()
// Unlock the mutex for this ID.
mtx.Unlock()
}
package netann
import (
"errors"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrChanStatusManagerExiting signals that a shutdown of the
// ChanStatusManager has already been requested.
ErrChanStatusManagerExiting = errors.New("chan status manager exiting")
// ErrInvalidTimeoutConstraints signals that the ChanStatusManager could
// not be initialized because the timeouts and sample intervals were
// malformed.
ErrInvalidTimeoutConstraints = errors.New("chan-enable-timeout + " +
"chan-status-sample-interval must <= chan-disable-timeout " +
"and all three chan configs must be positive integers")
// ErrEnableInactiveChan signals that a request to enable a channel
// could not be completed because the channel isn't actually active at
// the time of the request.
ErrEnableInactiveChan = errors.New("unable to enable channel which " +
"is not currently active")
// ErrEnableManuallyDisabledChan signals that an automatic / background
// request to enable a channel could not be completed because the channel
// was manually disabled.
ErrEnableManuallyDisabledChan = errors.New("unable to enable channel " +
"which was manually disabled")
)
// ChanStatusConfig holds parameters and resources required by the
// ChanStatusManager to perform its duty.
type ChanStatusConfig struct {
// OurPubKey is the public key identifying this node on the network.
OurPubKey *btcec.PublicKey
// OurKeyLoc is the locator for the public key identifying this node on
// the network.
OurKeyLoc keychain.KeyLocator
// MessageSigner signs messages that validate under OurPubKey.
MessageSigner lnwallet.MessageSigner
// IsChannelActive checks whether the channel identified by the provided
// ChannelID is considered active. This should only return true if the
// channel has been sufficiently confirmed, the channel has received
// ChannelReady, and the remote peer is online.
IsChannelActive func(lnwire.ChannelID) bool
// ApplyChannelUpdate processes new ChannelUpdates signed by our node by
// updating our local routing table and broadcasting the update to our
// peers.
ApplyChannelUpdate func(*lnwire.ChannelUpdate1, *wire.OutPoint,
bool) error
// DB stores the set of channels that are to be monitored.
DB DB
// Graph stores the channel info and policies for channels in DB.
Graph ChannelGraph
// ChanEnableTimeout is the duration a peer's connect must remain stable
// before attempting to re-enable the channel.
//
// NOTE: This value is only used to verify that the relation between
// itself, ChanDisableTimeout, and ChanStatusSampleInterval is correct.
// The user is still responsible for ensuring that the same duration
// elapses before attempting to re-enable a channel.
ChanEnableTimeout time.Duration
// ChanDisableTimeout is the duration the manager will wait after
// detecting that a channel has become inactive before broadcasting an
// update to disable the channel.
ChanDisableTimeout time.Duration
// ChanStatusSampleInterval is the long-polling interval used by the
// manager to check if the channels being monitored have become
// inactive.
ChanStatusSampleInterval time.Duration
}
// ChanStatusManager facilitates requests to enable or disable a channel via a
// network announcement that sets the disable bit on the ChannelUpdate
// accordingly. The manager will periodically sample to detect cases where a
// link has become inactive, and facilitate the process of disabling the channel
// passively. The ChanStatusManager state machine is designed to reduce the
// likelihood of spamming the network with updates for flapping peers.
type ChanStatusManager struct {
started sync.Once
stopped sync.Once
cfg *ChanStatusConfig
// ourPubKeyBytes is the serialized compressed pubkey of our node.
ourPubKeyBytes []byte
// chanStates contains the set of channels being monitored for status
// updates. Access to the map is serialized by the statusManager's event
// loop.
chanStates channelStates
// enableRequests pipes external requests to enable a channel into the
// primary event loop.
enableRequests chan statusRequest
// disableRequests pipes external requests to disable a channel into the
// primary event loop.
disableRequests chan statusRequest
// autoRequests pipes external requests to restore automatic channel
// state management into the primary event loop.
autoRequests chan statusRequest
// statusSampleTicker fires at the interval prescribed by
// ChanStatusSampleInterval to check if channels in chanStates have
// become inactive.
statusSampleTicker *time.Ticker
wg sync.WaitGroup
quit chan struct{}
}
// NewChanStatusManager initializes a new ChanStatusManager using the given
// configuration. An error is returned if the timeouts and sample interval fail
// to meet do not satisfy the equation:
//
// ChanEnableTimeout + ChanStatusSampleInterval > ChanDisableTimeout.
func NewChanStatusManager(cfg *ChanStatusConfig) (*ChanStatusManager, error) {
// Assert that the config timeouts are properly formed. We require the
// enable_timeout + sample_interval to be less than or equal to the
// disable_timeout and that all are positive values. A peer that
// disconnects and reconnects quickly may cause a disable update to be
// sent, shortly followed by a re-enable. Ensuring a healthy separation
// helps dampen the possibility of spamming updates that toggle the
// disable bit for such events.
if cfg.ChanStatusSampleInterval <= 0 {
return nil, ErrInvalidTimeoutConstraints
}
if cfg.ChanEnableTimeout <= 0 {
return nil, ErrInvalidTimeoutConstraints
}
if cfg.ChanDisableTimeout <= 0 {
return nil, ErrInvalidTimeoutConstraints
}
if cfg.ChanEnableTimeout+cfg.ChanStatusSampleInterval >
cfg.ChanDisableTimeout {
return nil, ErrInvalidTimeoutConstraints
}
return &ChanStatusManager{
cfg: cfg,
ourPubKeyBytes: cfg.OurPubKey.SerializeCompressed(),
chanStates: make(channelStates),
statusSampleTicker: time.NewTicker(cfg.ChanStatusSampleInterval),
enableRequests: make(chan statusRequest),
disableRequests: make(chan statusRequest),
autoRequests: make(chan statusRequest),
quit: make(chan struct{}),
}, nil
}
// Start safely starts the ChanStatusManager.
func (m *ChanStatusManager) Start() error {
var err error
m.started.Do(func() {
log.Info("Channel Status Manager starting")
err = m.start()
})
return err
}
func (m *ChanStatusManager) start() error {
channels, err := m.fetchChannels()
if err != nil {
return err
}
// Populate the initial states of all confirmed, public channels.
for _, c := range channels {
_, err := m.getOrInitChanStatus(c.FundingOutpoint)
switch {
// If we can't retrieve the edge info for this channel, it may
// have been pruned from the channel graph but not yet from our
// set of channels. We'll skip it as we can't determine its
// initial state.
case errors.Is(err, graphdb.ErrEdgeNotFound):
log.Warnf("Unable to find channel policies for %v, "+
"skipping. This is typical if the channel is "+
"in the process of closing.", c.FundingOutpoint)
continue
// If we are in the process of opening a channel, the funding
// manager might not have added the ChannelUpdate to the graph
// yet. We'll ignore the channel for now.
case err == ErrUnableToExtractChanUpdate:
log.Warnf("Unable to find channel policies for %v, "+
"skipping. This is typical if the channel is "+
"in the process of being opened.",
c.FundingOutpoint)
continue
case err != nil:
return err
}
}
m.wg.Add(1)
go m.statusManager()
return nil
}
// Stop safely shuts down the ChanStatusManager.
func (m *ChanStatusManager) Stop() error {
m.stopped.Do(func() {
log.Info("Channel Status Manager shutting down...")
defer log.Debug("Channel Status Manager shutdown complete")
close(m.quit)
m.wg.Wait()
})
return nil
}
// RequestEnable submits a request to immediately enable a channel identified by
// the provided outpoint. If the channel is already enabled, no action will be
// taken. If the channel is marked pending-disable the channel will be returned
// to an active status as the scheduled disable was never sent. Otherwise if the
// channel is found to be disabled, a new announcement will be signed with the
// disabled bit cleared and broadcast to the network.
//
// If the channel was manually disabled and RequestEnable is called with
// manual = false, then the request will be ignored.
//
// NOTE: RequestEnable should only be called after a stable connection with the
// channel's peer has lasted at least the ChanEnableTimeout. Failure to do so
// may result in behavior that deviates from the expected behavior of the state
// machine.
func (m *ChanStatusManager) RequestEnable(outpoint wire.OutPoint,
manual bool) error {
return m.submitRequest(m.enableRequests, outpoint, manual)
}
// RequestDisable submits a request to immediately disable a channel identified
// by the provided outpoint. If the channel is already disabled, no action will
// be taken. Otherwise, a new announcement will be signed with the disabled bit
// set and broadcast to the network.
//
// The channel state will be changed to either ChanStatusDisabled or
// ChanStatusManuallyDisabled, depending on the passed-in value of manual. In
// particular, note the following state transitions:
//
// current state | manual | new state
// ---------------------------------------------------
// Disabled | false | Disabled
// ManuallyDisabled | false | ManuallyDisabled (*)
// Disabled | true | ManuallyDisabled
// ManuallyDisabled | true | ManuallyDisabled
//
// (*) If a channel was manually disabled, subsequent automatic / background
//
// requests to disable the channel do not change the fact that the channel
// was manually disabled.
func (m *ChanStatusManager) RequestDisable(outpoint wire.OutPoint,
manual bool) error {
return m.submitRequest(m.disableRequests, outpoint, manual)
}
// RequestAuto submits a request to restore automatic channel state management.
// If the channel is in the state ChanStatusManuallyDisabled, it will be moved
// back to the state ChanStatusDisabled. Otherwise, no action will be taken.
func (m *ChanStatusManager) RequestAuto(outpoint wire.OutPoint) error {
return m.submitRequest(m.autoRequests, outpoint, true)
}
// statusRequest is passed to the statusManager to request a change in status
// for a particular channel point. The exact action is governed by passing the
// request through one of the enableRequests or disableRequests channels.
type statusRequest struct {
outpoint wire.OutPoint
manual bool
errChan chan error
}
// submitRequest sends a request for either enabling or disabling a particular
// outpoint and awaits an error response. The request type is dictated by the
// reqChan passed in, which can be either of the enableRequests or
// disableRequests channels.
func (m *ChanStatusManager) submitRequest(reqChan chan statusRequest,
outpoint wire.OutPoint, manual bool) error {
req := statusRequest{
outpoint: outpoint,
manual: manual,
errChan: make(chan error, 1),
}
select {
case reqChan <- req:
case <-m.quit:
return ErrChanStatusManagerExiting
}
select {
case err := <-req.errChan:
return err
case <-m.quit:
return ErrChanStatusManagerExiting
}
}
// statusManager is the primary event loop for the ChanStatusManager, providing
// the necessary synchronization primitive to protect access to the chanStates
// map. All requests to explicitly enable or disable a channel are processed
// within this method. The statusManager will also periodically poll the active
// status of channels within the htlcswitch to see if a disable announcement
// should be scheduled or broadcast.
//
// NOTE: This method MUST be run as a goroutine.
func (m *ChanStatusManager) statusManager() {
defer m.wg.Done()
for {
select {
// Process any requests to mark channel as enabled.
case req := <-m.enableRequests:
req.errChan <- m.processEnableRequest(req.outpoint, req.manual)
// Process any requests to mark channel as disabled.
case req := <-m.disableRequests:
req.errChan <- m.processDisableRequest(req.outpoint, req.manual)
// Process any requests to restore automatic channel state management.
case req := <-m.autoRequests:
req.errChan <- m.processAutoRequest(req.outpoint)
// Use long-polling to detect when channels become inactive.
case <-m.statusSampleTicker.C:
// First, do a sweep and mark any ChanStatusEnabled
// channels that are not active within the htlcswitch as
// ChanStatusPendingDisabled. The channel will then be
// disabled if no request to enable is received before
// the ChanDisableTimeout expires.
m.markPendingInactiveChannels()
// Now, do another sweep to disable any channels that
// were marked in a prior iteration as pending inactive
// if the inactive chan timeout has elapsed.
m.disableInactiveChannels()
case <-m.quit:
return
}
}
}
// processEnableRequest attempts to enable the given outpoint.
//
// - If the channel is not active at the time of the request,
// ErrEnableInactiveChan will be returned.
// - If the channel was in the ManuallyDisabled state and manual = false,
// the request will be ignored and ErrEnableManuallyDisabledChan will be
// returned.
// - Otherwise, the status of the channel in chanStates will be
// ChanStatusEnabled and the method will return nil.
//
// An update will be broadcast only if the channel is currently disabled,
// otherwise no update will be sent on the network.
func (m *ChanStatusManager) processEnableRequest(outpoint wire.OutPoint,
manual bool) error {
curState, err := m.getOrInitChanStatus(outpoint)
if err != nil {
return err
}
// Quickly check to see if the requested channel is active within the
// htlcswitch and return an error if it isn't.
chanID := lnwire.NewChanIDFromOutPoint(outpoint)
if !m.cfg.IsChannelActive(chanID) {
return ErrEnableInactiveChan
}
switch curState.Status {
// Channel is already enabled, nothing to do.
case ChanStatusEnabled:
log.Debugf("Channel(%v) already enabled, skipped "+
"announcement", outpoint)
return nil
// The channel is enabled, though we are now canceling the scheduled
// disable.
case ChanStatusPendingDisabled:
log.Debugf("Channel(%v) already enabled, canceling scheduled "+
"disable", outpoint)
// We'll sign a new update if the channel is still disabled.
case ChanStatusManuallyDisabled:
if !manual {
return ErrEnableManuallyDisabledChan
}
fallthrough
case ChanStatusDisabled:
log.Infof("Announcing channel(%v) enabled", outpoint)
err := m.signAndSendNextUpdate(outpoint, false)
if err != nil {
return err
}
}
m.chanStates.markEnabled(outpoint)
return nil
}
// processDisableRequest attempts to disable the given outpoint. If the method
// returns nil, the status of the channel in chanStates will be either
// ChanStatusDisabled or ChanStatusManuallyDisabled, depending on the
// passed-in value of manual.
//
// An update will only be sent if the channel has a status other than
// ChanStatusEnabled, otherwise no update will be sent on the network.
func (m *ChanStatusManager) processDisableRequest(outpoint wire.OutPoint,
manual bool) error {
curState, err := m.getOrInitChanStatus(outpoint)
if err != nil {
return err
}
status := curState.Status
if status == ChanStatusEnabled || status == ChanStatusPendingDisabled {
log.Infof("Announcing channel(%v) disabled [requested]",
outpoint)
err := m.signAndSendNextUpdate(outpoint, true)
if err != nil {
return err
}
}
// Typically, a request to disable a channel via the manager's public
// interface signals that the channel is being closed.
//
// If we don't need to keep track of a manual request to disable the
// channel, then we can remove the outpoint to free up space in the map
// of channel states. If for some reason the channel isn't closed, the
// state will be repopulated on subsequent calls to the manager's public
// interface via a db lookup, or on startup.
if manual {
m.chanStates.markManuallyDisabled(outpoint)
} else if status != ChanStatusManuallyDisabled {
delete(m.chanStates, outpoint)
}
return nil
}
// processAutoRequest attempts to restore automatic channel state management
// for the given outpoint. If the method returns nil, the state of the channel
// will no longer be ChanStatusManuallyDisabled (currently the only state in
// which automatic / background requests are ignored).
//
// No update will be sent on the network.
func (m *ChanStatusManager) processAutoRequest(outpoint wire.OutPoint) error {
curState, err := m.getOrInitChanStatus(outpoint)
if err != nil {
return err
}
if curState.Status == ChanStatusManuallyDisabled {
log.Debugf("Restoring automatic control for manually disabled "+
"channel(%v)", outpoint)
m.chanStates.markDisabled(outpoint)
}
return nil
}
// markPendingInactiveChannels performs a sweep of the database's active
// channels and determines which, if any, should have a disable announcement
// scheduled. Once an active channel is determined to be pending-inactive, one
// of two transitions can follow. Either the channel is disabled because no
// request to enable is received before the scheduled disable is broadcast, or
// the channel is successfully re-enabled and channel is returned to an active
// state from the POV of the ChanStatusManager.
func (m *ChanStatusManager) markPendingInactiveChannels() {
channels, err := m.fetchChannels()
if err != nil {
log.Errorf("Unable to load active channels: %v", err)
return
}
for _, c := range channels {
// Determine the initial status of the active channel, and
// populate the entry in the chanStates map.
curState, err := m.getOrInitChanStatus(c.FundingOutpoint)
if err != nil {
log.Errorf("Unable to retrieve chan status for "+
"Channel(%v): %v", c.FundingOutpoint, err)
continue
}
// If the channel's status is not ChanStatusEnabled, we are
// done. Either it is already disabled, or it has been marked
// ChanStatusPendingDisable meaning that we have already
// scheduled the time at which it will be disabled.
if curState.Status != ChanStatusEnabled {
continue
}
// If our bookkeeping shows the channel as active, sample the
// htlcswitch to see if it believes the link is also active. If
// so, we will skip marking it as ChanStatusPendingDisabled.
chanID := lnwire.NewChanIDFromOutPoint(c.FundingOutpoint)
if m.cfg.IsChannelActive(chanID) {
continue
}
// Otherwise, we discovered that this link was inactive within
// the switch. Compute the time at which we will send out a
// disable if the peer is unable to reestablish a stable
// connection.
disableTime := time.Now().Add(m.cfg.ChanDisableTimeout)
log.Debugf("Marking channel(%v) pending-inactive",
c.FundingOutpoint)
m.chanStates.markPendingDisabled(c.FundingOutpoint, disableTime)
}
}
// disableInactiveChannels scans through the set of monitored channels, and
// broadcast a disable update for any pending inactive channels whose
// SendDisableTime has been superseded by the current time.
func (m *ChanStatusManager) disableInactiveChannels() {
// Now, disable any channels whose inactive chan timeout has elapsed.
now := time.Now()
for outpoint, state := range m.chanStates {
// Ignore statuses that are not in the pending-inactive state.
if state.Status != ChanStatusPendingDisabled {
continue
}
// Ignore statuses for which the disable timeout has not
// expired.
if state.SendDisableTime.After(now) {
continue
}
log.Infof("Announcing channel(%v) disabled "+
"[detected]", outpoint)
// Sign an update disabling the channel.
err := m.signAndSendNextUpdate(outpoint, true)
if err != nil {
log.Errorf("Unable to sign update disabling "+
"channel(%v): %v", outpoint, err)
// If the edge was not found, this is a likely indicator
// that the channel has been closed. Thus we remove the
// outpoint from the set of tracked outpoints to prevent
// further attempts.
if errors.Is(err, graphdb.ErrEdgeNotFound) {
log.Debugf("Removing channel(%v) from "+
"consideration for passive disabling",
outpoint)
delete(m.chanStates, outpoint)
}
continue
}
// Record that the channel has now been disabled.
m.chanStates.markDisabled(outpoint)
}
}
// fetchChannels returns the working set of channels managed by the
// ChanStatusManager. The returned channels are filtered to only contain public
// channels.
func (m *ChanStatusManager) fetchChannels() ([]*channeldb.OpenChannel, error) {
allChannels, err := m.cfg.DB.FetchAllOpenChannels()
if err != nil {
return nil, err
}
// Filter out private channels.
var channels []*channeldb.OpenChannel
for _, c := range allChannels {
// We'll skip any private channels, as they aren't used for
// routing within the network by other nodes.
if c.ChannelFlags&lnwire.FFAnnounceChannel == 0 {
continue
}
channels = append(channels, c)
}
return channels, nil
}
// signAndSendNextUpdate computes and signs a valid update for the passed
// outpoint, with the ability to toggle the disabled bit. The new update will
// use the current time as the update's timestamp, or increment the old
// timestamp by 1 to ensure the update can propagate. If signing is successful,
// the new update will be sent out on the network.
func (m *ChanStatusManager) signAndSendNextUpdate(outpoint wire.OutPoint,
disabled bool) error {
// Retrieve the latest update for this channel. We'll use this
// as our starting point to send the new update.
chanUpdate, private, err := m.fetchLastChanUpdateByOutPoint(outpoint)
if err != nil {
return err
}
err = SignChannelUpdate(
m.cfg.MessageSigner, m.cfg.OurKeyLoc, chanUpdate,
ChanUpdSetDisable(disabled), ChanUpdSetTimestamp,
)
if err != nil {
return err
}
return m.cfg.ApplyChannelUpdate(chanUpdate, &outpoint, private)
}
// fetchLastChanUpdateByOutPoint fetches the latest policy for our direction of
// a channel, and crafts a new ChannelUpdate with this policy. Returns an error
// in case our ChannelEdgePolicy is not found in the database. Also returns if
// the channel is private by checking AuthProof for nil.
func (m *ChanStatusManager) fetchLastChanUpdateByOutPoint(op wire.OutPoint) (
*lnwire.ChannelUpdate1, bool, error) {
// Get the edge info and policies for this channel from the graph.
info, edge1, edge2, err := m.cfg.Graph.FetchChannelEdgesByOutpoint(&op)
if err != nil {
return nil, false, err
}
update, err := ExtractChannelUpdate(
m.ourPubKeyBytes, info, edge1, edge2,
)
return update, info.AuthProof == nil, err
}
// loadInitialChanState determines the initial ChannelState for a particular
// outpoint. The initial ChanStatus for a given outpoint will either be
// ChanStatusEnabled or ChanStatusDisabled, determined by inspecting the bits on
// the most recent announcement. An error is returned if the latest update could
// not be retrieved.
func (m *ChanStatusManager) loadInitialChanState(
outpoint *wire.OutPoint) (ChannelState, error) {
lastUpdate, _, err := m.fetchLastChanUpdateByOutPoint(*outpoint)
if err != nil {
return ChannelState{}, err
}
// Determine the channel's starting status by inspecting the disable bit
// on last announcement we sent out.
var initialStatus ChanStatus
if lastUpdate.ChannelFlags&lnwire.ChanUpdateDisabled == 0 {
initialStatus = ChanStatusEnabled
} else {
initialStatus = ChanStatusDisabled
}
return ChannelState{
Status: initialStatus,
}, nil
}
// getOrInitChanStatus retrieves the current ChannelState for a particular
// outpoint. If the chanStates map already contains an entry for the outpoint,
// the value in the map is returned. Otherwise, the outpoint's initial status is
// computed and updated in the chanStates map before being returned.
func (m *ChanStatusManager) getOrInitChanStatus(
outpoint wire.OutPoint) (ChannelState, error) {
// Return the current ChannelState from the chanStates map if it is
// already known to the ChanStatusManager.
if curState, ok := m.chanStates[outpoint]; ok {
return curState, nil
}
// Otherwise, determine the initial state based on the last update we
// sent for the outpoint.
initialState, err := m.loadInitialChanState(&outpoint)
if err != nil {
return ChannelState{}, err
}
// Finally, store the initial state in the chanStates map. This will
// serve as are up-to-date view of the outpoint's current status, in
// addition to making the channel eligible for detecting inactivity.
m.chanStates[outpoint] = initialState
return initialState, nil
}
package netann
import (
"bytes"
"errors"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// chanAnn2MsgName is a string representing the name of the
// ChannelAnnouncement2 message. This string will be used during the
// construction of the tagged hash message to be signed when producing
// the signature for the ChannelAnnouncement2 message.
chanAnn2MsgName = "channel_announcement_2"
// chanAnn2SigFieldName is the name of the signature field of the
// ChannelAnnouncement2 message. This string will be used during the
// construction of the tagged hash message to be signed when producing
// the signature for the ChannelAnnouncement2 message.
chanAnn2SigFieldName = "signature"
)
// CreateChanAnnouncement is a helper function which creates all channel
// announcements given the necessary channel related database items. This
// function is used to transform out database structs into the corresponding wire
// structs for announcing new channels to other peers, or simply syncing up a
// peer's initial routing table upon connect.
func CreateChanAnnouncement(chanProof *models.ChannelAuthProof,
chanInfo *models.ChannelEdgeInfo,
e1, e2 *models.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1,
*lnwire.ChannelUpdate1, *lnwire.ChannelUpdate1, error) {
// First, using the parameters of the channel, along with the channel
// authentication chanProof, we'll create re-create the original
// authenticated channel announcement.
chanID := lnwire.NewShortChanIDFromInt(chanInfo.ChannelID)
chanAnn := &lnwire.ChannelAnnouncement1{
ShortChannelID: chanID,
NodeID1: chanInfo.NodeKey1Bytes,
NodeID2: chanInfo.NodeKey2Bytes,
ChainHash: chanInfo.ChainHash,
BitcoinKey1: chanInfo.BitcoinKey1Bytes,
BitcoinKey2: chanInfo.BitcoinKey2Bytes,
Features: lnwire.NewRawFeatureVector(),
ExtraOpaqueData: chanInfo.ExtraOpaqueData,
}
err := chanAnn.Features.Decode(bytes.NewReader(chanInfo.Features))
if err != nil {
return nil, nil, nil, err
}
chanAnn.BitcoinSig1, err = lnwire.NewSigFromECDSARawSignature(
chanProof.BitcoinSig1Bytes,
)
if err != nil {
return nil, nil, nil, err
}
chanAnn.BitcoinSig2, err = lnwire.NewSigFromECDSARawSignature(
chanProof.BitcoinSig2Bytes,
)
if err != nil {
return nil, nil, nil, err
}
chanAnn.NodeSig1, err = lnwire.NewSigFromECDSARawSignature(
chanProof.NodeSig1Bytes,
)
if err != nil {
return nil, nil, nil, err
}
chanAnn.NodeSig2, err = lnwire.NewSigFromECDSARawSignature(
chanProof.NodeSig2Bytes,
)
if err != nil {
return nil, nil, nil, err
}
// We'll unconditionally queue the channel's existence chanProof as it
// will need to be processed before either of the channel update
// networkMsgs.
// Since it's up to a node's policy as to whether they advertise the
// edge in a direction, we don't create an advertisement if the edge is
// nil.
var edge1Ann, edge2Ann *lnwire.ChannelUpdate1
if e1 != nil {
edge1Ann, err = ChannelUpdateFromEdge(chanInfo, e1)
if err != nil {
return nil, nil, nil, err
}
}
if e2 != nil {
edge2Ann, err = ChannelUpdateFromEdge(chanInfo, e2)
if err != nil {
return nil, nil, nil, err
}
}
return chanAnn, edge1Ann, edge2Ann, nil
}
// FetchPkScript defines a function that can be used to fetch the output script
// for the transaction with the given SCID.
type FetchPkScript func(*lnwire.ShortChannelID) ([]byte, error)
// ValidateChannelAnn validates the channel announcement.
func ValidateChannelAnn(a lnwire.ChannelAnnouncement,
fetchPkScript FetchPkScript) error {
switch ann := a.(type) {
case *lnwire.ChannelAnnouncement1:
return validateChannelAnn1(ann)
case *lnwire.ChannelAnnouncement2:
return validateChannelAnn2(ann, fetchPkScript)
default:
return fmt.Errorf("unhandled implementation of "+
"lnwire.ChannelAnnouncement: %T", a)
}
}
// validateChannelAnn1 validates the channel announcement message and checks
// that node signatures covers the announcement message, and that the bitcoin
// signatures covers the node keys.
func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error {
// First, we'll compute the digest (h) which is to be signed by each of
// the keys included within the node announcement message. This hash
// digest includes all the keys, so the (up to 4 signatures) will
// attest to the validity of each of the keys.
data, err := a.DataToSign()
if err != nil {
return err
}
dataHash := chainhash.DoubleHashB(data)
// First we'll verify that the passed bitcoin key signature is indeed a
// signature over the computed hash digest.
bitcoinSig1, err := a.BitcoinSig1.ToSignature()
if err != nil {
return err
}
bitcoinKey1, err := btcec.ParsePubKey(a.BitcoinKey1[:])
if err != nil {
return err
}
if !bitcoinSig1.Verify(dataHash, bitcoinKey1) {
return errors.New("can't verify first bitcoin signature")
}
// If that checks out, then we'll verify that the second bitcoin
// signature is a valid signature of the bitcoin public key over hash
// digest as well.
bitcoinSig2, err := a.BitcoinSig2.ToSignature()
if err != nil {
return err
}
bitcoinKey2, err := btcec.ParsePubKey(a.BitcoinKey2[:])
if err != nil {
return err
}
if !bitcoinSig2.Verify(dataHash, bitcoinKey2) {
return errors.New("can't verify second bitcoin signature")
}
// Both node signatures attached should indeed be a valid signature
// over the selected digest of the channel announcement signature.
nodeSig1, err := a.NodeSig1.ToSignature()
if err != nil {
return err
}
nodeKey1, err := btcec.ParsePubKey(a.NodeID1[:])
if err != nil {
return err
}
if !nodeSig1.Verify(dataHash, nodeKey1) {
return errors.New("can't verify data in first node signature")
}
nodeSig2, err := a.NodeSig2.ToSignature()
if err != nil {
return err
}
nodeKey2, err := btcec.ParsePubKey(a.NodeID2[:])
if err != nil {
return err
}
if !nodeSig2.Verify(dataHash, nodeKey2) {
return errors.New("can't verify data in second node signature")
}
return nil
}
// validateChannelAnn2 validates the channel announcement message and checks
// that message signature covers the announcement message.
func validateChannelAnn2(a *lnwire.ChannelAnnouncement2,
fetchPkScript FetchPkScript) error {
dataHash, err := ChanAnn2DigestToSign(a)
if err != nil {
return err
}
sig, err := a.Signature.ToSignature()
if err != nil {
return err
}
nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:])
if err != nil {
return err
}
nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:])
if err != nil {
return err
}
keys := []*btcec.PublicKey{
nodeKey1, nodeKey2,
}
// If the bitcoin keys are provided in the announcement, then it is
// assumed that the signature of the announcement is a 4-of-4 MuSig2
// over the bitcoin keys and node ID keys.
if a.BitcoinKey1.IsSome() && a.BitcoinKey2.IsSome() {
var (
btcKey1 tlv.RecordT[tlv.TlvType12, [33]byte]
btcKey2 tlv.RecordT[tlv.TlvType14, [33]byte]
)
btcKey1 = a.BitcoinKey1.UnwrapOr(btcKey1)
btcKey2 = a.BitcoinKey2.UnwrapOr(btcKey2)
bitcoinKey1, err := btcec.ParsePubKey(btcKey1.Val[:])
if err != nil {
return err
}
bitcoinKey2, err := btcec.ParsePubKey(btcKey2.Val[:])
if err != nil {
return err
}
keys = append(keys, bitcoinKey1, bitcoinKey2)
} else {
// If bitcoin keys are not provided, then we need to get the
// on-chain output key since this will be the 3rd key in the
// 3-of-3 MuSig2 signature.
pkScript, err := fetchPkScript(&a.ShortChannelID.Val)
if err != nil {
return err
}
outputKey, err := schnorr.ParsePubKey(pkScript[2:])
if err != nil {
return err
}
keys = append(keys, outputKey)
}
aggKey, _, _, err := musig2.AggregateKeys(keys, true)
if err != nil {
return err
}
if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) {
return fmt.Errorf("invalid sig")
}
return nil
}
// ChanAnn2DigestToSign computes the digest of the message to be signed.
func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash,
error) {
data, err := a.DataToSign()
if err != nil {
return nil, err
}
return MsgHash(chanAnn2MsgName, chanAnn2SigFieldName, data), nil
}
package netann
import (
"time"
"github.com/btcsuite/btcd/wire"
)
// ChanStatus is a type that enumerates the possible states a ChanStatusManager
// tracks for its known channels.
type ChanStatus uint8
const (
// ChanStatusEnabled indicates that the channel's last announcement has
// the disabled bit cleared.
ChanStatusEnabled ChanStatus = iota
// ChanStatusPendingDisabled indicates that the channel's last
// announcement has the disabled bit cleared, but that the channel was
// detected in an inactive state. Channels in this state will have a
// disabling announcement sent after the ChanInactiveTimeout expires
// from the time of the first detection--unless the channel is
// explicitly re-enabled before the disabling occurs.
ChanStatusPendingDisabled
// ChanStatusDisabled indicates that the channel's last announcement has
// the disabled bit set.
ChanStatusDisabled
// ChanStatusManuallyDisabled indicates that the channel's last
// announcement had the disabled bit set, and that a user manually
// requested disabling the channel. Channels in this state will ignore
// automatic / background attempts to re-enable the channel.
//
// Note that there's no corresponding ChanStatusManuallyEnabled state
// because even if a user manually requests enabling a channel, we still
// DO want to allow automatic / background processes to disable it.
// Otherwise, the network might be cluttered with channels that are
// advertised as enabled, but don't actually work or even exist.
ChanStatusManuallyDisabled
)
// ChannelState describes the ChanStatusManager's view of a channel, and
// describes the current state the channel's disabled status on the network.
type ChannelState struct {
// Status is the channel's current ChanStatus from the POV of the
// ChanStatusManager.
Status ChanStatus
// SendDisableTime is the earliest time at which the ChanStatusManager
// will passively send a new disable announcement on behalf of this
// channel.
//
// NOTE: This field is only non-zero if status is
// ChanStatusPendingDisabled.
SendDisableTime time.Time
}
// channelStates is a map of channel outpoints to their channelState. All
// changes made after setting an entry initially should be made using receiver
// methods below.
type channelStates map[wire.OutPoint]ChannelState
// markEnabled creates a channelState using ChanStatusEnabled.
func (s *channelStates) markEnabled(outpoint wire.OutPoint) {
(*s)[outpoint] = ChannelState{
Status: ChanStatusEnabled,
}
}
// markDisabled creates a channelState using ChanStatusDisabled.
func (s *channelStates) markDisabled(outpoint wire.OutPoint) {
(*s)[outpoint] = ChannelState{
Status: ChanStatusDisabled,
}
}
// markManuallyDisabled creates a channelState using
// ChanStatusManuallyDisabled.
func (s *channelStates) markManuallyDisabled(outpoint wire.OutPoint) {
(*s)[outpoint] = ChannelState{
Status: ChanStatusManuallyDisabled,
}
}
// markPendingDisabled creates a channelState using ChanStatusPendingDisabled
// and sets the ChannelState's SendDisableTime to sendDisableTime.
func (s *channelStates) markPendingDisabled(outpoint wire.OutPoint,
sendDisableTime time.Time) {
(*s)[outpoint] = ChannelState{
Status: ChanStatusPendingDisabled,
SendDisableTime: sendDisableTime,
}
}
package netann
import (
"bytes"
"fmt"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/pkg/errors"
)
const (
// chanUpdate2MsgName is a string representing the name of the
// ChannelUpdate2 message. This string will be used during the
// construction of the tagged hash message to be signed when producing
// the signature for the ChannelUpdate2 message.
chanUpdate2MsgName = "channel_update_2"
// chanUpdate2SigField is the name of the signature field of the
// ChannelUpdate2 message. This string will be used during the
// construction of the tagged hash message to be signed when producing
// the signature for the ChannelUpdate2 message.
chanUpdate2SigField = "signature"
)
// ErrUnableToExtractChanUpdate is returned when a channel update cannot be
// found for one of our active channels.
var ErrUnableToExtractChanUpdate = fmt.Errorf("unable to extract ChannelUpdate")
// ChannelUpdateModifier is a closure that makes in-place modifications to an
// lnwire.ChannelUpdate.
type ChannelUpdateModifier func(*lnwire.ChannelUpdate1)
// ChanUpdSetDisable is a functional option that sets the disabled channel flag
// if disabled is true, and clears the bit otherwise.
func ChanUpdSetDisable(disabled bool) ChannelUpdateModifier {
return func(update *lnwire.ChannelUpdate1) {
if disabled {
// Set the bit responsible for marking a channel as
// disabled.
update.ChannelFlags |= lnwire.ChanUpdateDisabled
} else {
// Clear the bit responsible for marking a channel as
// disabled.
update.ChannelFlags &= ^lnwire.ChanUpdateDisabled
}
}
}
// ChanUpdSetTimestamp is a functional option that sets the timestamp of the
// update to the current time, or increments it if the timestamp is already in
// the future.
func ChanUpdSetTimestamp(update *lnwire.ChannelUpdate1) {
newTimestamp := uint32(time.Now().Unix())
if newTimestamp <= update.Timestamp {
// Increment the prior value to ensure the timestamp
// monotonically increases, otherwise the update won't
// propagate.
newTimestamp = update.Timestamp + 1
}
update.Timestamp = newTimestamp
}
// SignChannelUpdate applies the given modifiers to the passed
// lnwire.ChannelUpdate, then signs the resulting update. The provided update
// should be the most recent, valid update, otherwise the timestamp may not
// monotonically increase from the prior.
//
// NOTE: This method modifies the given update.
func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator,
update *lnwire.ChannelUpdate1, mods ...ChannelUpdateModifier) error {
// Apply the requested changes to the channel update.
for _, modifier := range mods {
modifier(update)
}
// Create the DER-encoded ECDSA signature over the message digest.
sig, err := SignAnnouncement(signer, keyLoc, update)
if err != nil {
return err
}
// Parse the DER-encoded signature into a fixed-size 64-byte array.
update.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil {
return err
}
return nil
}
// ExtractChannelUpdate attempts to retrieve a lnwire.ChannelUpdate message from
// an edge's info and a set of routing policies.
//
// NOTE: The passed policies can be nil.
func ExtractChannelUpdate(ownerPubKey []byte,
info *models.ChannelEdgeInfo,
policies ...*models.ChannelEdgePolicy) (
*lnwire.ChannelUpdate1, error) {
// Helper function to extract the owner of the given policy.
owner := func(edge *models.ChannelEdgePolicy) []byte {
var pubKey *btcec.PublicKey
if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
pubKey, _ = info.NodeKey1()
} else {
pubKey, _ = info.NodeKey2()
}
// If pubKey was not found, just return nil.
if pubKey == nil {
return nil
}
return pubKey.SerializeCompressed()
}
// Extract the channel update from the policy we own, if any.
for _, edge := range policies {
if edge != nil && bytes.Equal(ownerPubKey, owner(edge)) {
return ChannelUpdateFromEdge(info, edge)
}
}
return nil, ErrUnableToExtractChanUpdate
}
// UnsignedChannelUpdateFromEdge reconstructs an unsigned ChannelUpdate from the
// given edge info and policy.
func UnsignedChannelUpdateFromEdge(info *models.ChannelEdgeInfo,
policy *models.ChannelEdgePolicy) *lnwire.ChannelUpdate1 {
return &lnwire.ChannelUpdate1{
ChainHash: info.ChainHash,
ShortChannelID: lnwire.NewShortChanIDFromInt(policy.ChannelID),
Timestamp: uint32(policy.LastUpdate.Unix()),
ChannelFlags: policy.ChannelFlags,
MessageFlags: policy.MessageFlags,
TimeLockDelta: policy.TimeLockDelta,
HtlcMinimumMsat: policy.MinHTLC,
HtlcMaximumMsat: policy.MaxHTLC,
BaseFee: uint32(policy.FeeBaseMSat),
FeeRate: uint32(policy.FeeProportionalMillionths),
ExtraOpaqueData: policy.ExtraOpaqueData,
}
}
// ChannelUpdateFromEdge reconstructs a signed ChannelUpdate from the given edge
// info and policy.
func ChannelUpdateFromEdge(info *models.ChannelEdgeInfo,
policy *models.ChannelEdgePolicy) (*lnwire.ChannelUpdate1, error) {
update := UnsignedChannelUpdateFromEdge(info, policy)
var err error
update.Signature, err = lnwire.NewSigFromECDSARawSignature(
policy.SigBytes,
)
if err != nil {
return nil, err
}
return update, nil
}
// ValidateChannelUpdateAnn validates the channel update announcement by
// checking (1) that the included signature covers the announcement and has been
// signed by the node's private key, and (2) that the announcement's message
// flags and optional fields are sane.
func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount,
a lnwire.ChannelUpdate) error {
if err := ValidateChannelUpdateFields(capacity, a); err != nil {
return err
}
return VerifyChannelUpdateSignature(a, pubKey)
}
// VerifyChannelUpdateSignature verifies that the channel update message was
// signed by the party with the given node public key.
func VerifyChannelUpdateSignature(msg lnwire.ChannelUpdate,
pubKey *btcec.PublicKey) error {
switch u := msg.(type) {
case *lnwire.ChannelUpdate1:
return verifyChannelUpdate1Signature(u, pubKey)
case *lnwire.ChannelUpdate2:
return verifyChannelUpdate2Signature(u, pubKey)
default:
return fmt.Errorf("unhandled implementation of "+
"lnwire.ChannelUpdate: %T", msg)
}
}
// verifyChannelUpdateSignature1 verifies that the channel update message was
// signed by the party with the given node public key.
func verifyChannelUpdate1Signature(msg *lnwire.ChannelUpdate1,
pubKey *btcec.PublicKey) error {
data, err := msg.DataToSign()
if err != nil {
return fmt.Errorf("unable to reconstruct message data: %w", err)
}
dataHash := chainhash.DoubleHashB(data)
nodeSig, err := msg.Signature.ToSignature()
if err != nil {
return err
}
if !nodeSig.Verify(dataHash, pubKey) {
return fmt.Errorf("invalid signature for channel update %v",
spew.Sdump(msg))
}
return nil
}
// verifyChannelUpdateSignature2 verifies that the channel update message was
// signed by the party with the given node public key.
func verifyChannelUpdate2Signature(c *lnwire.ChannelUpdate2,
pubKey *btcec.PublicKey) error {
digest, err := chanUpdate2DigestToSign(c)
if err != nil {
return fmt.Errorf("unable to reconstruct message data: %w", err)
}
nodeSig, err := c.Signature.ToSignature()
if err != nil {
return err
}
if !nodeSig.Verify(digest, pubKey) {
return fmt.Errorf("invalid signature for channel update %v",
spew.Sdump(c))
}
return nil
}
// ValidateChannelUpdateFields validates a channel update's message flags and
// corresponding update fields.
func ValidateChannelUpdateFields(capacity btcutil.Amount,
msg lnwire.ChannelUpdate) error {
switch u := msg.(type) {
case *lnwire.ChannelUpdate1:
return validateChannelUpdate1Fields(capacity, u)
case *lnwire.ChannelUpdate2:
return validateChannelUpdate2Fields(capacity, u)
default:
return fmt.Errorf("unhandled implementation of "+
"lnwire.ChannelUpdate: %T", msg)
}
}
// validateChannelUpdate1Fields validates a channel update's message flags and
// corresponding update fields.
func validateChannelUpdate1Fields(capacity btcutil.Amount,
msg *lnwire.ChannelUpdate1) error {
// The maxHTLC flag is mandatory.
if !msg.MessageFlags.HasMaxHtlc() {
return errors.Errorf("max htlc flag not set for channel "+
"update %v", spew.Sdump(msg))
}
maxHtlc := msg.HtlcMaximumMsat
if maxHtlc == 0 || maxHtlc < msg.HtlcMinimumMsat {
return errors.Errorf("invalid max htlc for channel "+
"update %v", spew.Sdump(msg))
}
// For light clients, the capacity will not be set so we'll skip
// checking whether the MaxHTLC value respects the channel's
// capacity.
capacityMsat := lnwire.NewMSatFromSatoshis(capacity)
if capacityMsat != 0 && maxHtlc > capacityMsat {
return errors.Errorf("max_htlc (%v) for channel update "+
"greater than capacity (%v)", maxHtlc, capacityMsat)
}
return nil
}
// validateChannelUpdate2Fields validates a channel update's message flags and
// corresponding update fields.
func validateChannelUpdate2Fields(capacity btcutil.Amount,
c *lnwire.ChannelUpdate2) error {
maxHtlc := c.HTLCMaximumMsat.Val
if maxHtlc == 0 || maxHtlc < c.HTLCMinimumMsat.Val {
return fmt.Errorf("invalid max htlc for channel update %v",
spew.Sdump(c))
}
// Checking whether the MaxHTLC value respects the channel's capacity.
capacityMsat := lnwire.NewMSatFromSatoshis(capacity)
if maxHtlc > capacityMsat {
return fmt.Errorf("max_htlc (%v) for channel update greater "+
"than capacity (%v)", maxHtlc, capacityMsat)
}
return nil
}
// ChanUpdate2DigestTag returns the tag to be used when signing the digest of
// a channel_update_2 message.
func ChanUpdate2DigestTag() []byte {
return MsgTag(chanUpdate2MsgName, chanUpdate2SigField)
}
// chanUpdate2DigestToSign computes the digest of the ChannelUpdate2 message to
// be signed.
func chanUpdate2DigestToSign(c *lnwire.ChannelUpdate2) ([]byte, error) {
data, err := c.DataToSign()
if err != nil {
return nil, err
}
hash := MsgHash(chanUpdate2MsgName, chanUpdate2SigField, data)
return hash[:], nil
}
package netann
import (
"net"
"sync"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker"
)
// HostAnnouncerConfig is the main config for the HostAnnouncer.
type HostAnnouncerConfig struct {
// Hosts is the set of hosts we should watch for IP changes.
Hosts []string
// RefreshTicker ticks each time we should check for any address
// changes.
RefreshTicker ticker.Ticker
// LookupHost performs DNS resolution on a given host and returns its
// addresses.
LookupHost func(string) (net.Addr, error)
// AdvertisedIPs is the set of IPs that we've already announced with
// our current NodeAnnouncement. This set will be constructed to avoid
// unnecessary node NodeAnnouncement updates.
AdvertisedIPs map[string]struct{}
// AnnounceNewIPs announces a new set of IP addresses for the backing
// Lightning node. The first set of addresses is the new set of
// addresses that we should advertise, while the other set are the
// stale addresses that we should no longer advertise.
AnnounceNewIPs func([]net.Addr, map[string]struct{}) error
}
// HostAnnouncer is a sub-system that allows a user to specify a set of hosts
// for lnd that will be continually resolved to notice any IP address changes.
// If the target IP address for a host changes, then we'll generate a new
// NodeAnnouncement that includes these new IPs.
type HostAnnouncer struct {
cfg HostAnnouncerConfig
quit chan struct{}
wg sync.WaitGroup
startOnce sync.Once
stopOnce sync.Once
}
// NewHostAnnouncer returns a new instance of the HostAnnouncer.
func NewHostAnnouncer(cfg HostAnnouncerConfig) *HostAnnouncer {
return &HostAnnouncer{
cfg: cfg,
quit: make(chan struct{}),
}
}
// Start starts the HostAnnouncer.
func (h *HostAnnouncer) Start() error {
h.startOnce.Do(func() {
log.Info("HostAnnouncer starting")
h.wg.Add(1)
go h.hostWatcher()
})
return nil
}
// Stop signals the HostAnnouncer for a graceful stop.
func (h *HostAnnouncer) Stop() error {
h.stopOnce.Do(func() {
log.Info("HostAnnouncer shutting down...")
defer log.Debug("HostAnnouncer shutdown complete")
close(h.quit)
h.wg.Wait()
})
return nil
}
// hostWatcher periodically attempts to resolve the IP for each host, updating
// them if they change within the interval.
func (h *HostAnnouncer) hostWatcher() {
defer h.wg.Done()
ipMapping := make(map[string]net.Addr)
refreshHosts := func() {
// We'll now run through each of our hosts to check if they had
// their backing IPs changed. If so, we'll want to re-announce
// them.
var addrsToUpdate []net.Addr
addrsToRemove := make(map[string]struct{})
for _, host := range h.cfg.Hosts {
newAddr, err := h.cfg.LookupHost(host)
if err != nil {
log.Warnf("unable to resolve IP for "+
"host %v: %v", host, err)
continue
}
// If nothing has changed since the last time we
// checked, then we don't need to do any updates.
oldAddr, oldAddrFound := ipMapping[host]
if oldAddrFound && oldAddr.String() == newAddr.String() {
continue
}
// Update the IP mapping now, as if this is the first
// time then we don't need to send a new announcement.
ipMapping[host] = newAddr
// If this IP has already been announced, then we'll
// skip it to avoid triggering an unnecessary node
// announcement update.
_, ipAnnounced := h.cfg.AdvertisedIPs[newAddr.String()]
if ipAnnounced {
continue
}
// If we've reached this point, then the old address
// was found, and the new address we just looked up
// differs from the old one.
log.Debugf("IP change detected! %v: %v -> %v", host,
oldAddr, newAddr)
// If we had already advertised an addr for this host,
// then we'll need to remove that old stale address.
if oldAddr != nil {
addrsToRemove[oldAddr.String()] = struct{}{}
}
addrsToUpdate = append(addrsToUpdate, newAddr)
}
// If we don't have any addresses to update, then we can skip
// things around until the next round.
if len(addrsToUpdate) == 0 {
log.Debugf("No IP changes detected for hosts: %v",
h.cfg.Hosts)
return
}
// Now that we know the set of IPs we need to update, we'll do
// them all in a single batch.
err := h.cfg.AnnounceNewIPs(addrsToUpdate, addrsToRemove)
if err != nil {
log.Warnf("unable to announce new IPs: %v", err)
}
}
refreshHosts()
h.cfg.RefreshTicker.Resume()
for {
select {
case <-h.cfg.RefreshTicker.Ticks():
log.Debugf("HostAnnouncer checking for any IP " +
"changes...")
refreshHosts()
case <-h.quit:
return
}
}
}
// NodeAnnUpdater describes a function that's able to update our current node
// announcement on disk. It returns the updated node announcement given a set
// of updates to be applied to the current node announcement.
type NodeAnnUpdater func(modifier ...NodeAnnModifier,
) (lnwire.NodeAnnouncement, error)
// IPAnnouncer is a factory function that generates a new function that uses
// the passed annUpdater function to to announce new IP changes for a given
// host.
func IPAnnouncer(annUpdater NodeAnnUpdater) func([]net.Addr,
map[string]struct{}) error {
return func(newAddrs []net.Addr, oldAddrs map[string]struct{}) error {
_, err := annUpdater(func(
currentNodeAnn *lnwire.NodeAnnouncement) {
// To ensure we don't duplicate any addresses, we'll
// filter out the same of addresses we should no longer
// advertise.
filteredAddrs := make(
[]net.Addr, 0, len(currentNodeAnn.Addresses),
)
for _, addr := range currentNodeAnn.Addresses {
if _, ok := oldAddrs[addr.String()]; ok {
continue
}
filteredAddrs = append(filteredAddrs, addr)
}
filteredAddrs = append(filteredAddrs, newAddrs...)
currentNodeAnn.Addresses = filteredAddrs
})
return err
}
}
package netann
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("NANN", nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package netann
import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// MsgHashTag will prefix the message name and the field name in order to
// construct the message tag.
const MsgHashTag = "lightning"
// MsgTag computes the full tag that will be used to prefix a message before
// calculating the tagged hash. The tag is constructed as follows:
//
// tag = "lightning"||"msg_name"||"field_name"
func MsgTag(msgName, fieldName string) []byte {
tag := []byte(MsgHashTag)
tag = append(tag, []byte(msgName)...)
return append(tag, []byte(fieldName)...)
}
// MsgHash computes the tagged hash of the given message as follows:
//
// tag = "lightning"||"msg_name"||"field_name"
// hash = sha256(sha246(tag) || sha256(tag) || msg)
func MsgHash(msgName, fieldName string, msg []byte) *chainhash.Hash {
tag := MsgTag(msgName, fieldName)
return chainhash.TaggedHash(tag, msg)
}
package netann
import (
"bytes"
"image/color"
"net"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
)
// NodeAnnModifier is a closure that makes in-place modifications to an
// lnwire.NodeAnnouncement.
type NodeAnnModifier func(*lnwire.NodeAnnouncement)
// NodeAnnSetAlias is a functional option that sets the alias of the
// given node announcement.
func NodeAnnSetAlias(alias lnwire.NodeAlias) func(*lnwire.NodeAnnouncement) {
return func(nodeAnn *lnwire.NodeAnnouncement) {
nodeAnn.Alias = alias
}
}
// NodeAnnSetAddrs is a functional option that allows updating the addresses of
// the given node announcement.
func NodeAnnSetAddrs(addrs []net.Addr) func(*lnwire.NodeAnnouncement) {
return func(nodeAnn *lnwire.NodeAnnouncement) {
nodeAnn.Addresses = addrs
}
}
// NodeAnnSetColor is a functional option that sets the color of the
// given node announcement.
func NodeAnnSetColor(newColor color.RGBA) func(*lnwire.NodeAnnouncement) {
return func(nodeAnn *lnwire.NodeAnnouncement) {
nodeAnn.RGBColor = newColor
}
}
// NodeAnnSetFeatures is a functional option that allows updating the features of
// the given node announcement.
func NodeAnnSetFeatures(features *lnwire.RawFeatureVector) func(*lnwire.NodeAnnouncement) {
return func(nodeAnn *lnwire.NodeAnnouncement) {
nodeAnn.Features = features
}
}
// NodeAnnSetTimestamp is a functional option that sets the timestamp of the
// announcement to the current time, or increments it if the timestamp is
// already in the future.
func NodeAnnSetTimestamp(nodeAnn *lnwire.NodeAnnouncement) {
newTimestamp := uint32(time.Now().Unix())
if newTimestamp <= nodeAnn.Timestamp {
// Increment the prior value to ensure the timestamp
// monotonically increases, otherwise the announcement won't
// propagate.
newTimestamp = nodeAnn.Timestamp + 1
}
nodeAnn.Timestamp = newTimestamp
}
// SignNodeAnnouncement signs the lnwire.NodeAnnouncement provided, which
// should be the most recent, valid update, otherwise the timestamp may not
// monotonically increase from the prior.
func SignNodeAnnouncement(signer lnwallet.MessageSigner,
keyLoc keychain.KeyLocator, nodeAnn *lnwire.NodeAnnouncement) error {
// Create the DER-encoded ECDSA signature over the message digest.
sig, err := SignAnnouncement(signer, keyLoc, nodeAnn)
if err != nil {
return err
}
// Parse the DER-encoded signature into a fixed-size 64-byte array.
nodeAnn.Signature, err = lnwire.NewSigFromSignature(sig)
return err
}
// ValidateNodeAnn validates the node announcement by ensuring that the
// attached signature is needed a signature of the node announcement under the
// specified node public key.
func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error {
// Reconstruct the data of announcement which should be covered by the
// signature so we can verify the signature shortly below
data, err := a.DataToSign()
if err != nil {
return err
}
nodeSig, err := a.Signature.ToSignature()
if err != nil {
return err
}
nodeKey, err := btcec.ParsePubKey(a.NodeID[:])
if err != nil {
return err
}
// Finally ensure that the passed signature is valid, if not we'll
// return an error so this node announcement can be rejected.
dataHash := chainhash.DoubleHashB(data)
if !nodeSig.Verify(dataHash, nodeKey) {
var msgBuf bytes.Buffer
if _, err := lnwire.WriteMessage(&msgBuf, a, 0); err != nil {
return err
}
return errors.Errorf("signature on NodeAnnouncement(%x) is "+
"invalid: %x", nodeKey.SerializeCompressed(),
msgBuf.Bytes())
}
return nil
}
package netann
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
)
// NodeSigner is an implementation of the MessageSigner interface backed by the
// identity private key of running lnd node.
type NodeSigner struct {
keySigner keychain.SingleKeyMessageSigner
}
// NewNodeSigner creates a new instance of the NodeSigner backed by the target
// private key.
func NewNodeSigner(keySigner keychain.SingleKeyMessageSigner) *NodeSigner {
return &NodeSigner{
keySigner: keySigner,
}
}
// SignMessage signs a double-sha256 digest of the passed msg under the
// resident node's private key described in the key locator. If the target key
// locator is _not_ the node's private key, then an error will be returned.
func (n *NodeSigner) SignMessage(keyLoc keychain.KeyLocator,
msg []byte, doubleHash bool) (*ecdsa.Signature, error) {
// If this isn't our identity public key, then we'll exit early with an
// error as we can't sign with this key.
if keyLoc != n.keySigner.KeyLocator() {
return nil, fmt.Errorf("unknown public key locator")
}
// Otherwise, we'll sign the double-sha256 of the target message.
sig, err := n.keySigner.SignMessage(msg, doubleHash)
if err != nil {
return nil, fmt.Errorf("can't sign the message: %w", err)
}
return sig, nil
}
// SignMessageCompact signs a single or double sha256 digest of the msg
// parameter under the resident node's private key. The returned signature is a
// pubkey-recoverable signature.
func (n *NodeSigner) SignMessageCompact(msg []byte, doubleHash bool) ([]byte,
error) {
return n.keySigner.SignMessageCompact(msg, doubleHash)
}
// A compile time check to ensure that NodeSigner implements the MessageSigner
// interface.
var _ lnwallet.MessageSigner = (*NodeSigner)(nil)
package netann
import (
"fmt"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
)
// SignAnnouncement signs any type of gossip message that is announced on the
// network.
func SignAnnouncement(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator,
msg lnwire.Message) (input.Signature, error) {
var (
data []byte
err error
)
switch m := msg.(type) {
case *lnwire.ChannelAnnouncement1:
data, err = m.DataToSign()
case *lnwire.ChannelUpdate1:
data, err = m.DataToSign()
case *lnwire.NodeAnnouncement:
data, err = m.DataToSign()
default:
return nil, fmt.Errorf("can't sign %T message", m)
}
if err != nil {
return nil, fmt.Errorf("unable to get data to sign: %w", err)
}
return signer.SignMessage(keyLoc, data, true)
}
package record
import (
"fmt"
"io"
"github.com/lightningnetwork/lnd/tlv"
)
// AMPOnionType is the type used in the onion to reference the AMP fields:
// root_share, set_id, and child_index.
const AMPOnionType tlv.Type = 14
// AMP is a record that encodes the fields necessary for atomic multi-path
// payments.
type AMP struct {
rootShare [32]byte
setID [32]byte
childIndex uint32
}
// MaxAmpPayLoadSize is an AMP Record which when serialized to a tlv record uses
// the maximum payload size. The `childIndex` is created randomly and is a
// 4 byte `varint` type so we make sure we use an index which will be encoded in
// 4 bytes.
var MaxAmpPayLoadSize = AMP{
rootShare: [32]byte{},
setID: [32]byte{},
childIndex: 0x80000000,
}
// NewAMP generate a new AMP record with the given root_share, set_id, and
// child_index.
func NewAMP(rootShare, setID [32]byte, childIndex uint32) *AMP {
return &{
rootShare: rootShare,
setID: setID,
childIndex: childIndex,
}
}
// RootShare returns the root share contained in the AMP record.
func (a *AMP) RootShare() [32]byte {
return a.rootShare
}
// SetID returns the set id contained in the AMP record.
func (a *AMP) SetID() [32]byte {
return a.setID
}
// ChildIndex returns the child index contained in the AMP record.
func (a *AMP) ChildIndex() uint32 {
return a.childIndex
}
// AMPEncoder writes the AMP record to the provided io.Writer.
func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*AMP); ok {
if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil {
return err
}
if err := tlv.EBytes32(w, &v.setID, buf); err != nil {
return err
}
return tlv.ETUint32T(w, v.childIndex, buf)
}
return tlv.NewTypeForEncodingErr(val, "AMP")
}
const (
// minAMPLength is the minimum length of a serialized AMP TLV record,
// which occurs when the truncated encoding of child_index takes 0
// bytes, leaving only the root_share and set_id.
minAMPLength = 64
// maxAMPLength is the maximum length of a serialized AMP TLV record,
// which occurs when the truncated encoding of a child_index takes 2
// bytes.
maxAMPLength = 68
)
// AMPDecoder reads the AMP record from the provided io.Reader.
func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength {
if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil {
return err
}
if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil {
return err
}
return tlv.DTUint32(r, &v.childIndex, buf, l-minAMPLength)
}
return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength)
}
// Record returns a tlv.Record that can be used to encode or decode this record.
func (a *AMP) Record() tlv.Record {
return tlv.MakeDynamicRecord(
AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder,
)
}
// PayloadSize returns the size this record takes up in encoded form.
func (a *AMP) PayloadSize() uint64 {
return 32 + 32 + tlv.SizeTUint32(a.childIndex)
}
// String returns a human-readable description of the amp payload fields.
func (a *AMP) String() string {
if a == nil {
return "<nil>"
}
return fmt.Sprintf("root_share=%x set_id=%x child_index=%d",
a.rootShare, a.setID, a.childIndex)
}
package record
import (
"bytes"
"encoding/binary"
"io"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// AverageDummyHopPayloadSize is the size of a standard blinded path dummy hop
// payload. In most cases, this is larger than the other payload types and so
// to make sure that a sender cannot use this fact to know if a dummy hop is
// present or not, we'll make sure to always pad all payloads to at least this
// size.
const AverageDummyHopPayloadSize = 51
// BlindedRouteData contains the information that is included in a blinded
// route encrypted data blob that is created by the recipient to provide
// forwarding information.
type BlindedRouteData struct {
// Padding is an optional set of bytes that a recipient can use to pad
// the data so that the encrypted recipient data blobs are all the same
// length.
Padding tlv.OptionalRecordT[tlv.TlvType1, []byte]
// ShortChannelID is the channel ID of the next hop.
ShortChannelID tlv.OptionalRecordT[tlv.TlvType2, lnwire.ShortChannelID]
// NextNodeID is the node ID of the next node on the path. In the
// context of blinded path payments, this is used to indicate the
// presence of dummy hops that need to be peeled from the onion.
NextNodeID tlv.OptionalRecordT[tlv.TlvType4, *btcec.PublicKey]
// PathID is a secret set of bytes that the blinded path creator will
// set so that they can check the value on decryption to ensure that the
// path they created was used for the intended purpose.
PathID tlv.OptionalRecordT[tlv.TlvType6, []byte]
// NextBlindingOverride is a blinding point that should be switched
// in for the next hop. This is used to combine two blinded paths into
// one (which primarily is used in onion messaging, but in theory
// could be used for payments as well).
NextBlindingOverride tlv.OptionalRecordT[tlv.TlvType8, *btcec.PublicKey]
// RelayInfo provides the relay parameters for the hop.
RelayInfo tlv.OptionalRecordT[tlv.TlvType10, PaymentRelayInfo]
// Constraints provides the payment relay constraints for the hop.
Constraints tlv.OptionalRecordT[tlv.TlvType12, PaymentConstraints]
// Features is the set of features the payment requires.
Features tlv.OptionalRecordT[tlv.TlvType14, lnwire.FeatureVector]
}
// NewNonFinalBlindedRouteData creates the data that's provided for hops within
// a blinded route.
func NewNonFinalBlindedRouteData(chanID lnwire.ShortChannelID,
blindingOverride *btcec.PublicKey, relayInfo PaymentRelayInfo,
constraints *PaymentConstraints,
features *lnwire.FeatureVector) *BlindedRouteData {
info := &BlindedRouteData{
ShortChannelID: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType2](chanID),
),
RelayInfo: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType10](relayInfo),
),
}
if blindingOverride != nil {
info.NextBlindingOverride = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType8](blindingOverride))
}
if constraints != nil {
info.Constraints = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType12](*constraints))
}
if features != nil {
info.Features = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType14](*features),
)
}
return info
}
// NewFinalHopBlindedRouteData creates the data that's provided for the final
// hop in a blinded route.
func NewFinalHopBlindedRouteData(constraints *PaymentConstraints,
pathID []byte) *BlindedRouteData {
var data BlindedRouteData
if pathID != nil {
data.PathID = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6](pathID),
)
}
if constraints != nil {
data.Constraints = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType12](*constraints))
}
return &data
}
// NewDummyHopRouteData creates the data that's provided for any hop preceding
// a dummy hop. The presence of such a payload indicates to the reader that
// they are the intended recipient and should peel the remainder of the onion.
func NewDummyHopRouteData(ourPubKey *btcec.PublicKey,
relayInfo PaymentRelayInfo,
constraints PaymentConstraints) *BlindedRouteData {
return &BlindedRouteData{
NextNodeID: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType4](ourPubKey),
),
RelayInfo: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType10](relayInfo),
),
Constraints: tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType12](constraints),
),
}
}
// DecodeBlindedRouteData decodes the data provided within a blinded route.
func DecodeBlindedRouteData(r io.Reader) (*BlindedRouteData, error) {
var (
d BlindedRouteData
padding = d.Padding.Zero()
scid = d.ShortChannelID.Zero()
nextNodeID = d.NextNodeID.Zero()
pathID = d.PathID.Zero()
blindingOverride = d.NextBlindingOverride.Zero()
relayInfo = d.RelayInfo.Zero()
constraints = d.Constraints.Zero()
features = d.Features.Zero()
)
var tlvRecords lnwire.ExtraOpaqueData
if err := lnwire.ReadElements(r, &tlvRecords); err != nil {
return nil, err
}
typeMap, err := tlvRecords.ExtractRecords(
&padding, &scid, &nextNodeID, &pathID, &blindingOverride,
&relayInfo, &constraints, &features,
)
if err != nil {
return nil, err
}
val, ok := typeMap[d.Padding.TlvType()]
if ok && val == nil {
d.Padding = tlv.SomeRecordT(padding)
}
if val, ok := typeMap[d.ShortChannelID.TlvType()]; ok && val == nil {
d.ShortChannelID = tlv.SomeRecordT(scid)
}
if val, ok := typeMap[d.NextNodeID.TlvType()]; ok && val == nil {
d.NextNodeID = tlv.SomeRecordT(nextNodeID)
}
if val, ok := typeMap[d.PathID.TlvType()]; ok && val == nil {
d.PathID = tlv.SomeRecordT(pathID)
}
val, ok = typeMap[d.NextBlindingOverride.TlvType()]
if ok && val == nil {
d.NextBlindingOverride = tlv.SomeRecordT(blindingOverride)
}
if val, ok := typeMap[d.RelayInfo.TlvType()]; ok && val == nil {
d.RelayInfo = tlv.SomeRecordT(relayInfo)
}
if val, ok := typeMap[d.Constraints.TlvType()]; ok && val == nil {
d.Constraints = tlv.SomeRecordT(constraints)
}
if val, ok := typeMap[d.Features.TlvType()]; ok && val == nil {
d.Features = tlv.SomeRecordT(features)
}
return &d, nil
}
// EncodeBlindedRouteData encodes the blinded route data provided.
func EncodeBlindedRouteData(data *BlindedRouteData) ([]byte, error) {
var (
e lnwire.ExtraOpaqueData
recordProducers = make([]tlv.RecordProducer, 0, 5)
)
data.Padding.WhenSome(func(p tlv.RecordT[tlv.TlvType1, []byte]) {
recordProducers = append(recordProducers, &p)
})
data.ShortChannelID.WhenSome(func(scid tlv.RecordT[tlv.TlvType2,
lnwire.ShortChannelID]) {
recordProducers = append(recordProducers, &scid)
})
data.NextNodeID.WhenSome(func(f tlv.RecordT[tlv.TlvType4,
*btcec.PublicKey]) {
recordProducers = append(recordProducers, &f)
})
data.PathID.WhenSome(func(pathID tlv.RecordT[tlv.TlvType6, []byte]) {
recordProducers = append(recordProducers, &pathID)
})
data.NextBlindingOverride.WhenSome(func(pk tlv.RecordT[tlv.TlvType8,
*btcec.PublicKey]) {
recordProducers = append(recordProducers, &pk)
})
data.RelayInfo.WhenSome(func(r tlv.RecordT[tlv.TlvType10,
PaymentRelayInfo]) {
recordProducers = append(recordProducers, &r)
})
data.Constraints.WhenSome(func(cs tlv.RecordT[tlv.TlvType12,
PaymentConstraints]) {
recordProducers = append(recordProducers, &cs)
})
data.Features.WhenSome(func(f tlv.RecordT[tlv.TlvType14,
lnwire.FeatureVector]) {
recordProducers = append(recordProducers, &f)
})
if err := e.PackRecords(recordProducers...); err != nil {
return nil, err
}
return e[:], nil
}
// PadBy adds "n" padding bytes to the BlindedRouteData using the Padding field.
// Callers should be aware that the total payload size will change by more than
// "n" since the "n" bytes will be prefixed by BigSize type and length fields.
// Callers may need to call PadBy iteratively until each encrypted data packet
// is the same size and so each call will overwrite the Padding record.
// Note that calling PadBy with an n value of 0 will still result in a zero
// length TLV entry being added.
func (b *BlindedRouteData) PadBy(n int) {
b.Padding = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType1](make([]byte, n)),
)
}
// PaymentRelayInfo describes the relay policy for a blinded path.
type PaymentRelayInfo struct {
// CltvExpiryDelta is the expiry delta for the payment.
CltvExpiryDelta uint16
// FeeRate is the fee rate that will be charged per millionth of a
// satoshi.
FeeRate uint32
// BaseFee is the per-htlc fee charged in milli-satoshis.
BaseFee lnwire.MilliSatoshi
}
// Record creates a tlv.Record that encodes the payment relay (type 10) type for
// an encrypted blob payload.
func (i *PaymentRelayInfo) Record() tlv.Record {
return tlv.MakeDynamicRecord(
10, &i, func() uint64 {
// uint16 + uint32 + tuint32
return 2 + 4 + tlv.SizeTUint32(uint32(i.BaseFee))
}, encodePaymentRelay, decodePaymentRelay,
)
}
func encodePaymentRelay(w io.Writer, val interface{}, buf *[8]byte) error {
if t, ok := val.(**PaymentRelayInfo); ok {
relayInfo := *t
// Just write our first 6 bytes directly.
binary.BigEndian.PutUint16(buf[:2], relayInfo.CltvExpiryDelta)
binary.BigEndian.PutUint32(buf[2:6], relayInfo.FeeRate)
if _, err := w.Write(buf[0:6]); err != nil {
return err
}
baseFee := uint32(relayInfo.BaseFee)
// We can safely reuse buf here because we overwrite its
// contents.
return tlv.ETUint32(w, &baseFee, buf)
}
return tlv.NewTypeForEncodingErr(val, "**hop.PaymentRelayInfo")
}
func decodePaymentRelay(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if t, ok := val.(**PaymentRelayInfo); ok && l <= 10 {
scratch := make([]byte, l)
n, err := io.ReadFull(r, scratch)
if err != nil {
return err
}
// We expect at least 6 bytes, because we have 2 bytes for
// cltv delta and 4 bytes for fee rate.
if n < 6 {
return tlv.NewTypeForDecodingErr(val,
"*hop.paymentRelayInfo", uint64(n), 6)
}
relayInfo := *t
relayInfo.CltvExpiryDelta = binary.BigEndian.Uint16(
scratch[0:2],
)
relayInfo.FeeRate = binary.BigEndian.Uint32(scratch[2:6])
// To be able to re-use the DTUint32 function we create a
// buffer with just the bytes holding the variable length u32.
// If the base fee is zero, this will be an empty buffer, which
// is okay.
b := bytes.NewBuffer(scratch[6:])
var baseFee uint32
err = tlv.DTUint32(b, &baseFee, buf, l-6)
if err != nil {
return err
}
relayInfo.BaseFee = lnwire.MilliSatoshi(baseFee)
return nil
}
return tlv.NewTypeForDecodingErr(val, "*hop.paymentRelayInfo", l, 10)
}
// PaymentConstraints is a set of restrictions on a payment.
type PaymentConstraints struct {
// MaxCltvExpiry is the maximum expiry height for the payment.
MaxCltvExpiry uint32
// HtlcMinimumMsat is the minimum htlc size for the payment.
HtlcMinimumMsat lnwire.MilliSatoshi
}
func (p *PaymentConstraints) Record() tlv.Record {
return tlv.MakeDynamicRecord(
12, &p, func() uint64 {
// uint32 + tuint64.
return 4 + tlv.SizeTUint64(uint64(
p.HtlcMinimumMsat,
))
},
encodePaymentConstraints, decodePaymentConstraints,
)
}
func encodePaymentConstraints(w io.Writer, val interface{},
buf *[8]byte) error {
if c, ok := val.(**PaymentConstraints); ok {
constraints := *c
binary.BigEndian.PutUint32(buf[:4], constraints.MaxCltvExpiry)
if _, err := w.Write(buf[:4]); err != nil {
return err
}
// We can safely re-use buf here because we overwrite its
// contents.
htlcMsat := uint64(constraints.HtlcMinimumMsat)
return tlv.ETUint64(w, &htlcMsat, buf)
}
return tlv.NewTypeForEncodingErr(val, "**PaymentConstraints")
}
func decodePaymentConstraints(r io.Reader, val interface{}, buf *[8]byte,
l uint64) error {
if c, ok := val.(**PaymentConstraints); ok && l <= 12 {
scratch := make([]byte, l)
n, err := io.ReadFull(r, scratch)
if err != nil {
return err
}
// We expect at least 4 bytes for our uint32.
if n < 4 {
return tlv.NewTypeForDecodingErr(val,
"*paymentConstraints", uint64(n), 4)
}
payConstraints := *c
payConstraints.MaxCltvExpiry = binary.BigEndian.Uint32(
scratch[:4],
)
// This could be empty if our minimum is zero, that's okay.
var (
b = bytes.NewBuffer(scratch[4:])
minHtlc uint64
)
err = tlv.DTUint64(b, &minHtlc, buf, l-4)
if err != nil {
return err
}
payConstraints.HtlcMinimumMsat = lnwire.MilliSatoshi(minHtlc)
return nil
}
return tlv.NewTypeForDecodingErr(val, "**PaymentConstraints", l, l)
}
package record
import (
"fmt"
)
const (
// CustomTypeStart is the start of the custom tlv type range as defined
// in BOLT 01.
CustomTypeStart = 65536
)
// CustomSet stores a set of custom key/value pairs.
type CustomSet map[uint64][]byte
// Validate checks that all custom records are in the custom type range.
func (c CustomSet) Validate() error {
for key := range c {
if key < CustomTypeStart {
return fmt.Errorf("no custom records with types "+
"below %v allowed", CustomTypeStart)
}
}
return nil
}
// IsKeysend checks if the custom records contain the key send type.
func (c CustomSet) IsKeysend() bool {
return c[KeySendType] != nil
}
package record
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// AmtOnionType is the type used in the onion to reference the amount to
// send to the next hop.
AmtOnionType tlv.Type = 2
// LockTimeTLV is the type used in the onion to reference the CLTV
// value that should be used for the next hop's HTLC.
LockTimeOnionType tlv.Type = 4
// NextHopOnionType is the type used in the onion to reference the ID
// of the next hop.
NextHopOnionType tlv.Type = 6
// EncryptedDataOnionType is the type used to include encrypted data
// provided by the receiver in the onion for use in blinded paths.
EncryptedDataOnionType tlv.Type = 10
// BlindingPointOnionType is the type used to include receiver provided
// ephemeral keys in the onion that are used in blinded paths.
BlindingPointOnionType tlv.Type = 12
// MetadataOnionType is the type used in the onion for the payment
// metadata.
MetadataOnionType tlv.Type = 16
// TotalAmtMsatBlindedType is the type used in the onion for the total
// amount field that is included in the final hop for blinded payments.
TotalAmtMsatBlindedType tlv.Type = 18
)
// NewAmtToFwdRecord creates a tlv.Record that encodes the amount_to_forward
// (type 2) for an onion payload.
func NewAmtToFwdRecord(amt *uint64) tlv.Record {
return tlv.MakeDynamicRecord(
AmtOnionType, amt, func() uint64 {
return tlv.SizeTUint64(*amt)
},
tlv.ETUint64, tlv.DTUint64,
)
}
// NewLockTimeRecord creates a tlv.Record that encodes the outgoing_cltv_value
// (type 4) for an onion payload.
func NewLockTimeRecord(lockTime *uint32) tlv.Record {
return tlv.MakeDynamicRecord(
LockTimeOnionType, lockTime, func() uint64 {
return tlv.SizeTUint32(*lockTime)
},
tlv.ETUint32, tlv.DTUint32,
)
}
// NewNextHopIDRecord creates a tlv.Record that encodes the short_channel_id
// (type 6) for an onion payload.
func NewNextHopIDRecord(cid *uint64) tlv.Record {
return tlv.MakePrimitiveRecord(NextHopOnionType, cid)
}
// NewEncryptedDataRecord creates a tlv.Record that encodes the encrypted_data
// (type 10) record for an onion payload.
func NewEncryptedDataRecord(data *[]byte) tlv.Record {
return tlv.MakePrimitiveRecord(EncryptedDataOnionType, data)
}
// NewBlindingPointRecord creates a tlv.Record that encodes the blinding_point
// (type 12) record for an onion payload.
func NewBlindingPointRecord(point **btcec.PublicKey) tlv.Record {
return tlv.MakePrimitiveRecord(BlindingPointOnionType, point)
}
// NewMetadataRecord creates a tlv.Record that encodes the metadata (type 10)
// for an onion payload.
func NewMetadataRecord(metadata *[]byte) tlv.Record {
return tlv.MakeDynamicRecord(
MetadataOnionType, metadata,
func() uint64 {
return uint64(len(*metadata))
},
tlv.EVarBytes, tlv.DVarBytes,
)
}
// NewTotalAmtMsatBlinded creates a tlv.Record that encodes the
// total_amount_msat for the final an onion payload within a blinded route.
func NewTotalAmtMsatBlinded(amt *uint64) tlv.Record {
return tlv.MakeDynamicRecord(
TotalAmtMsatBlindedType, amt, func() uint64 {
return tlv.SizeTUint64(*amt)
},
tlv.ETUint64, tlv.DTUint64,
)
}
package record
import (
"fmt"
"io"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
// MPPOnionType is the type used in the onion to reference the MPP fields:
// total_amt and payment_addr.
const MPPOnionType tlv.Type = 8
// MPP is a record that encodes the fields necessary for multi-path payments.
type MPP struct {
// paymentAddr is a random, receiver-generated value used to avoid
// collisions with concurrent payers.
paymentAddr [32]byte
// totalMsat is the total value of the payment, potentially spread
// across more than one HTLC.
totalMsat lnwire.MilliSatoshi
}
// NewMPP generates a new MPP record with the given total and payment address.
func NewMPP(total lnwire.MilliSatoshi, addr [32]byte) *MPP {
return &MPP{
paymentAddr: addr,
totalMsat: total,
}
}
// PaymentAddr returns the payment address contained in the MPP record.
func (r *MPP) PaymentAddr() [32]byte {
return r.paymentAddr
}
// TotalMsat returns the total value of an MPP payment in msats.
func (r *MPP) TotalMsat() lnwire.MilliSatoshi {
return r.totalMsat
}
// MPPEncoder writes the MPP record to the provided io.Writer.
func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*MPP); ok {
err := tlv.EBytes32(w, &v.paymentAddr, buf)
if err != nil {
return err
}
return tlv.ETUint64T(w, uint64(v.totalMsat), buf)
}
return tlv.NewTypeForEncodingErr(val, "MPP")
}
const (
// minMPPLength is the minimum length of a serialized MPP TLV record,
// which occurs when the truncated encoding of total_amt_msat takes 0
// bytes, leaving only the payment_addr.
minMPPLength = 32
// maxMPPLength is the maximum length of a serialized MPP TLV record,
// which occurs when the truncated encoding of total_amt_msat takes 8
// bytes.
maxMPPLength = 40
)
// MPPDecoder reads the MPP record to the provided io.Reader.
func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength {
if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil {
return err
}
var total uint64
if err := tlv.DTUint64(r, &total, buf, l-32); err != nil {
return err
}
v.totalMsat = lnwire.MilliSatoshi(total)
return nil
}
return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength)
}
// Record returns a tlv.Record that can be used to encode or decode this record.
func (r *MPP) Record() tlv.Record {
// Fixed-size, 32 byte payment address followed by truncated 64-bit
// total msat.
size := func() uint64 {
return 32 + tlv.SizeTUint64(uint64(r.totalMsat))
}
return tlv.MakeDynamicRecord(
MPPOnionType, r, size, MPPEncoder, MPPDecoder,
)
}
// PayloadSize returns the size this record takes up in encoded form.
func (r *MPP) PayloadSize() uint64 {
return 32 + tlv.SizeTUint64(uint64(r.totalMsat))
}
// String returns a human-readable representation of the mpp payload field.
func (r *MPP) String() string {
if r == nil {
return "<nil>"
}
return fmt.Sprintf("total=%v, addr=%x", r.totalMsat, r.paymentAddr)
}
package routing
import (
"errors"
"fmt"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
var (
// ErrNoPayLoadSizeFunc is returned when no payload size function is
// definied.
ErrNoPayLoadSizeFunc = errors.New("no payloadSizeFunc defined for " +
"additional edge")
)
// AdditionalEdge is an interface which specifies additional edges which can
// be appended to an existing route. Compared to normal edges of a route they
// provide an explicit payload size function and are introduced because blinded
// paths differ in their payload structure.
type AdditionalEdge interface {
// IntermediatePayloadSize returns the size of the payload for the
// additional edge when being an intermediate hop in a route NOT the
// final hop.
IntermediatePayloadSize(amount lnwire.MilliSatoshi, expiry uint32,
channelID uint64) uint64
// EdgePolicy returns the policy of the additional edge.
EdgePolicy() *models.CachedEdgePolicy
// BlindedPayment returns the BlindedPayment that this additional edge
// info was derived from. It will return nil if this edge was not
// derived from a blinded route.
BlindedPayment() *BlindedPayment
}
// PayloadSizeFunc defines the interface for the payload size function.
type PayloadSizeFunc func(amount lnwire.MilliSatoshi, expiry uint32,
channelID uint64) uint64
// PrivateEdge implements the AdditionalEdge interface. As the name implies it
// is used for private route hints that the receiver adds for example to an
// invoice.
type PrivateEdge struct {
policy *models.CachedEdgePolicy
}
// EdgePolicy return the policy of the PrivateEdge.
func (p *PrivateEdge) EdgePolicy() *models.CachedEdgePolicy {
return p.policy
}
// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if
// this edge were to be included in a route.
func (p *PrivateEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi,
expiry uint32, channelID uint64) uint64 {
hop := route.Hop{
AmtToForward: amount,
OutgoingTimeLock: expiry,
}
return hop.PayloadSize(channelID)
}
// BlindedPayment is a no-op for a PrivateEdge since it is not associated with
// a blinded payment. This will thus return nil.
func (p *PrivateEdge) BlindedPayment() *BlindedPayment {
return nil
}
// BlindedEdge implements the AdditionalEdge interface. Blinded hops are viewed
// as additional edges because they are appended at the end of a normal route.
type BlindedEdge struct {
policy *models.CachedEdgePolicy
// blindedPayment is the BlindedPayment that this blinded edge was
// derived from.
blindedPayment *BlindedPayment
// hopIndex is the index of the hop in the blinded payment path that
// this edge is associated with.
hopIndex int
}
// NewBlindedEdge constructs a new BlindedEdge which packages the policy info
// for a specific hop within the given blinded payment path. The hop index
// should correspond to the hop within the blinded payment that this edge is
// associated with.
func NewBlindedEdge(policy *models.CachedEdgePolicy, payment *BlindedPayment,
hopIndex int) (*BlindedEdge, error) {
if payment == nil {
return nil, fmt.Errorf("blinded payment cannot be nil for " +
"blinded edge")
}
if hopIndex < 0 || hopIndex >= len(payment.BlindedPath.BlindedHops) {
return nil, fmt.Errorf("the hop index %d is outside the "+
"valid range between 0 and %d", hopIndex,
len(payment.BlindedPath.BlindedHops)-1)
}
return &BlindedEdge{
policy: policy,
hopIndex: hopIndex,
blindedPayment: payment,
}, nil
}
// EdgePolicy return the policy of the BlindedEdge.
func (b *BlindedEdge) EdgePolicy() *models.CachedEdgePolicy {
return b.policy
}
// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if
// this edge were to be included in a route.
func (b *BlindedEdge) IntermediatePayloadSize(_ lnwire.MilliSatoshi, _ uint32,
_ uint64) uint64 {
blindedPath := b.blindedPayment.BlindedPath
hop := route.Hop{
BlindingPoint: blindedPath.BlindingPoint,
EncryptedData: blindedPath.BlindedHops[b.hopIndex].CipherText,
}
// For blinded paths the next chanID is in the encrypted data tlv.
return hop.PayloadSize(0)
}
// BlindedPayment returns the blinded payment that this edge is associated
// with.
func (b *BlindedEdge) BlindedPayment() *BlindedPayment {
return b.blindedPayment
}
// Compile-time constraints to ensure the PrivateEdge and the BlindedEdge
// implement the AdditionalEdge interface.
var _ AdditionalEdge = (*PrivateEdge)(nil)
var _ AdditionalEdge = (*BlindedEdge)(nil)
// defaultHopPayloadSize is the default payload size of a normal (not-blinded)
// hop in the route.
func defaultHopPayloadSize(amount lnwire.MilliSatoshi, expiry uint32,
channelID uint64) uint64 {
// The payload size of a cleartext intermediate hop is equal to the
// payload size of a private edge therefore we reuse its size function.
edge := PrivateEdge{}
return edge.IntermediatePayloadSize(amount, expiry, channelID)
}
package routing
import (
"fmt"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
// bandwidthHints provides hints about the currently available balance in our
// channels.
type bandwidthHints interface {
// availableChanBandwidth returns the total available bandwidth for a
// channel and a bool indicating whether the channel hint was found.
// The amount parameter is used to validate the outgoing htlc amount
// that we wish to add to the channel against its flow restrictions. If
// a zero amount is provided, the minimum htlc value for the channel
// will be used. If the channel is unavailable, a zero amount is
// returned.
availableChanBandwidth(channelID uint64,
amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool)
// firstHopCustomBlob returns the custom blob for the first hop of the
// payment, if available.
firstHopCustomBlob() fn.Option[tlv.Blob]
}
// getLinkQuery is the function signature used to lookup a link.
type getLinkQuery func(lnwire.ShortChannelID) (
htlcswitch.ChannelLink, error)
// bandwidthManager is an implementation of the bandwidthHints interface which
// uses the link lookup provided to query the link for our latest local channel
// balances.
type bandwidthManager struct {
getLink getLinkQuery
localChans map[lnwire.ShortChannelID]struct{}
firstHopBlob fn.Option[tlv.Blob]
trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
}
// newBandwidthManager creates a bandwidth manager for the source node provided
// which is used to obtain hints from the lower layer w.r.t the available
// bandwidth of edges on the network. Currently, we'll only obtain bandwidth
// hints for the edges we directly have open ourselves. Obtaining these hints
// allows us to reduce the number of extraneous attempts as we can skip channels
// that are inactive, or just don't have enough bandwidth to carry the payment.
func newBandwidthManager(graph Graph, sourceNode route.Vertex,
linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob],
ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager,
error) {
manager := &bandwidthManager{
getLink: linkQuery,
localChans: make(map[lnwire.ShortChannelID]struct{}),
firstHopBlob: firstHopBlob,
trafficShaper: ts,
}
// First, we'll collect the set of outbound edges from the target
// source node and add them to our bandwidth manager's map of channels.
err := graph.ForEachNodeDirectedChannel(sourceNode,
func(channel *graphdb.DirectedChannel) error {
shortID := lnwire.NewShortChanIDFromInt(
channel.ChannelID,
)
manager.localChans[shortID] = struct{}{}
return nil
})
if err != nil {
return nil, err
}
return manager, nil
}
// getBandwidth queries the current state of a link and gets its currently
// available bandwidth. Note that this function assumes that the channel being
// queried is one of our local channels, so any failure to retrieve the link
// is interpreted as the link being offline.
func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID,
amount lnwire.MilliSatoshi) lnwire.MilliSatoshi {
link, err := b.getLink(cid)
if err != nil {
// If the link isn't online, then we'll report that it has
// zero bandwidth.
log.Warnf("ShortChannelID=%v: link not found: %v", cid, err)
return 0
}
// If the link is found within the switch, but it isn't yet eligible
// to forward any HTLCs, then we'll treat it as if it isn't online in
// the first place.
if !link.EligibleToForward() {
log.Warnf("ShortChannelID=%v: not eligible to forward", cid)
return 0
}
// bandwidthResult is an inline type that we'll use to pass the
// bandwidth result from the external traffic shaper to the main logic
// below.
type bandwidthResult struct {
// bandwidth is the available bandwidth for the channel as
// reported by the external traffic shaper. If the external
// traffic shaper is not handling the channel, this value will
// be fn.None
bandwidth fn.Option[lnwire.MilliSatoshi]
// htlcAmount is the amount we're going to use to check if we
// can add another HTLC to the channel. If the external traffic
// shaper is handling the channel, we'll use 0 to just sanity
// check the number of HTLCs on the channel, since we don't know
// the actual HTLC amount that will be sent.
htlcAmount fn.Option[lnwire.MilliSatoshi]
}
var (
// We will pass the link bandwidth to the external traffic
// shaper. This is the current best estimate for the available
// bandwidth for the link.
linkBandwidth = link.Bandwidth()
bandwidthErr = func(err error) fn.Result[bandwidthResult] {
return fn.Err[bandwidthResult](err)
}
)
result, err := fn.MapOptionZ(
b.trafficShaper,
func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] {
auxBandwidth, err := link.AuxBandwidth(
amount, cid, b.firstHopBlob, s,
).Unpack()
if err != nil {
return bandwidthErr(fmt.Errorf("failed to get "+
"auxiliary bandwidth: %w", err))
}
// If the external traffic shaper is not handling the
// channel, we'll just return the original bandwidth and
// no custom amount.
if !auxBandwidth.IsHandled {
return fn.Ok(bandwidthResult{})
}
// We don't know the actual HTLC amount that will be
// sent using the custom channel. But we'll still want
// to make sure we can add another HTLC, using the
// MayAddOutgoingHtlc method below. Passing 0 into that
// method will use the minimum HTLC value for the
// channel, which is okay to just check we don't exceed
// the max number of HTLCs on the channel. A proper
// balance check is done elsewhere.
return fn.Ok(bandwidthResult{
bandwidth: auxBandwidth.Bandwidth,
htlcAmount: fn.Some[lnwire.MilliSatoshi](0),
})
},
).Unpack()
if err != nil {
log.Errorf("ShortChannelID=%v: failed to get bandwidth from "+
"external traffic shaper: %v", cid, err)
return 0
}
htlcAmount := result.htlcAmount.UnwrapOr(amount)
// If our link isn't currently in a state where it can add another
// outgoing htlc, treat the link as unusable.
if err := link.MayAddOutgoingHtlc(htlcAmount); err != nil {
log.Warnf("ShortChannelID=%v: cannot add outgoing "+
"htlc with amount %v: %v", cid, htlcAmount, err)
return 0
}
// If the external traffic shaper determined the bandwidth, we'll return
// that value, even if it is zero (which would mean no bandwidth is
// available on that channel).
reportedBandwidth := result.bandwidth.UnwrapOr(linkBandwidth)
return reportedBandwidth
}
// availableChanBandwidth returns the total available bandwidth for a channel
// and a bool indicating whether the channel hint was found. If the channel is
// unavailable, a zero amount is returned.
func (b *bandwidthManager) availableChanBandwidth(channelID uint64,
amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) {
shortID := lnwire.NewShortChanIDFromInt(channelID)
_, ok := b.localChans[shortID]
if !ok {
return 0, false
}
return b.getBandwidth(shortID, amount), true
}
// firstHopCustomBlob returns the custom blob for the first hop of the payment,
// if available.
func (b *bandwidthManager) firstHopCustomBlob() fn.Option[tlv.Blob] {
return b.firstHopBlob
}
package routing
import (
"bytes"
"errors"
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// BlindedPathNUMSHex is the hex encoded version of the blinded path target
// NUMs key (in compressed format) which has no known private key.
// This was generated using the following script:
// https://github.com/lightninglabs/lightning-node-connect/tree/master/
// mailbox/numsgen, with the seed phrase "Lightning Blinded Path".
const BlindedPathNUMSHex = "02667a98ef82ecb522f803b17a74f14508a48b25258f9831" +
"dd6e95f5e299dfd54e"
var (
// ErrNoBlindedPath is returned when the blinded path in a blinded
// payment is missing.
ErrNoBlindedPath = errors.New("blinded path required")
// ErrInsufficientBlindedHops is returned when a blinded path does
// not have enough blinded hops.
ErrInsufficientBlindedHops = errors.New("blinded path requires " +
"at least one hop")
// ErrHTLCRestrictions is returned when a blinded path has invalid
// HTLC maximum and minimum values.
ErrHTLCRestrictions = errors.New("invalid htlc minimum and maximum")
// BlindedPathNUMSKey is a NUMS key (nothing up my sleeves number) that
// has no known private key.
BlindedPathNUMSKey = input.MustParsePubKey(BlindedPathNUMSHex)
// CompressedBlindedPathNUMSKey is the compressed version of the
// BlindedPathNUMSKey.
CompressedBlindedPathNUMSKey = BlindedPathNUMSKey.SerializeCompressed()
)
// BlindedPaymentPathSet groups the data we need to handle sending to a set of
// blinded paths provided by the recipient of a payment.
//
// NOTE: for now this only holds a single BlindedPayment. By the end of the PR
// series, it will handle multiple paths.
type BlindedPaymentPathSet struct {
// paths is the set of blinded payment paths for a single payment.
// NOTE: For now this will always only have a single entry. By the end
// of this PR, it can hold multiple.
paths []*BlindedPayment
// targetPubKey is the ephemeral node pub key that we will inject into
// each path as the last hop. This is only for the sake of path finding.
// Once the path has been found, the original destination pub key is
// used again. In the edge case where there is only a single hop in the
// path (the introduction node is the destination node), then this will
// just be the introduction node's real public key.
targetPubKey *btcec.PublicKey
// features is the set of relay features available for the payment.
// This is extracted from the set of blinded payment paths. At the
// moment we require that all paths for the same payment have the
// same feature set.
features *lnwire.FeatureVector
// finalCLTV is the final hop's expiry delta of _any_ path in the set.
// For any multi-hop path, the final CLTV delta should be seen as zero
// since the final hop's final CLTV delta is accounted for in the
// accumulated path policy values. The only edge case is for when the
// final hop in the path is also the introduction node in which case
// that path's FinalCLTV must be the non-zero min CLTV of the final hop
// so that it is accounted for in path finding. For this reason, if
// we have any single path in the set with only one hop, then we throw
// away all the other paths. This should be fine to do since if there is
// a path where the intro node is also the destination node, then there
// isn't any need to try any other longer blinded path. In other words,
// if this value is non-zero, then there is only one path in this
// blinded path set and that path only has a single hop: the
// introduction node.
finalCLTV uint16
}
// NewBlindedPaymentPathSet constructs a new BlindedPaymentPathSet from a set of
// BlindedPayments. For blinded paths which have more than one single hop a
// dummy hop via a NUMS key is appeneded to allow for MPP path finding via
// multiple blinded paths.
func NewBlindedPaymentPathSet(paths []*BlindedPayment) (*BlindedPaymentPathSet,
error) {
if len(paths) == 0 {
return nil, ErrNoBlindedPath
}
// For now, we assert that all the paths have the same set of features.
features := paths[0].Features
noFeatures := features == nil || features.IsEmpty()
for i := 1; i < len(paths); i++ {
noFeats := paths[i].Features == nil ||
paths[i].Features.IsEmpty()
if noFeatures && !noFeats {
return nil, fmt.Errorf("all blinded paths must have " +
"the same set of features")
}
if noFeatures {
continue
}
if !features.RawFeatureVector.Equals(
paths[i].Features.RawFeatureVector,
) {
return nil, fmt.Errorf("all blinded paths must have " +
"the same set of features")
}
}
// Deep copy the paths to avoid mutating the original paths.
pathSet := make([]*BlindedPayment, len(paths))
for i, path := range paths {
pathSet[i] = path.deepCopy()
}
// For blinded paths we use the NUMS key as a target if the blinded
// path has more hops than just the introduction node.
targetPub := &BlindedPathNUMSKey
var finalCLTVDelta uint16
// In case the paths do NOT include a single hop route we append a
// dummy hop via a NUMS key to allow for MPP path finding via multiple
// blinded paths. A unified target is needed to use all blinded paths
// during the payment lifecycle. A dummy hop is solely added for the
// path finding process and is removed after the path is found. This
// ensures that we still populate the mission control with the correct
// data and also respect these mc entries when looking for a path.
for _, path := range pathSet {
pathLength := len(path.BlindedPath.BlindedHops)
// If any provided blinded path only has a single hop (ie, the
// destination node is also the introduction node), then we
// discard all other paths since we know the real pub key of the
// destination node. We also then set the final CLTV delta to
// the path's delta since there are no other edge hints that
// will account for it.
if pathLength == 1 {
pathSet = []*BlindedPayment{path}
finalCLTVDelta = path.CltvExpiryDelta
targetPub = path.BlindedPath.IntroductionPoint
break
}
lastHop := path.BlindedPath.BlindedHops[pathLength-1]
path.BlindedPath.BlindedHops = append(
path.BlindedPath.BlindedHops,
&sphinx.BlindedHopInfo{
BlindedNodePub: &BlindedPathNUMSKey,
// We add the last hop's cipher text so that
// the payload size of the final hop is equal
// to the real last hop.
CipherText: lastHop.CipherText,
},
)
}
return &BlindedPaymentPathSet{
paths: pathSet,
targetPubKey: targetPub,
features: features,
finalCLTV: finalCLTVDelta,
}, nil
}
// TargetPubKey returns the public key to be used as the destination node's
// public key during pathfinding.
func (s *BlindedPaymentPathSet) TargetPubKey() *btcec.PublicKey {
return s.targetPubKey
}
// Features returns the set of relay features available for the payment.
func (s *BlindedPaymentPathSet) Features() *lnwire.FeatureVector {
return s.features
}
// IntroNodeOnlyPath can be called if it is expected that the path set only
// contains a single payment path which itself only has one hop. It errors if
// this is not the case.
func (s *BlindedPaymentPathSet) IntroNodeOnlyPath() (*BlindedPayment, error) {
if len(s.paths) != 1 {
return nil, fmt.Errorf("expected only a single path in the "+
"blinded payment set, got %d", len(s.paths))
}
if len(s.paths[0].BlindedPath.BlindedHops) > 1 {
return nil, fmt.Errorf("an intro node only path cannot have " +
"more than one hop")
}
return s.paths[0], nil
}
// IsIntroNode returns true if the given vertex is an introduction node for one
// of the paths in the blinded payment path set.
func (s *BlindedPaymentPathSet) IsIntroNode(source route.Vertex) bool {
for _, path := range s.paths {
introVertex := route.NewVertex(
path.BlindedPath.IntroductionPoint,
)
if source == introVertex {
return true
}
}
return false
}
// FinalCLTVDelta is the minimum CLTV delta to use for the final hop on the
// route. In most cases this will return zero since the value is accounted for
// in the path's accumulated CLTVExpiryDelta. Only in the edge case of the path
// set only including a single path which only includes an introduction node
// will this return a non-zero value.
func (s *BlindedPaymentPathSet) FinalCLTVDelta() uint16 {
return s.finalCLTV
}
// LargestLastHopPayloadPath returns the BlindedPayment in the set that has the
// largest last-hop payload. This is to be used for onion size estimation in
// path finding.
func (s *BlindedPaymentPathSet) LargestLastHopPayloadPath() (*BlindedPayment,
error) {
var (
largestPath *BlindedPayment
currentMax int
)
if len(s.paths) == 0 {
return nil, fmt.Errorf("no blinded paths in the set")
}
// We set the largest path to make sure we always return a path even
// if the cipher text is empty.
largestPath = s.paths[0]
for _, path := range s.paths {
numHops := len(path.BlindedPath.BlindedHops)
lastHop := path.BlindedPath.BlindedHops[numHops-1]
if len(lastHop.CipherText) > currentMax {
largestPath = path
currentMax = len(lastHop.CipherText)
}
}
return largestPath, nil
}
// ToRouteHints converts the blinded path payment set into a RouteHints map so
// that the blinded payment paths can be treated like route hints throughout the
// code base.
func (s *BlindedPaymentPathSet) ToRouteHints() (RouteHints, error) {
hints := make(RouteHints)
for _, path := range s.paths {
pathHints, err := path.toRouteHints()
if err != nil {
return nil, err
}
for from, edges := range pathHints {
hints[from] = append(hints[from], edges...)
}
}
if len(hints) == 0 {
return nil, nil
}
return hints, nil
}
// IsBlindedRouteNUMSTargetKey returns true if the given public key is the
// NUMS key used as a target for blinded path final hops.
func IsBlindedRouteNUMSTargetKey(pk []byte) bool {
return bytes.Equal(pk, CompressedBlindedPathNUMSKey)
}
// BlindedPayment provides the path and payment parameters required to send a
// payment along a blinded path.
type BlindedPayment struct {
// BlindedPath contains the unblinded introduction point and blinded
// hops for the blinded section of the payment.
BlindedPath *sphinx.BlindedPath
// BaseFee is the total base fee to be paid for payments made over the
// blinded path.
BaseFee uint32
// ProportionalFeeRate is the aggregated proportional fee rate for
// payments made over the blinded path.
ProportionalFeeRate uint32
// CltvExpiryDelta is the total expiry delta for the blinded path. This
// field includes the CLTV for the blinded hops *and* the final cltv
// delta for the receiver.
CltvExpiryDelta uint16
// HtlcMinimum is the highest HLTC minimum supported along the blinded
// path (while some hops may have lower values, we're effectively
// bounded by the highest minimum).
HtlcMinimum uint64
// HtlcMaximum is the lowest HTLC maximum supported along the blinded
// path (while some hops may have higher values, we're effectively
// bounded by the lowest maximum).
HtlcMaximum uint64
// Features is the set of relay features available for the payment.
Features *lnwire.FeatureVector
}
// Validate performs validation on a blinded payment.
func (b *BlindedPayment) Validate() error {
if b.BlindedPath == nil {
return ErrNoBlindedPath
}
// The sphinx library inserts the introduction node as the first hop,
// so we expect at least one hop.
if len(b.BlindedPath.BlindedHops) < 1 {
return fmt.Errorf("%w got: %v", ErrInsufficientBlindedHops,
len(b.BlindedPath.BlindedHops))
}
if b.HtlcMaximum < b.HtlcMinimum {
return fmt.Errorf("%w: %v < %v", ErrHTLCRestrictions,
b.HtlcMaximum, b.HtlcMinimum)
}
for _, hop := range b.BlindedPath.BlindedHops {
// The first hop of the blinded path does not necessarily have
// blinded node pub key because it is the introduction point.
if hop.BlindedNodePub == nil {
continue
}
if IsBlindedRouteNUMSTargetKey(
hop.BlindedNodePub.SerializeCompressed(),
) {
return fmt.Errorf("blinded path cannot include NUMS "+
"key: %s", BlindedPathNUMSHex)
}
}
return nil
}
// toRouteHints produces a set of chained route hints that represent a blinded
// path. In the case of a single hop blinded route (which is paying directly
// to the introduction point), no hints will be returned. In this case callers
// *must* account for the blinded route's CLTV delta elsewhere (as this is
// effectively the final_cltv_delta for the receiving introduction node). In
// the case of multiple blinded hops, CLTV delta is fully accounted for in the
// hints (both for intermediate hops and the final_cltv_delta for the receiving
// node).
func (b *BlindedPayment) toRouteHints() (RouteHints, error) {
// If we just have a single hop in our blinded route, it just contains
// an introduction node (this is a valid path according to the spec).
// Since we have the un-blinded node ID for the introduction node, we
// don't need to add any route hints.
if len(b.BlindedPath.BlindedHops) == 1 {
return nil, nil
}
hintCount := len(b.BlindedPath.BlindedHops) - 1
hints := make(
RouteHints, hintCount,
)
// Start at the unblinded introduction node, because our pathfinding
// will be able to locate this point in the graph.
fromNode := route.NewVertex(b.BlindedPath.IntroductionPoint)
features := lnwire.EmptyFeatureVector()
if b.Features != nil {
features = b.Features.Clone()
}
// Use the total aggregate relay parameters for the entire blinded
// route as the policy for the hint from our introduction node. This
// will ensure that pathfinding provides sufficient fees/delay for the
// blinded portion to the introduction node.
firstBlindedHop := b.BlindedPath.BlindedHops[1].BlindedNodePub
edgePolicy := &models.CachedEdgePolicy{
TimeLockDelta: b.CltvExpiryDelta,
MinHTLC: lnwire.MilliSatoshi(b.HtlcMinimum),
MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum),
FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee),
FeeProportionalMillionths: lnwire.MilliSatoshi(
b.ProportionalFeeRate,
),
ToNodePubKey: func() route.Vertex {
return route.NewVertex(
// The first node in this slice is
// the introduction node, so we start
// at index 1 to get the first blinded
// relaying node.
firstBlindedHop,
)
},
ToNodeFeatures: features,
}
lastEdge, err := NewBlindedEdge(edgePolicy, b, 0)
if err != nil {
return nil, err
}
hints[fromNode] = []AdditionalEdge{lastEdge}
// Start at an offset of 1 because the first node in our blinded hops
// is the introduction node and terminate at the second-last node
// because we're dealing with hops as pairs.
for i := 1; i < hintCount; i++ {
// Set our origin node to the current
fromNode = route.NewVertex(
b.BlindedPath.BlindedHops[i].BlindedNodePub,
)
// Create a hint which has no fee or cltv delta. We
// specifically want zero values here because our relay
// parameters are expressed in encrypted blobs rather than the
// route itself for blinded routes.
nextHopIdx := i + 1
nextNode := route.NewVertex(
b.BlindedPath.BlindedHops[nextHopIdx].BlindedNodePub,
)
edgePolicy := &models.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return nextNode
},
ToNodeFeatures: features,
}
lastEdge, err = NewBlindedEdge(edgePolicy, b, i)
if err != nil {
return nil, err
}
hints[fromNode] = []AdditionalEdge{lastEdge}
}
return hints, nil
}
// deepCopy returns a deep copy of the BlindedPayment.
func (b *BlindedPayment) deepCopy() *BlindedPayment {
if b == nil {
return nil
}
cpyPayment := &BlindedPayment{
BaseFee: b.BaseFee,
ProportionalFeeRate: b.ProportionalFeeRate,
CltvExpiryDelta: b.CltvExpiryDelta,
HtlcMinimum: b.HtlcMinimum,
HtlcMaximum: b.HtlcMaximum,
}
// Deep copy the BlindedPath if it exists
if b.BlindedPath != nil {
cpyPayment.BlindedPath = &sphinx.BlindedPath{
BlindedHops: make([]*sphinx.BlindedHopInfo,
len(b.BlindedPath.BlindedHops)),
}
if b.BlindedPath.IntroductionPoint != nil {
cpyPayment.BlindedPath.IntroductionPoint =
copyPublicKey(b.BlindedPath.IntroductionPoint)
}
if b.BlindedPath.BlindingPoint != nil {
cpyPayment.BlindedPath.BlindingPoint =
copyPublicKey(b.BlindedPath.BlindingPoint)
}
// Copy each blinded hop info.
for i, hop := range b.BlindedPath.BlindedHops {
if hop == nil {
continue
}
cpyHop := &sphinx.BlindedHopInfo{
CipherText: hop.CipherText,
}
if hop.BlindedNodePub != nil {
cpyHop.BlindedNodePub =
copyPublicKey(hop.BlindedNodePub)
}
cpyHop.CipherText = make([]byte, len(hop.CipherText))
copy(cpyHop.CipherText, hop.CipherText)
cpyPayment.BlindedPath.BlindedHops[i] = cpyHop
}
}
// Deep copy the Features if they exist
if b.Features != nil {
cpyPayment.Features = b.Features.Clone()
}
return cpyPayment
}
// copyPublicKey makes a deep copy of a public key.
//
// TODO(ziggie): Remove this function if this is available in the btcec library.
func copyPublicKey(pk *btcec.PublicKey) *btcec.PublicKey {
var result secp256k1.JacobianPoint
pk.AsJacobian(&result)
result.ToAffine()
return btcec.NewPublicKey(&result.X, &result.Y)
}
package chainview
import (
"bytes"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/blockcache"
graphdb "github.com/lightningnetwork/lnd/graph/db"
)
// BitcoindFilteredChainView is an implementation of the FilteredChainView
// interface which is backed by bitcoind.
type BitcoindFilteredChainView struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// bestHeight is the height of the latest block added to the
// blockQueue from the onFilteredConnectedMethod. It is used to
// determine up to what height we would need to rescan in case
// of a filter update.
bestHeightMtx sync.Mutex
bestHeight uint32
// TODO: Factor out common logic between bitcoind and btcd into a
// NodeFilteredView interface.
chainClient *chain.BitcoindClient
// blockEventQueue is the ordered queue used to keep the order
// of connected and disconnected blocks sent to the reader of the
// chainView.
blockQueue *blockEventQueue
// blockCache is an LRU block cache.
blockCache *blockcache.BlockCache
// filterUpdates is a channel in which updates to the utxo filter
// attached to this instance are sent over.
filterUpdates chan filterUpdate
// chainFilter is the set of utox's that we're currently watching
// spends for within the chain.
filterMtx sync.RWMutex
chainFilter map[wire.OutPoint]struct{}
// filterBlockReqs is a channel in which requests to filter select
// blocks will be sent over.
filterBlockReqs chan *filterBlockReq
quit chan struct{}
wg sync.WaitGroup
}
// A compile time check to ensure BitcoindFilteredChainView implements the
// chainview.FilteredChainView.
var _ FilteredChainView = (*BitcoindFilteredChainView)(nil)
// NewBitcoindFilteredChainView creates a new instance of a FilteredChainView
// from RPC credentials and a ZMQ socket address for a bitcoind instance.
func NewBitcoindFilteredChainView(
chainConn *chain.BitcoindConn,
blockCache *blockcache.BlockCache) *BitcoindFilteredChainView {
chainView := &BitcoindFilteredChainView{
chainFilter: make(map[wire.OutPoint]struct{}),
filterUpdates: make(chan filterUpdate),
filterBlockReqs: make(chan *filterBlockReq),
blockCache: blockCache,
quit: make(chan struct{}),
}
chainView.chainClient = chainConn.NewBitcoindClient()
chainView.blockQueue = newBlockEventQueue()
return chainView
}
// Start starts all goroutines necessary for normal operation.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) Start() error {
// Already started?
if atomic.AddInt32(&b.started, 1) != 1 {
return nil
}
log.Infof("FilteredChainView starting")
err := b.chainClient.Start()
if err != nil {
return err
}
err = b.chainClient.NotifyBlocks()
if err != nil {
return err
}
_, bestHeight, err := b.chainClient.GetBestBlock()
if err != nil {
return err
}
b.bestHeightMtx.Lock()
b.bestHeight = uint32(bestHeight)
b.bestHeightMtx.Unlock()
b.blockQueue.Start()
b.wg.Add(1)
go b.chainFilterer()
return nil
}
// Stop stops all goroutines which we launched by the prior call to the Start
// method.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) Stop() error {
log.Debug("BitcoindFilteredChainView stopping")
defer log.Debug("BitcoindFilteredChainView stopped")
// Already shutting down?
if atomic.AddInt32(&b.stopped, 1) != 1 {
return nil
}
// Shutdown the rpc client, this gracefully disconnects from bitcoind's
// zmq socket, and cleans up all related resources.
b.chainClient.Stop()
b.chainClient.WaitForShutdown()
b.blockQueue.Stop()
close(b.quit)
b.wg.Wait()
return nil
}
// onFilteredBlockConnected is called for each block that's connected to the
// end of the main chain. Based on our current chain filter, the block may or
// may not include any relevant transactions.
func (b *BitcoindFilteredChainView) onFilteredBlockConnected(height int32,
hash chainhash.Hash, txns []*wtxmgr.TxRecord) {
mtxs := make([]*wire.MsgTx, len(txns))
b.filterMtx.Lock()
for i, tx := range txns {
mtxs[i] = &tx.MsgTx
for _, txIn := range mtxs[i].TxIn {
// We can delete this outpoint from the chainFilter, as
// we just received a block where it was spent. In case
// of a reorg, this outpoint might get "un-spent", but
// that's okay since it would never be wise to consider
// the channel open again (since a spending transaction
// exists on the network).
delete(b.chainFilter, txIn.PreviousOutPoint)
}
}
b.filterMtx.Unlock()
// We record the height of the last connected block added to the
// blockQueue such that we can scan up to this height in case of
// a rescan. It must be protected by a mutex since a filter update
// might be trying to read it concurrently.
b.bestHeightMtx.Lock()
b.bestHeight = uint32(height)
b.bestHeightMtx.Unlock()
block := &FilteredBlock{
Hash: hash,
Height: uint32(height),
Transactions: mtxs,
}
b.blockQueue.Add(&blockEvent{
eventType: connected,
block: block,
})
}
// onFilteredBlockDisconnected is a callback which is executed once a block is
// disconnected from the end of the main chain.
func (b *BitcoindFilteredChainView) onFilteredBlockDisconnected(height int32,
hash chainhash.Hash) {
log.Debugf("got disconnected block at height %d: %v", height,
hash)
filteredBlock := &FilteredBlock{
Hash: hash,
Height: uint32(height),
}
b.blockQueue.Add(&blockEvent{
eventType: disconnected,
block: filteredBlock,
})
}
// FilterBlock takes a block hash, and returns a FilteredBlocks which is the
// result of applying the current registered UTXO sub-set on the block
// corresponding to that block hash. If any watched UTOX's are spent by the
// selected lock, then the internal chainFilter will also be updated.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredBlock, error) {
req := &filterBlockReq{
blockHash: blockHash,
resp: make(chan *FilteredBlock, 1),
err: make(chan error, 1),
}
select {
case b.filterBlockReqs <- req:
case <-b.quit:
return nil, fmt.Errorf("FilteredChainView shutting down")
}
return <-req.resp, <-req.err
}
// chainFilterer is the primary goroutine which: listens for new blocks coming
// and dispatches the relevant FilteredBlock notifications, updates the filter
// due to requests by callers, and finally is able to perform targeted block
// filtration.
//
// TODO(roasbeef): change to use loadfilter RPC's
func (b *BitcoindFilteredChainView) chainFilterer() {
defer b.wg.Done()
// filterBlock is a helper function that scans the given block, and
// notes which transactions spend outputs which are currently being
// watched. Additionally, the chain filter will also be updated by
// removing any spent outputs.
filterBlock := func(blk *wire.MsgBlock) []*wire.MsgTx {
b.filterMtx.Lock()
defer b.filterMtx.Unlock()
var filteredTxns []*wire.MsgTx
for _, tx := range blk.Transactions {
var txAlreadyFiltered bool
for _, txIn := range tx.TxIn {
prevOp := txIn.PreviousOutPoint
if _, ok := b.chainFilter[prevOp]; !ok {
continue
}
delete(b.chainFilter, prevOp)
// Only add this txn to our list of filtered
// txns if it is the first previous outpoint to
// cause a match.
if txAlreadyFiltered {
continue
}
filteredTxns = append(filteredTxns, tx.Copy())
txAlreadyFiltered = true
}
}
return filteredTxns
}
decodeJSONBlock := func(block *btcjson.RescannedBlock,
height uint32) (*FilteredBlock, error) {
hash, err := chainhash.NewHashFromStr(block.Hash)
if err != nil {
return nil, err
}
txs := make([]*wire.MsgTx, 0, len(block.Transactions))
for _, str := range block.Transactions {
b, err := hex.DecodeString(str)
if err != nil {
return nil, err
}
tx := &wire.MsgTx{}
err = tx.Deserialize(bytes.NewReader(b))
if err != nil {
return nil, err
}
txs = append(txs, tx)
}
return &FilteredBlock{
Hash: *hash,
Height: height,
Transactions: txs,
}, nil
}
for {
select {
// The caller has just sent an update to the current chain
// filter, so we'll apply the update, possibly rewinding our
// state partially.
case update := <-b.filterUpdates:
// First, we'll add all the new UTXO's to the set of
// watched UTXO's, eliminating any duplicates in the
// process.
log.Tracef("Updating chain filter with new UTXO's: %v",
update.newUtxos)
b.filterMtx.Lock()
for _, newOp := range update.newUtxos {
b.chainFilter[newOp] = struct{}{}
}
b.filterMtx.Unlock()
// Apply the new TX filter to the chain client, which
// will cause all following notifications from and
// calls to it return blocks filtered with the new
// filter.
err := b.chainClient.LoadTxFilter(false, update.newUtxos)
if err != nil {
log.Errorf("Unable to update filter: %v", err)
continue
}
// All blocks gotten after we loaded the filter will
// have the filter applied, but we will need to rescan
// the blocks up to the height of the block we last
// added to the blockQueue.
b.bestHeightMtx.Lock()
bestHeight := b.bestHeight
b.bestHeightMtx.Unlock()
// If the update height matches our best known height,
// then we don't need to do any rewinding.
if update.updateHeight == bestHeight {
continue
}
// Otherwise, we'll rewind the state to ensure the
// caller doesn't miss any relevant notifications.
// Starting from the height _after_ the update height,
// we'll walk forwards, rescanning one block at a time
// with the chain client applying the newly loaded
// filter to each block.
for i := update.updateHeight + 1; i < bestHeight+1; i++ {
blockHash, err := b.chainClient.GetBlockHash(int64(i))
if err != nil {
log.Warnf("Unable to get block hash "+
"for block at height %d: %v",
i, err)
continue
}
// To avoid dealing with the case where a reorg
// is happening while we rescan, we scan one
// block at a time, skipping blocks that might
// have gone missing.
rescanned, err := b.chainClient.RescanBlocks(
[]chainhash.Hash{*blockHash},
)
if err != nil {
log.Warnf("Unable to rescan block "+
"with hash %v at height %d: %v",
blockHash, i, err)
continue
}
// If no block was returned from the rescan, it
// means no matching transactions were found.
if len(rescanned) != 1 {
log.Tracef("rescan of block %v at "+
"height=%d yielded no "+
"transactions", blockHash, i)
continue
}
decoded, err := decodeJSONBlock(
&rescanned[0], i,
)
if err != nil {
log.Errorf("Unable to decode block: %v",
err)
continue
}
b.blockQueue.Add(&blockEvent{
eventType: connected,
block: decoded,
})
}
// We've received a new request to manually filter a block.
case req := <-b.filterBlockReqs:
// First we'll fetch the block itself as well as some
// additional information including its height.
block, err := b.GetBlock(req.blockHash)
if err != nil {
req.err <- err
req.resp <- nil
continue
}
header, err := b.chainClient.GetBlockHeaderVerbose(
req.blockHash)
if err != nil {
req.err <- err
req.resp <- nil
continue
}
// Once we have this info, we can directly filter the
// block and dispatch the proper notification.
req.resp <- &FilteredBlock{
Hash: *req.blockHash,
Height: uint32(header.Height),
Transactions: filterBlock(block),
}
req.err <- err
// We've received a new event from the chain client.
case event := <-b.chainClient.Notifications():
switch e := event.(type) {
case chain.FilteredBlockConnected:
b.onFilteredBlockConnected(
e.Block.Height, e.Block.Hash, e.RelevantTxs,
)
case chain.BlockDisconnected:
b.onFilteredBlockDisconnected(e.Height, e.Hash)
}
case <-b.quit:
return
}
}
}
// UpdateFilter updates the UTXO filter which is to be consulted when creating
// FilteredBlocks to be sent to subscribed clients. This method is cumulative
// meaning repeated calls to this method should _expand_ the size of the UTXO
// sub-set currently being watched. If the set updateHeight is _lower_ than
// the best known height of the implementation, then the state should be
// rewound to ensure all relevant notifications are dispatched.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint,
updateHeight uint32) error {
newUtxos := make([]wire.OutPoint, len(ops))
for i, op := range ops {
newUtxos[i] = op.OutPoint
}
select {
case b.filterUpdates <- filterUpdate{
newUtxos: newUtxos,
updateHeight: updateHeight,
}:
return nil
case <-b.quit:
return fmt.Errorf("chain filter shutting down")
}
}
// FilteredBlocks returns the channel that filtered blocks are to be sent over.
// Each time a block is connected to the end of a main chain, and appropriate
// FilteredBlock which contains the transactions which mutate our watched UTXO
// set is to be returned.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) FilteredBlocks() <-chan *FilteredBlock {
return b.blockQueue.newBlocks
}
// DisconnectedBlocks returns a receive only channel which will be sent upon
// with the empty filtered blocks of blocks which are disconnected from the
// main chain in the case of a re-org.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BitcoindFilteredChainView) DisconnectedBlocks() <-chan *FilteredBlock {
return b.blockQueue.staleBlocks
}
// GetBlock is used to retrieve the block with the given hash. This function
// wraps the blockCache's GetBlock function.
func (b *BitcoindFilteredChainView) GetBlock(hash *chainhash.Hash) (
*wire.MsgBlock, error) {
return b.blockCache.GetBlock(hash, b.chainClient.GetBlock)
}
package chainview
import (
"bytes"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/blockcache"
graphdb "github.com/lightningnetwork/lnd/graph/db"
)
// BtcdFilteredChainView is an implementation of the FilteredChainView
// interface which is backed by an active websockets connection to btcd.
type BtcdFilteredChainView struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// bestHeight is the height of the latest block added to the
// blockQueue from the onFilteredConnectedMethod. It is used to
// determine up to what height we would need to rescan in case
// of a filter update.
bestHeightMtx sync.Mutex
bestHeight uint32
btcdConn *rpcclient.Client
// blockEventQueue is the ordered queue used to keep the order
// of connected and disconnected blocks sent to the reader of the
// chainView.
blockQueue *blockEventQueue
// blockCache is an LRU block cache.
blockCache *blockcache.BlockCache
// filterUpdates is a channel in which updates to the utxo filter
// attached to this instance are sent over.
filterUpdates chan filterUpdate
// chainFilter is the set of utox's that we're currently watching
// spends for within the chain.
filterMtx sync.RWMutex
chainFilter map[wire.OutPoint]struct{}
// filterBlockReqs is a channel in which requests to filter select
// blocks will be sent over.
filterBlockReqs chan *filterBlockReq
quit chan struct{}
wg sync.WaitGroup
}
// A compile time check to ensure BtcdFilteredChainView implements the
// chainview.FilteredChainView.
var _ FilteredChainView = (*BtcdFilteredChainView)(nil)
// NewBtcdFilteredChainView creates a new instance of a FilteredChainView from
// RPC credentials for an active btcd instance.
func NewBtcdFilteredChainView(config rpcclient.ConnConfig,
blockCache *blockcache.BlockCache) (*BtcdFilteredChainView, error) {
chainView := &BtcdFilteredChainView{
chainFilter: make(map[wire.OutPoint]struct{}),
filterUpdates: make(chan filterUpdate),
filterBlockReqs: make(chan *filterBlockReq),
blockCache: blockCache,
quit: make(chan struct{}),
}
ntfnCallbacks := &rpcclient.NotificationHandlers{
OnFilteredBlockConnected: chainView.onFilteredBlockConnected,
OnFilteredBlockDisconnected: chainView.onFilteredBlockDisconnected,
}
// Disable connecting to btcd within the rpcclient.New method. We
// defer establishing the connection to our .Start() method.
config.DisableConnectOnNew = true
config.DisableAutoReconnect = false
chainConn, err := rpcclient.New(&config, ntfnCallbacks)
if err != nil {
return nil, err
}
chainView.btcdConn = chainConn
chainView.blockQueue = newBlockEventQueue()
return chainView, nil
}
// Start starts all goroutines necessary for normal operation.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) Start() error {
// Already started?
if atomic.AddInt32(&b.started, 1) != 1 {
return nil
}
log.Infof("FilteredChainView starting")
// Connect to btcd, and register for notifications on connected, and
// disconnected blocks.
if err := b.btcdConn.Connect(20); err != nil {
return err
}
if err := b.btcdConn.NotifyBlocks(); err != nil {
return err
}
_, bestHeight, err := b.btcdConn.GetBestBlock()
if err != nil {
return err
}
b.bestHeightMtx.Lock()
b.bestHeight = uint32(bestHeight)
b.bestHeightMtx.Unlock()
b.blockQueue.Start()
b.wg.Add(1)
go b.chainFilterer()
return nil
}
// Stop stops all goroutines which we launched by the prior call to the Start
// method.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) Stop() error {
log.Debug("BtcdFilteredChainView stopping")
defer log.Debug("BtcdFilteredChainView stopped")
// Already shutting down?
if atomic.AddInt32(&b.stopped, 1) != 1 {
return nil
}
// Shutdown the rpc client, this gracefully disconnects from btcd, and
// cleans up all related resources.
b.btcdConn.Shutdown()
b.btcdConn.WaitForShutdown()
b.blockQueue.Stop()
close(b.quit)
b.wg.Wait()
return nil
}
// onFilteredBlockConnected is called for each block that's connected to the
// end of the main chain. Based on our current chain filter, the block may or
// may not include any relevant transactions.
func (b *BtcdFilteredChainView) onFilteredBlockConnected(height int32,
header *wire.BlockHeader, txns []*btcutil.Tx) {
mtxs := make([]*wire.MsgTx, len(txns))
b.filterMtx.Lock()
for i, tx := range txns {
mtx := tx.MsgTx()
mtxs[i] = mtx
for _, txIn := range mtx.TxIn {
// We can delete this outpoint from the chainFilter, as
// we just received a block where it was spent. In case
// of a reorg, this outpoint might get "un-spent", but
// that's okay since it would never be wise to consider
// the channel open again (since a spending transaction
// exists on the network).
delete(b.chainFilter, txIn.PreviousOutPoint)
}
}
b.filterMtx.Unlock()
// We record the height of the last connected block added to the
// blockQueue such that we can scan up to this height in case of
// a rescan. It must be protected by a mutex since a filter update
// might be trying to read it concurrently.
b.bestHeightMtx.Lock()
b.bestHeight = uint32(height)
b.bestHeightMtx.Unlock()
block := &FilteredBlock{
Hash: header.BlockHash(),
Height: uint32(height),
Transactions: mtxs,
}
b.blockQueue.Add(&blockEvent{
eventType: connected,
block: block,
})
}
// onFilteredBlockDisconnected is a callback which is executed once a block is
// disconnected from the end of the main chain.
func (b *BtcdFilteredChainView) onFilteredBlockDisconnected(height int32,
header *wire.BlockHeader) {
log.Debugf("got disconnected block at height %d: %v", height,
header.BlockHash())
filteredBlock := &FilteredBlock{
Hash: header.BlockHash(),
Height: uint32(height),
}
b.blockQueue.Add(&blockEvent{
eventType: disconnected,
block: filteredBlock,
})
}
// filterBlockReq houses a request to manually filter a block specified by
// block hash.
type filterBlockReq struct {
blockHash *chainhash.Hash
resp chan *FilteredBlock
err chan error
}
// FilterBlock takes a block hash, and returns a FilteredBlocks which is the
// result of applying the current registered UTXO sub-set on the block
// corresponding to that block hash. If any watched UTOX's are spent by the
// selected lock, then the internal chainFilter will also be updated.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredBlock, error) {
req := &filterBlockReq{
blockHash: blockHash,
resp: make(chan *FilteredBlock, 1),
err: make(chan error, 1),
}
select {
case b.filterBlockReqs <- req:
case <-b.quit:
return nil, fmt.Errorf("FilteredChainView shutting down")
}
return <-req.resp, <-req.err
}
// chainFilterer is the primary goroutine which: listens for new blocks coming
// and dispatches the relevant FilteredBlock notifications, updates the filter
// due to requests by callers, and finally is able to preform targeted block
// filtration.
//
// TODO(roasbeef): change to use loadfilter RPC's
func (b *BtcdFilteredChainView) chainFilterer() {
defer b.wg.Done()
// filterBlock is a helper function that scans the given block, and
// notes which transactions spend outputs which are currently being
// watched. Additionally, the chain filter will also be updated by
// removing any spent outputs.
filterBlock := func(blk *wire.MsgBlock) []*wire.MsgTx {
b.filterMtx.Lock()
defer b.filterMtx.Unlock()
var filteredTxns []*wire.MsgTx
for _, tx := range blk.Transactions {
var txAlreadyFiltered bool
for _, txIn := range tx.TxIn {
prevOp := txIn.PreviousOutPoint
if _, ok := b.chainFilter[prevOp]; !ok {
continue
}
delete(b.chainFilter, prevOp)
// Only add this txn to our list of filtered
// txns if it is the first previous outpoint to
// cause a match.
if txAlreadyFiltered {
continue
}
filteredTxns = append(filteredTxns, tx.Copy())
txAlreadyFiltered = true
}
}
return filteredTxns
}
decodeJSONBlock := func(block *btcjson.RescannedBlock,
height uint32) (*FilteredBlock, error) {
hash, err := chainhash.NewHashFromStr(block.Hash)
if err != nil {
return nil, err
}
txs := make([]*wire.MsgTx, 0, len(block.Transactions))
for _, str := range block.Transactions {
b, err := hex.DecodeString(str)
if err != nil {
return nil, err
}
tx := &wire.MsgTx{}
err = tx.Deserialize(bytes.NewReader(b))
if err != nil {
return nil, err
}
txs = append(txs, tx)
}
return &FilteredBlock{
Hash: *hash,
Height: height,
Transactions: txs,
}, nil
}
for {
select {
// The caller has just sent an update to the current chain
// filter, so we'll apply the update, possibly rewinding our
// state partially.
case update := <-b.filterUpdates:
// First, we'll add all the new UTXO's to the set of
// watched UTXO's, eliminating any duplicates in the
// process.
log.Tracef("Updating chain filter with new UTXO's: %v",
update.newUtxos)
b.filterMtx.Lock()
for _, newOp := range update.newUtxos {
b.chainFilter[newOp] = struct{}{}
}
b.filterMtx.Unlock()
// Apply the new TX filter to btcd, which will cause
// all following notifications from and calls to it
// return blocks filtered with the new filter.
b.btcdConn.LoadTxFilter(false, []btcutil.Address{},
update.newUtxos)
// All blocks gotten after we loaded the filter will
// have the filter applied, but we will need to rescan
// the blocks up to the height of the block we last
// added to the blockQueue.
b.bestHeightMtx.Lock()
bestHeight := b.bestHeight
b.bestHeightMtx.Unlock()
// If the update height matches our best known height,
// then we don't need to do any rewinding.
if update.updateHeight == bestHeight {
continue
}
// Otherwise, we'll rewind the state to ensure the
// caller doesn't miss any relevant notifications.
// Starting from the height _after_ the update height,
// we'll walk forwards, rescanning one block at a time
// with btcd applying the newly loaded filter to each
// block.
for i := update.updateHeight + 1; i < bestHeight+1; i++ {
blockHash, err := b.btcdConn.GetBlockHash(int64(i))
if err != nil {
log.Warnf("Unable to get block hash "+
"for block at height %d: %v",
i, err)
continue
}
// To avoid dealing with the case where a reorg
// is happening while we rescan, we scan one
// block at a time, skipping blocks that might
// have gone missing.
rescanned, err := b.btcdConn.RescanBlocks(
[]chainhash.Hash{*blockHash})
if err != nil {
log.Warnf("Unable to rescan block "+
"with hash %v at height %d: %v",
blockHash, i, err)
continue
}
// If no block was returned from the rescan, it
// means no matching transactions were found.
if len(rescanned) != 1 {
log.Tracef("rescan of block %v at "+
"height=%d yielded no "+
"transactions", blockHash, i)
continue
}
decoded, err := decodeJSONBlock(
&rescanned[0], uint32(i))
if err != nil {
log.Errorf("Unable to decode block: %v",
err)
continue
}
b.blockQueue.Add(&blockEvent{
eventType: connected,
block: decoded,
})
}
// We've received a new request to manually filter a block.
case req := <-b.filterBlockReqs:
// First we'll fetch the block itself as well as some
// additional information including its height.
block, err := b.GetBlock(req.blockHash)
if err != nil {
req.err <- err
req.resp <- nil
continue
}
header, err := b.btcdConn.GetBlockHeaderVerbose(req.blockHash)
if err != nil {
req.err <- err
req.resp <- nil
continue
}
// Once we have this info, we can directly filter the
// block and dispatch the proper notification.
req.resp <- &FilteredBlock{
Hash: *req.blockHash,
Height: uint32(header.Height),
Transactions: filterBlock(block),
}
req.err <- err
case <-b.quit:
return
}
}
}
// filterUpdate is a message sent to the chainFilterer to update the current
// chainFilter state.
type filterUpdate struct {
newUtxos []wire.OutPoint
updateHeight uint32
}
// UpdateFilter updates the UTXO filter which is to be consulted when creating
// FilteredBlocks to be sent to subscribed clients. This method is cumulative
// meaning repeated calls to this method should _expand_ the size of the UTXO
// sub-set currently being watched. If the set updateHeight is _lower_ than
// the best known height of the implementation, then the state should be
// rewound to ensure all relevant notifications are dispatched.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint,
updateHeight uint32) error {
newUtxos := make([]wire.OutPoint, len(ops))
for i, op := range ops {
newUtxos[i] = op.OutPoint
}
select {
case b.filterUpdates <- filterUpdate{
newUtxos: newUtxos,
updateHeight: updateHeight,
}:
return nil
case <-b.quit:
return fmt.Errorf("chain filter shutting down")
}
}
// FilteredBlocks returns the channel that filtered blocks are to be sent over.
// Each time a block is connected to the end of a main chain, and appropriate
// FilteredBlock which contains the transactions which mutate our watched UTXO
// set is to be returned.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) FilteredBlocks() <-chan *FilteredBlock {
return b.blockQueue.newBlocks
}
// DisconnectedBlocks returns a receive only channel which will be sent upon
// with the empty filtered blocks of blocks which are disconnected from the
// main chain in the case of a re-org.
//
// NOTE: This is part of the FilteredChainView interface.
func (b *BtcdFilteredChainView) DisconnectedBlocks() <-chan *FilteredBlock {
return b.blockQueue.staleBlocks
}
// GetBlock is used to retrieve the block with the given hash. This function
// wraps the blockCache's GetBlock function.
func (b *BtcdFilteredChainView) GetBlock(hash *chainhash.Hash) (
*wire.MsgBlock, error) {
return b.blockCache.GetBlock(hash, b.btcdConn.GetBlock)
}
package chainview
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This
// means the package will not perform any logging by default until the caller
// requests it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("CRTR", nil))
}
// DisableLog disables all library log output. Logging output is disabled
// by default until either UseLogger or SetLogWriter are called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info.
// This should be used in preference to SetLogWriter if the caller is also
// using btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package chainview
import (
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/gcs/builder"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/neutrino"
"github.com/lightningnetwork/lnd/blockcache"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/lntypes"
)
// CfFilteredChainView is an implementation of the FilteredChainView interface
// which is supported by an underlying Bitcoin light client which supports
// client side filtering of Golomb Coded Sets. Rather than fetching all the
// blocks, the light client is able to query filters locally, to test if an
// item in a block modifies any of our watched set of UTXOs.
type CfFilteredChainView struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.
// p2pNode is a pointer to the running GCS-filter supported Bitcoin
// light clientl
p2pNode *neutrino.ChainService
// chainView is the active rescan which only watches our specified
// sub-set of the UTXO set.
chainView *neutrino.Rescan
// rescanErrChan is the channel that any errors encountered during the
// rescan will be sent over.
rescanErrChan <-chan error
// blockEventQueue is the ordered queue used to keep the order
// of connected and disconnected blocks sent to the reader of the
// chainView.
blockQueue *blockEventQueue
// blockCache is an LRU block cache.
blockCache *blockcache.BlockCache
// chainFilter is the
filterMtx sync.RWMutex
chainFilter map[wire.OutPoint][]byte
quit chan struct{}
wg sync.WaitGroup
}
// A compile time check to ensure CfFilteredChainView implements the
// chainview.FilteredChainView.
var _ FilteredChainView = (*CfFilteredChainView)(nil)
// NewCfFilteredChainView creates a new instance of the CfFilteredChainView
// which is connected to an active neutrino node.
//
// NOTE: The node should already be running and syncing before being passed into
// this function.
func NewCfFilteredChainView(node *neutrino.ChainService,
blockCache *blockcache.BlockCache) (*CfFilteredChainView, error) {
return &CfFilteredChainView{
blockQueue: newBlockEventQueue(),
quit: make(chan struct{}),
rescanErrChan: make(chan error),
chainFilter: make(map[wire.OutPoint][]byte),
p2pNode: node,
blockCache: blockCache,
}, nil
}
// Start kicks off the FilteredChainView implementation. This function must be
// called before any calls to UpdateFilter can be processed.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) Start() error {
// Already started?
if atomic.AddInt32(&c.started, 1) != 1 {
return nil
}
log.Infof("FilteredChainView starting")
// First, we'll obtain the latest block height of the p2p node. We'll
// start the auto-rescan from this point. Once a caller actually wishes
// to register a chain view, the rescan state will be rewound
// accordingly.
startingPoint, err := c.p2pNode.BestBlock()
if err != nil {
return err
}
// Next, we'll create our set of rescan options. Currently it's
// required that an user MUST set a addr/outpoint/txid when creating a
// rescan. To get around this, we'll add a "zero" outpoint, that won't
// actually be matched.
var zeroPoint neutrino.InputWithScript
rescanOptions := []neutrino.RescanOption{
neutrino.StartBlock(startingPoint),
neutrino.QuitChan(c.quit),
neutrino.NotificationHandlers(
rpcclient.NotificationHandlers{
OnFilteredBlockConnected: c.onFilteredBlockConnected,
OnFilteredBlockDisconnected: c.onFilteredBlockDisconnected,
},
),
neutrino.WatchInputs(zeroPoint),
}
// Finally, we'll create our rescan struct, start it, and launch all
// the goroutines we need to operate this FilteredChainView instance.
c.chainView = neutrino.NewRescan(
&neutrino.RescanChainSource{
ChainService: c.p2pNode,
},
rescanOptions...,
)
c.rescanErrChan = c.chainView.Start()
c.blockQueue.Start()
c.wg.Add(1)
go c.chainFilterer()
return nil
}
// Stop signals all active goroutines for a graceful shutdown.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) Stop() error {
log.Debug("CfFilteredChainView stopping")
defer log.Debug("CfFilteredChainView stopped")
// Already shutting down?
if atomic.AddInt32(&c.stopped, 1) != 1 {
return nil
}
close(c.quit)
c.blockQueue.Stop()
c.wg.Wait()
return nil
}
// onFilteredBlockConnected is called for each block that's connected to the
// end of the main chain. Based on our current chain filter, the block may or
// may not include any relevant transactions.
func (c *CfFilteredChainView) onFilteredBlockConnected(height int32,
header *wire.BlockHeader, txns []*btcutil.Tx) {
mtxs := make([]*wire.MsgTx, len(txns))
for i, tx := range txns {
mtx := tx.MsgTx()
mtxs[i] = mtx
for _, txIn := range mtx.TxIn {
c.filterMtx.Lock()
delete(c.chainFilter, txIn.PreviousOutPoint)
c.filterMtx.Unlock()
}
}
block := &FilteredBlock{
Hash: header.BlockHash(),
Height: uint32(height),
Transactions: mtxs,
}
c.blockQueue.Add(&blockEvent{
eventType: connected,
block: block,
})
}
// onFilteredBlockDisconnected is a callback which is executed once a block is
// disconnected from the end of the main chain.
func (c *CfFilteredChainView) onFilteredBlockDisconnected(height int32,
header *wire.BlockHeader) {
log.Debugf("got disconnected block at height %d: %v", height,
header.BlockHash())
filteredBlock := &FilteredBlock{
Hash: header.BlockHash(),
Height: uint32(height),
}
c.blockQueue.Add(&blockEvent{
eventType: disconnected,
block: filteredBlock,
})
}
// chainFilterer is the primary coordination goroutine within the
// CfFilteredChainView. This goroutine handles errors from the running rescan.
func (c *CfFilteredChainView) chainFilterer() {
defer c.wg.Done()
for {
select {
case err := <-c.rescanErrChan:
log.Errorf("Error encountered during rescan: %v", err)
case <-c.quit:
return
}
}
}
// FilterBlock takes a block hash, and returns a FilteredBlocks which is the
// result of applying the current registered UTXO sub-set on the block
// corresponding to that block hash. If any watched UTXO's are spent by the
// selected lock, then the internal chainFilter will also be updated.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredBlock, error) {
// First, we'll fetch the block header itself so we can obtain the
// height which is part of our return value.
blockHeight, err := c.p2pNode.GetBlockHeight(blockHash)
if err != nil {
return nil, err
}
filteredBlock := &FilteredBlock{
Hash: *blockHash,
Height: uint32(blockHeight),
}
// If we don't have any items within our current chain filter, then we
// can exit early as we don't need to fetch the filter.
c.filterMtx.RLock()
if len(c.chainFilter) == 0 {
c.filterMtx.RUnlock()
return filteredBlock, nil
}
c.filterMtx.RUnlock()
// Next, using the block, hash, we'll fetch the compact filter for this
// block. We only require the regular filter as we're just looking for
// outpoint that have been spent.
filter, err := c.p2pNode.GetCFilter(*blockHash, wire.GCSFilterRegular)
if err != nil {
return nil, fmt.Errorf("unable to fetch filter: %w", err)
}
// Before we can match the filter, we'll need to map each item in our
// chain filter to the representation that included in the compact
// filters.
c.filterMtx.RLock()
relevantPoints := make([][]byte, 0, len(c.chainFilter))
for _, filterEntry := range c.chainFilter {
relevantPoints = append(relevantPoints, filterEntry)
}
c.filterMtx.RUnlock()
// With our relevant points constructed, we can finally match against
// the retrieved filter.
matched, err := filter.MatchAny(builder.DeriveKey(blockHash),
relevantPoints)
if err != nil {
return nil, err
}
// If there wasn't a match, then we'll return the filtered block as is
// (void of any transactions).
if !matched {
return filteredBlock, nil
}
// If we reach this point, then there was a match, so we'll need to
// fetch the block itself so we can scan it for any actual matches (as
// there's a fp rate).
block, err := c.GetBlock(*blockHash)
if err != nil {
return nil, err
}
// Finally, we'll step through the block, input by input, to see if any
// transactions spend any outputs from our watched sub-set of the UTXO
// set.
for _, tx := range block.Transactions() {
for _, txIn := range tx.MsgTx().TxIn {
prevOp := txIn.PreviousOutPoint
c.filterMtx.RLock()
_, ok := c.chainFilter[prevOp]
c.filterMtx.RUnlock()
if ok {
filteredBlock.Transactions = append(
filteredBlock.Transactions,
tx.MsgTx().Copy(),
)
c.filterMtx.Lock()
delete(c.chainFilter, prevOp)
c.filterMtx.Unlock()
break
}
}
}
return filteredBlock, nil
}
// UpdateFilter updates the UTXO filter which is to be consulted when creating
// FilteredBlocks to be sent to subscribed clients. This method is cumulative
// meaning repeated calls to this method should _expand_ the size of the UTXO
// sub-set currently being watched. If the set updateHeight is _lower_ than
// the best known height of the implementation, then the state should be
// rewound to ensure all relevant notifications are dispatched.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) UpdateFilter(ops []graphdb.EdgePoint,
updateHeight uint32) error {
log.Tracef("Updating chain filter with new UTXO's: %v", ops)
// First, we'll update the current chain view, by adding any new
// UTXO's, ignoring duplicates in the process.
c.filterMtx.Lock()
for _, op := range ops {
c.chainFilter[op.OutPoint] = op.FundingPkScript
}
c.filterMtx.Unlock()
inputs := make([]neutrino.InputWithScript, len(ops))
for i, op := range ops {
inputs[i] = neutrino.InputWithScript{
PkScript: op.FundingPkScript,
OutPoint: op.OutPoint,
}
}
// With our internal chain view update, we'll craft a new update to the
// chainView which includes our new UTXO's, and current update height.
rescanUpdate := []neutrino.UpdateOption{
neutrino.AddInputs(inputs...),
neutrino.Rewind(updateHeight),
neutrino.DisableDisconnectedNtfns(true),
}
err := c.chainView.Update(rescanUpdate...)
if err != nil {
return fmt.Errorf("unable to update rescan: %w", err)
}
return nil
}
// FilteredBlocks returns the channel that filtered blocks are to be sent over.
// Each time a block is connected to the end of a main chain, and appropriate
// FilteredBlock which contains the transactions which mutate our watched UTXO
// set is to be returned.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) FilteredBlocks() <-chan *FilteredBlock {
return c.blockQueue.newBlocks
}
// DisconnectedBlocks returns a receive only channel which will be sent upon
// with the empty filtered blocks of blocks which are disconnected from the
// main chain in the case of a re-org.
//
// NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) DisconnectedBlocks() <-chan *FilteredBlock {
return c.blockQueue.staleBlocks
}
// GetBlock is used to retrieve the block with the given hash. Since the block
// cache used by neutrino will be the same as that used by LND (since it is
// passed to neutrino on initialisation), the neutrino GetBlock method can be
// called directly since it already uses the block cache. However, neutrino
// does not lock the block cache mutex for the given block hash and so that is
// done here.
func (c *CfFilteredChainView) GetBlock(hash chainhash.Hash) (
*btcutil.Block, error) {
c.blockCache.HashMutex.Lock(lntypes.Hash(hash))
defer c.blockCache.HashMutex.Unlock(lntypes.Hash(hash))
return c.p2pNode.GetBlock(hash)
}
package chainview
import "sync"
// blockEventType is the possible types of a blockEvent.
type blockEventType uint8
const (
// connected is the type of a blockEvent representing a block
// that was connected to our current chain.
connected blockEventType = iota
// disconnected is the type of a blockEvent representing a
// block that is stale/disconnected from our current chain.
disconnected
)
// blockEvent represent a block that was either connected
// or disconnected from the current chain.
type blockEvent struct {
eventType blockEventType
block *FilteredBlock
}
// blockEventQueue is an ordered queue for block events sent from a
// FilteredChainView. The two types of possible block events are
// connected/new blocks, and disconnected/stale blocks. The
// blockEventQueue keeps the order of these events intact, while
// still being non-blocking. This is important in order for the
// chainView's call to onBlockConnected/onBlockDisconnected to not
// get blocked, and for the consumer of the block events to always
// get the events in the correct order.
type blockEventQueue struct {
queueCond *sync.Cond
queueMtx sync.Mutex
queue []*blockEvent
// newBlocks is the channel where the consumer of the queue
// will receive connected/new blocks from the FilteredChainView.
newBlocks chan *FilteredBlock
// staleBlocks is the channel where the consumer of the queue will
// receive disconnected/stale blocks from the FilteredChainView.
staleBlocks chan *FilteredBlock
wg sync.WaitGroup
quit chan struct{}
}
// newBlockEventQueue creates a new blockEventQueue.
func newBlockEventQueue() *blockEventQueue {
b := &blockEventQueue{
newBlocks: make(chan *FilteredBlock),
staleBlocks: make(chan *FilteredBlock),
quit: make(chan struct{}),
}
b.queueCond = sync.NewCond(&b.queueMtx)
return b
}
// Start starts the blockEventQueue coordinator such that it can start handling
// events.
func (b *blockEventQueue) Start() {
b.wg.Add(1)
go b.queueCoordinator()
}
// Stop signals the queue coordinator to stop, such that the queue can be
// shut down.
func (b *blockEventQueue) Stop() {
close(b.quit)
b.queueCond.Signal()
}
// queueCoordinator is the queue's main loop, handling incoming block events
// and handing them off to the correct output channel.
//
// NB: MUST be run as a goroutine from the Start() method.
func (b *blockEventQueue) queueCoordinator() {
defer b.wg.Done()
for {
// First, we'll check our condition. If the queue of events is
// empty, then we'll wait until a new item is added.
b.queueCond.L.Lock()
for len(b.queue) == 0 {
b.queueCond.Wait()
// If we were woke up in order to exit, then we'll do
// so. Otherwise, we'll check the queue for any new
// items.
select {
case <-b.quit:
b.queueCond.L.Unlock()
return
default:
}
}
// Grab the first element in the queue, and nil the index to
// avoid gc leak.
event := b.queue[0]
b.queue[0] = nil
b.queue = b.queue[1:]
b.queueCond.L.Unlock()
// In the case this is a connected block, we'll send it on the
// newBlocks channel. In case it is a disconnected block, we'll
// send it on the staleBlocks channel. This send will block
// until it is received by the consumer on the other end, making
// sure we won't try to send any other block event before the
// consumer is aware of this one.
switch event.eventType {
case connected:
select {
case b.newBlocks <- event.block:
case <-b.quit:
return
}
case disconnected:
select {
case b.staleBlocks <- event.block:
case <-b.quit:
return
}
}
}
}
// Add puts the provided blockEvent at the end of the event queue, making sure
// it will first be received after all previous events. This method is
// non-blocking, in the sense that it will never wait for the consumer of the
// queue to read form the other end, making it safe to call from the
// FilteredChainView's onBlockConnected/onBlockDisconnected.
func (b *blockEventQueue) Add(event *blockEvent) {
// Lock the condition, and add the event to the end of queue.
b.queueCond.L.Lock()
b.queue = append(b.queue, event)
b.queueCond.L.Unlock()
// With the event added, we signal to the queueCoordinator that
// there are new events to handle.
b.queueCond.Signal()
}
package routing
import (
"sync"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/multimutex"
"github.com/lightningnetwork/lnd/queue"
)
// DBMPPayment is an interface derived from channeldb.MPPayment that is used by
// the payment lifecycle.
type DBMPPayment interface {
// GetState returns the current state of the payment.
GetState() *channeldb.MPPaymentState
// Terminated returns true if the payment is in a final state.
Terminated() bool
// GetStatus returns the current status of the payment.
GetStatus() channeldb.PaymentStatus
// NeedWaitAttempts specifies whether the payment needs to wait for the
// outcome of an attempt.
NeedWaitAttempts() (bool, error)
// GetHTLCs returns all HTLCs of this payment.
GetHTLCs() []channeldb.HTLCAttempt
// InFlightHTLCs returns all HTLCs that are in flight.
InFlightHTLCs() []channeldb.HTLCAttempt
// AllowMoreAttempts is used to decide whether we can safely attempt
// more HTLCs for a given payment state. Return an error if the payment
// is in an unexpected state.
AllowMoreAttempts() (bool, error)
// TerminalInfo returns the settled HTLC attempt or the payment's
// failure reason.
TerminalInfo() (*channeldb.HTLCAttempt, *channeldb.FailureReason)
}
// ControlTower tracks all outgoing payments made, whose primary purpose is to
// prevent duplicate payments to the same payment hash. In production, a
// persistent implementation is preferred so that tracking can survive across
// restarts. Payments are transitioned through various payment states, and the
// ControlTower interface provides access to driving the state transitions.
type ControlTower interface {
// This method checks that no succeeded payment exist for this payment
// hash.
InitPayment(lntypes.Hash, *channeldb.PaymentCreationInfo) error
// DeleteFailedAttempts removes all failed HTLCs from the db. It should
// be called for a given payment whenever all inflight htlcs are
// completed, and the payment has reached a final settled state.
DeleteFailedAttempts(lntypes.Hash) error
// RegisterAttempt atomically records the provided HTLCAttemptInfo.
RegisterAttempt(lntypes.Hash, *channeldb.HTLCAttemptInfo) error
// SettleAttempt marks the given attempt settled with the preimage. If
// this is a multi shard payment, this might implicitly mean the the
// full payment succeeded.
//
// After invoking this method, InitPayment should always return an
// error to prevent us from making duplicate payments to the same
// payment hash. The provided preimage is atomically saved to the DB
// for record keeping.
SettleAttempt(lntypes.Hash, uint64, *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error)
// FailAttempt marks the given payment attempt failed.
FailAttempt(lntypes.Hash, uint64, *channeldb.HTLCFailInfo) (
*channeldb.HTLCAttempt, error)
// FetchPayment fetches the payment corresponding to the given payment
// hash.
FetchPayment(paymentHash lntypes.Hash) (DBMPPayment, error)
// FailPayment transitions a payment into the Failed state, and records
// the ultimate reason the payment failed. Note that this should only
// be called when all active attempts are already failed. After
// invoking this method, InitPayment should return nil on its next call
// for this payment hash, allowing the user to make a subsequent
// payment.
FailPayment(lntypes.Hash, channeldb.FailureReason) error
// FetchInFlightPayments returns all payments with status InFlight.
FetchInFlightPayments() ([]*channeldb.MPPayment, error)
// SubscribePayment subscribes to updates for the payment with the given
// hash. A first update with the current state of the payment is always
// sent out immediately.
SubscribePayment(paymentHash lntypes.Hash) (ControlTowerSubscriber,
error)
// SubscribeAllPayments subscribes to updates for all payments. A first
// update with the current state of every inflight payment is always
// sent out immediately.
SubscribeAllPayments() (ControlTowerSubscriber, error)
}
// ControlTowerSubscriber contains the state for a payment update subscriber.
type ControlTowerSubscriber interface {
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
Updates() <-chan interface{}
// Close signals that the subscriber is no longer interested in updates.
Close()
}
// ControlTowerSubscriberImpl contains the state for a payment update
// subscriber.
type controlTowerSubscriberImpl struct {
updates <-chan interface{}
queue *queue.ConcurrentQueue
quit chan struct{}
}
// newControlTowerSubscriber instantiates a new subscriber state object.
func newControlTowerSubscriber() *controlTowerSubscriberImpl {
// Create a queue for payment updates.
queue := queue.NewConcurrentQueue(20)
queue.Start()
return &controlTowerSubscriberImpl{
updates: queue.ChanOut(),
queue: queue,
quit: make(chan struct{}),
}
}
// Close signals that the subscriber is no longer interested in updates.
func (s *controlTowerSubscriberImpl) Close() {
// Close quit channel so that any pending writes to the queue are
// cancelled.
close(s.quit)
// Stop the queue goroutine so that it won't leak.
s.queue.Stop()
}
// Updates is the channel over which *channeldb.MPPayment updates can be
// received.
func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} {
return s.updates
}
// controlTower is persistent implementation of ControlTower to restrict
// double payment sending.
type controlTower struct {
db *channeldb.PaymentControl
// subscriberIndex is used to provide a unique id for each subscriber
// to all payments. This is used to easily remove the subscriber when
// necessary.
subscriberIndex uint64
subscribersAllPayments map[uint64]*controlTowerSubscriberImpl
subscribers map[lntypes.Hash][]*controlTowerSubscriberImpl
subscribersMtx sync.Mutex
// paymentsMtx provides synchronization on the payment level to ensure
// that no race conditions occur in between updating the database and
// sending a notification.
paymentsMtx *multimutex.Mutex[lntypes.Hash]
}
// NewControlTower creates a new instance of the controlTower.
func NewControlTower(db *channeldb.PaymentControl) ControlTower {
return &controlTower{
db: db,
subscribersAllPayments: make(
map[uint64]*controlTowerSubscriberImpl,
),
subscribers: make(map[lntypes.Hash][]*controlTowerSubscriberImpl),
paymentsMtx: multimutex.NewMutex[lntypes.Hash](),
}
}
// InitPayment checks or records the given PaymentCreationInfo with the DB,
// making sure it does not already exist as an in-flight payment. Then this
// method returns successfully, the payment is guaranteed to be in the
// Initiated state.
func (p *controlTower) InitPayment(paymentHash lntypes.Hash,
info *channeldb.PaymentCreationInfo) error {
err := p.db.InitPayment(paymentHash, info)
if err != nil {
return err
}
// Take lock before querying the db to prevent missing or duplicating
// an update.
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.FetchPayment(paymentHash)
if err != nil {
return err
}
p.notifySubscribers(paymentHash, payment)
return nil
}
// DeleteFailedAttempts deletes all failed htlcs if the payment was
// successfully settled.
func (p *controlTower) DeleteFailedAttempts(paymentHash lntypes.Hash) error {
return p.db.DeleteFailedAttempts(paymentHash)
}
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
// DB.
func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash,
attempt *channeldb.HTLCAttemptInfo) error {
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.RegisterAttempt(paymentHash, attempt)
if err != nil {
return err
}
// Notify subscribers of the attempt registration.
p.notifySubscribers(paymentHash, payment)
return nil
}
// SettleAttempt marks the given attempt settled with the preimage. If
// this is a multi shard payment, this might implicitly mean the the
// full payment succeeded.
func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash,
attemptID uint64, settleInfo *channeldb.HTLCSettleInfo) (
*channeldb.HTLCAttempt, error) {
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.SettleAttempt(paymentHash, attemptID, settleInfo)
if err != nil {
return nil, err
}
// Notify subscribers of success event.
p.notifySubscribers(paymentHash, payment)
return payment.GetAttempt(attemptID)
}
// FailAttempt marks the given payment attempt failed.
func (p *controlTower) FailAttempt(paymentHash lntypes.Hash,
attemptID uint64, failInfo *channeldb.HTLCFailInfo) (
*channeldb.HTLCAttempt, error) {
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.FailAttempt(paymentHash, attemptID, failInfo)
if err != nil {
return nil, err
}
// Notify subscribers of failed attempt.
p.notifySubscribers(paymentHash, payment)
return payment.GetAttempt(attemptID)
}
// FetchPayment fetches the payment corresponding to the given payment hash.
func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) (
DBMPPayment, error) {
return p.db.FetchPayment(paymentHash)
}
// FailPayment transitions a payment into the Failed state, and records the
// reason the payment failed. After invoking this method, InitPayment should
// return nil on its next call for this payment hash, allowing the switch to
// make a subsequent payment.
//
// NOTE: This method will overwrite the failure reason if the payment is already
// failed.
func (p *controlTower) FailPayment(paymentHash lntypes.Hash,
reason channeldb.FailureReason) error {
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.Fail(paymentHash, reason)
if err != nil {
return err
}
// Notify subscribers of fail event.
p.notifySubscribers(paymentHash, payment)
return nil
}
// FetchInFlightPayments returns all payments with status InFlight.
func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) {
return p.db.FetchInFlightPayments()
}
// SubscribePayment subscribes to updates for the payment with the given hash. A
// first update with the current state of the payment is always sent out
// immediately.
func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) (
ControlTowerSubscriber, error) {
// Take lock before querying the db to prevent missing or duplicating an
// update.
p.paymentsMtx.Lock(paymentHash)
defer p.paymentsMtx.Unlock(paymentHash)
payment, err := p.db.FetchPayment(paymentHash)
if err != nil {
return nil, err
}
subscriber := newControlTowerSubscriber()
// Always write current payment state to the channel.
subscriber.queue.ChanIn() <- payment
// Payment is currently in flight. Register this subscriber for further
// updates. Otherwise this update is the final update and the incoming
// channel can be closed. This will close the queue's outgoing channel
// when all updates have been written.
if !payment.Terminated() {
p.subscribersMtx.Lock()
p.subscribers[paymentHash] = append(
p.subscribers[paymentHash], subscriber,
)
p.subscribersMtx.Unlock()
} else {
close(subscriber.queue.ChanIn())
}
return subscriber, nil
}
// SubscribeAllPayments subscribes to updates for all inflight payments. A first
// update with the current state of every inflight payment is always sent out
// immediately.
// Note: If payments are in-flight while starting a new subscription, the start
// of the payment stream could produce out-of-order and/or duplicate events. In
// order to get updates for every in-flight payment attempt make sure to
// subscribe to this method before initiating any payments.
func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) {
subscriber := newControlTowerSubscriber()
// Add the subscriber to the list before fetching in-flight payments, so
// no events are missed. If a payment attempt update occurs after
// appending and before fetching in-flight payments, an out-of-order
// duplicate may be produced, because it is then fetched in below call
// and notified through the subscription.
p.subscribersMtx.Lock()
p.subscribersAllPayments[p.subscriberIndex] = subscriber
p.subscriberIndex++
p.subscribersMtx.Unlock()
log.Debugf("Scanning for inflight payments")
inflightPayments, err := p.db.FetchInFlightPayments()
if err != nil {
return nil, err
}
log.Debugf("Scanning for inflight payments finished",
len(inflightPayments))
for index := range inflightPayments {
// Always write current payment state to the channel.
subscriber.queue.ChanIn() <- inflightPayments[index]
}
return subscriber, nil
}
// notifySubscribers sends a final payment event to all subscribers of this
// payment. The channel will be closed after this. Note that this function must
// be executed atomically (by means of a lock) with the database update to
// guarantee consistency of the notifications.
func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash,
event *channeldb.MPPayment) {
// Get all subscribers for this payment.
p.subscribersMtx.Lock()
subscribersPaymentHash, ok := p.subscribers[paymentHash]
if !ok && len(p.subscribersAllPayments) == 0 {
p.subscribersMtx.Unlock()
return
}
// If the payment reached a terminal state, the subscriber list can be
// cleared. There won't be any more updates.
terminal := event.Terminated()
if terminal {
delete(p.subscribers, paymentHash)
}
// Copy subscribers to all payments locally while holding the lock in
// order to avoid concurrency issues while reading/writing the map.
subscribersAllPayments := make(map[uint64]*controlTowerSubscriberImpl)
for k, v := range p.subscribersAllPayments {
subscribersAllPayments[k] = v
}
p.subscribersMtx.Unlock()
// Notify all subscribers that subscribed to the current payment hash.
for _, subscriber := range subscribersPaymentHash {
select {
case subscriber.queue.ChanIn() <- event:
// If this event is the last, close the incoming channel
// of the queue. This will signal the subscriber that
// there won't be any more updates.
if terminal {
close(subscriber.queue.ChanIn())
}
// If subscriber disappeared, skip notification. For further
// notifications, we'll keep skipping over this subscriber.
case <-subscriber.quit:
}
}
// Notify all subscribers that subscribed to all payments.
for key, subscriber := range subscribersAllPayments {
select {
case subscriber.queue.ChanIn() <- event:
// If subscriber disappeared, remove it from the subscribers
// list.
case <-subscriber.quit:
p.subscribersMtx.Lock()
delete(p.subscribersAllPayments, key)
p.subscribersMtx.Unlock()
}
}
}
package routing
import (
"fmt"
"github.com/btcsuite/btcd/btcutil"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// Graph is an abstract interface that provides information about nodes and
// edges to pathfinding.
type Graph interface {
// ForEachNodeDirectedChannel calls the callback for every channel of
// the given node.
ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *graphdb.DirectedChannel) error) error
// FetchNodeFeatures returns the features of the given node.
FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
}
// GraphSessionFactory can be used to gain access to a graphdb.NodeTraverser
// instance which can then be used for a path-finding session. Depending on the
// implementation, the session will represent a DB connection where a read-lock
// is being held across calls to the backing graph.
type GraphSessionFactory interface {
// GraphSession will provide the call-back with access to a
// graphdb.NodeTraverser instance which can be used to perform queries
// against the channel graph.
GraphSession(cb func(graph graphdb.NodeTraverser) error) error
}
// FetchAmountPairCapacity determines the maximal public capacity between two
// nodes depending on the amount we try to send.
func FetchAmountPairCapacity(graph Graph, source, nodeFrom, nodeTo route.Vertex,
amount lnwire.MilliSatoshi) (btcutil.Amount, error) {
// Create unified edges for all incoming connections.
//
// Note: Inbound fees are not used here because this method is only used
// by a deprecated router rpc.
u := newNodeEdgeUnifier(source, nodeTo, false, nil)
err := u.addGraphPolicies(graph)
if err != nil {
return 0, err
}
edgeUnifier, ok := u.edgeUnifiers[nodeFrom]
if !ok {
return 0, fmt.Errorf("no edge info for node pair %v -> %v",
nodeFrom, nodeTo)
}
edge := edgeUnifier.getEdgeNetwork(amount, 0)
if edge == nil {
return 0, fmt.Errorf("no edge for node pair %v -> %v "+
"(amount %v)", nodeFrom, nodeTo, amount)
}
return edge.capacity, nil
}
package routing
import (
"container/heap"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// nodeWithDist is a helper struct that couples the distance from the current
// source to a node with a pointer to the node itself.
type nodeWithDist struct {
// dist is the distance to this node from the source node in our
// current context.
dist int64
// node is the vertex itself. This can be used to explore all the
// outgoing edges (channels) emanating from a node.
node route.Vertex
// netAmountReceived is the amount that should be received by this node.
// Either as final payment to the final node or as an intermediate
// amount that includes also the fees for subsequent hops. This node's
// inbound fee is already subtracted from the htlc amount - if
// applicable.
netAmountReceived lnwire.MilliSatoshi
// outboundFee is the fee that this node charges on the outgoing
// channel.
outboundFee lnwire.MilliSatoshi
// incomingCltv is the expected absolute expiry height for the incoming
// htlc of this node. This value should already include the final cltv
// delta.
incomingCltv int32
// probability is the probability that from this node onward the route
// is successful.
probability float64
// weight is the cost of the route from this node to the destination.
// Includes the routing fees and a virtual cost factor to account for
// time locks.
weight int64
// nextHop is the edge this route comes from.
nextHop *unifiedEdge
// routingInfoSize is the total size requirement for the payloads field
// in the onion packet from this hop towards the final destination.
routingInfoSize uint64
}
// distanceHeap is a min-distance heap that's used within our path finding
// algorithm to keep track of the "closest" node to our source node.
type distanceHeap struct {
nodes []*nodeWithDist
// pubkeyIndices maps public keys of nodes to their respective index in
// the heap. This is used as a way to avoid db lookups by using heap.Fix
// instead of having duplicate entries on the heap.
pubkeyIndices map[route.Vertex]int
}
// newDistanceHeap initializes a new distance heap. This is required because
// we must initialize the pubkeyIndices map for path-finding optimizations.
func newDistanceHeap(numNodes int) distanceHeap {
distHeap := distanceHeap{
pubkeyIndices: make(map[route.Vertex]int, numNodes),
nodes: make([]*nodeWithDist, 0, numNodes),
}
return distHeap
}
// Len returns the number of nodes in the priority queue.
//
// NOTE: This is part of the heap.Interface implementation.
func (d *distanceHeap) Len() int { return len(d.nodes) }
// Less returns whether the item in the priority queue with index i should sort
// before the item with index j.
//
// NOTE: This is part of the heap.Interface implementation.
func (d *distanceHeap) Less(i, j int) bool {
// If distances are equal, tie break on probability.
if d.nodes[i].dist == d.nodes[j].dist {
return d.nodes[i].probability > d.nodes[j].probability
}
return d.nodes[i].dist < d.nodes[j].dist
}
// Swap swaps the nodes at the passed indices in the priority queue.
//
// NOTE: This is part of the heap.Interface implementation.
func (d *distanceHeap) Swap(i, j int) {
d.nodes[i], d.nodes[j] = d.nodes[j], d.nodes[i]
d.pubkeyIndices[d.nodes[i].node] = i
d.pubkeyIndices[d.nodes[j].node] = j
}
// Push pushes the passed item onto the priority queue.
//
// NOTE: This is part of the heap.Interface implementation.
func (d *distanceHeap) Push(x interface{}) {
n := x.(*nodeWithDist)
d.nodes = append(d.nodes, n)
d.pubkeyIndices[n.node] = len(d.nodes) - 1
}
// Pop removes the highest priority item (according to Less) from the priority
// queue and returns it.
//
// NOTE: This is part of the heap.Interface implementation.
func (d *distanceHeap) Pop() interface{} {
n := len(d.nodes)
x := d.nodes[n-1]
d.nodes[n-1] = nil
d.nodes = d.nodes[0 : n-1]
delete(d.pubkeyIndices, x.node)
return x
}
// PushOrFix attempts to adjust the position of a given node in the heap.
// If the vertex already exists in the heap, then we must call heap.Fix to
// modify its position and reorder the heap. If the vertex does not already
// exist in the heap, then it is pushed onto the heap. Otherwise, we will end
// up performing more db lookups on the same node in the pathfinding algorithm.
func (d *distanceHeap) PushOrFix(dist *nodeWithDist) {
index, ok := d.pubkeyIndices[dist.node]
if !ok {
heap.Push(d, dist)
return
}
// Change the value at the specified index.
d.nodes[index] = dist
// Call heap.Fix to reorder the heap.
heap.Fix(d, index)
}
package routing
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/routing/chainview"
)
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
const Subsystem = "CRTR"
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger(Subsystem, nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// by default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
chainview.UseLogger(logger)
}
package routing
import (
"bytes"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btclog/v2"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// DefaultPenaltyHalfLife is the default half-life duration. The
// half-life duration defines after how much time a penalized node or
// channel is back at 50% probability.
DefaultPenaltyHalfLife = time.Hour
// minSecondChanceInterval is the minimum time required between
// second-chance failures.
//
// If nodes return a channel policy related failure, they may get a
// second chance to forward the payment. It could be that the channel
// policy that we are aware of is not up to date. This is especially
// important in case of mobile apps that are mostly offline.
//
// However, we don't want to give nodes the option to endlessly return
// new channel updates so that we are kept busy trying to route through
// that node until the payment loop times out.
//
// Therefore we only grant a second chance to a node if the previous
// second chance is sufficiently long ago. This is what
// minSecondChanceInterval defines. If a second policy failure comes in
// within that interval, we will apply a penalty.
//
// Second chances granted are tracked on the level of node pairs. This
// means that if a node has multiple channels to the same peer, they
// will only get a single second chance to route to that peer again.
// Nodes forward non-strict, so it isn't necessary to apply a less
// restrictive channel level tracking scheme here.
minSecondChanceInterval = time.Minute
// DefaultMaxMcHistory is the default maximum history size.
DefaultMaxMcHistory = 1000
// DefaultMcFlushInterval is the default interval we use to flush MC state
// to the database.
DefaultMcFlushInterval = time.Second
// prevSuccessProbability is the assumed probability for node pairs that
// successfully relayed the previous attempt.
prevSuccessProbability = 0.95
// DefaultAprioriWeight is the default a priori weight. See
// MissionControlConfig for further explanation.
DefaultAprioriWeight = 0.5
// DefaultMinFailureRelaxInterval is the default minimum time that must
// have passed since the previously recorded failure before the failure
// amount may be raised.
DefaultMinFailureRelaxInterval = time.Minute
// DefaultFeeEstimationTimeout is the default value for
// FeeEstimationTimeout. It defines the maximum duration that the
// probing fee estimation is allowed to take.
DefaultFeeEstimationTimeout = time.Minute
// DefaultMissionControlNamespace is the name of the default mission
// control name space. This is used as the sub-bucket key within the
// top level DB bucket to store mission control results.
DefaultMissionControlNamespace = "default"
)
var (
// ErrInvalidMcHistory is returned if we get a negative mission control
// history count.
ErrInvalidMcHistory = errors.New("mission control history must be " +
">= 0")
// ErrInvalidFailureInterval is returned if we get an invalid failure
// interval.
ErrInvalidFailureInterval = errors.New("failure interval must be >= 0")
)
// NodeResults contains previous results from a node to its peers.
type NodeResults map[route.Vertex]TimedPairResult
// mcConfig holds various config members that will be required by all
// MissionControl instances and will be the same regardless of namespace.
type mcConfig struct {
// clock is a time source used by mission control.
clock clock.Clock
// selfNode is our pubkey.
selfNode route.Vertex
}
// MissionControl contains state which summarizes the past attempts of HTLC
// routing by external callers when sending payments throughout the network. It
// acts as a shared memory during routing attempts with the goal to optimize the
// payment attempt success rate.
//
// Failed payment attempts are reported to mission control. These reports are
// used to track the time of the last node or channel level failure. The time
// since the last failure is used to estimate a success probability that is fed
// into the path finding process for subsequent payment attempts.
type MissionControl struct {
cfg *mcConfig
// state is the internal mission control state that is input for
// probability estimation.
state *missionControlState
store *missionControlStore
// estimator is the probability estimator that is used with the payment
// results that mission control collects.
estimator Estimator
// onConfigUpdate is a function that is called whenever the
// mission control state is updated.
onConfigUpdate fn.Option[func(cfg *MissionControlConfig)]
log btclog.Logger
mu sync.Mutex
}
// MissionController manages MissionControl instances in various namespaces.
type MissionController struct {
db kvdb.Backend
cfg *mcConfig
defaultMCCfg *MissionControlConfig
mc map[string]*MissionControl
mu sync.Mutex
// TODO(roasbeef): further counters, if vertex continually unavailable,
// add to another generation
// TODO(roasbeef): also add favorable metrics for nodes
}
// GetNamespacedStore returns the MissionControl in the given namespace. If one
// does not yet exist, then it is initialised.
func (m *MissionController) GetNamespacedStore(ns string) (*MissionControl,
error) {
m.mu.Lock()
defer m.mu.Unlock()
if mc, ok := m.mc[ns]; ok {
return mc, nil
}
return m.initMissionControl(ns)
}
// ListNamespaces returns a list of the namespaces that the MissionController
// is aware of.
func (m *MissionController) ListNamespaces() []string {
m.mu.Lock()
defer m.mu.Unlock()
namespaces := make([]string, 0, len(m.mc))
for ns := range m.mc {
namespaces = append(namespaces, ns)
}
return namespaces
}
// MissionControlConfig defines parameters that control mission control
// behaviour.
type MissionControlConfig struct {
// Estimator gives probability estimates for node pairs.
Estimator Estimator
// OnConfigUpdate is function that is called whenever the
// mission control state is updated.
OnConfigUpdate fn.Option[func(cfg *MissionControlConfig)]
// MaxMcHistory defines the maximum number of payment results that are
// held on disk.
MaxMcHistory int
// McFlushInterval defines the ticker interval when we flush the
// accumulated state to the DB.
McFlushInterval time.Duration
// MinFailureRelaxInterval is the minimum time that must have passed
// since the previously recorded failure before the failure amount may
// be raised.
MinFailureRelaxInterval time.Duration
}
func (c *MissionControlConfig) validate() error {
if c.MaxMcHistory < 0 {
return ErrInvalidMcHistory
}
if c.MinFailureRelaxInterval < 0 {
return ErrInvalidFailureInterval
}
return nil
}
// String returns a string representation of a mission control config.
func (c *MissionControlConfig) String() string {
return fmt.Sprintf("maximum history: %v, minimum failure relax "+
"interval: %v", c.MaxMcHistory, c.MinFailureRelaxInterval)
}
// TimedPairResult describes a timestamped pair result.
type TimedPairResult struct {
// FailTime is the time of the last failure.
FailTime time.Time
// FailAmt is the amount of the last failure. This amount may be pushed
// up if a later success is higher than the last failed amount.
FailAmt lnwire.MilliSatoshi
// SuccessTime is the time of the last success.
SuccessTime time.Time
// SuccessAmt is the highest amount that successfully forwarded. This
// isn't necessarily the last success amount. The value of this field
// may also be pushed down if a later failure is lower than the highest
// success amount. Because of this, SuccessAmt may not match
// SuccessTime.
SuccessAmt lnwire.MilliSatoshi
}
// MissionControlSnapshot contains a snapshot of the current state of mission
// control.
type MissionControlSnapshot struct {
// Pairs is a list of channels for which specific information is
// logged.
Pairs []MissionControlPairSnapshot
}
// MissionControlPairSnapshot contains a snapshot of the current node pair
// state in mission control.
type MissionControlPairSnapshot struct {
// Pair is the node pair of which the state is described.
Pair DirectedNodePair
// TimedPairResult contains the data for this pair.
TimedPairResult
}
// paymentResult is the information that becomes available when a payment
// attempt completes.
type paymentResult struct {
id uint64
timeFwd tlv.RecordT[tlv.TlvType0, uint64]
timeReply tlv.RecordT[tlv.TlvType1, uint64]
route tlv.RecordT[tlv.TlvType2, mcRoute]
// failure holds information related to the failure of a payment. The
// presence of this record indicates a payment failure. The absence of
// this record indicates a successful payment.
failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure]
}
// newPaymentResult constructs a new paymentResult.
func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time,
failure *paymentFailure) *paymentResult {
result := &paymentResult{
id: id,
timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint64(timeFwd.UnixNano()),
),
timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1](
uint64(timeReply.UnixNano()),
),
route: tlv.NewRecordT[tlv.TlvType2](*rt),
}
if failure != nil {
result.failure = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](*failure),
)
}
return result
}
// NewMissionController returns a new instance of MissionController.
func NewMissionController(db kvdb.Backend, self route.Vertex,
cfg *MissionControlConfig) (*MissionController, error) {
log.Debugf("Instantiating mission control with config: %v, %v", cfg,
cfg.Estimator)
if err := cfg.validate(); err != nil {
return nil, err
}
mcCfg := &mcConfig{
clock: clock.NewDefaultClock(),
selfNode: self,
}
mgr := &MissionController{
db: db,
defaultMCCfg: cfg,
cfg: mcCfg,
mc: make(map[string]*MissionControl),
}
if err := mgr.loadMissionControls(); err != nil {
return nil, err
}
for _, mc := range mgr.mc {
if err := mc.init(); err != nil {
return nil, err
}
}
return mgr, nil
}
// loadMissionControls initialises a MissionControl in the default namespace if
// one does not yet exist. It then initialises a MissionControl for all other
// namespaces found in the DB.
//
// NOTE: this should only be called once during MissionController construction.
func (m *MissionController) loadMissionControls() error {
m.mu.Lock()
defer m.mu.Unlock()
// Always initialise the default namespace.
_, err := m.initMissionControl(DefaultMissionControlNamespace)
if err != nil {
return err
}
namespaces := make(map[string]struct{})
err = m.db.View(func(tx walletdb.ReadTx) error {
mcStoreBkt := tx.ReadBucket(resultsKey)
if mcStoreBkt == nil {
return fmt.Errorf("top level mission control bucket " +
"not found")
}
// Iterate through all the keys in the bucket and collect the
// namespaces.
return mcStoreBkt.ForEach(func(k, _ []byte) error {
// We've already initialised the default namespace so
// we can skip it.
if string(k) == DefaultMissionControlNamespace {
return nil
}
namespaces[string(k)] = struct{}{}
return nil
})
}, func() {})
if err != nil {
return err
}
// Now, iterate through all the namespaces and initialise them.
for ns := range namespaces {
_, err = m.initMissionControl(ns)
if err != nil {
return err
}
}
return nil
}
// initMissionControl creates a new MissionControl instance with the given
// namespace if one does not yet exist.
//
// NOTE: the MissionController's mutex must be held before calling this method.
func (m *MissionController) initMissionControl(namespace string) (
*MissionControl, error) {
// If a mission control with this namespace has already been initialised
// then there is nothing left to do.
if mc, ok := m.mc[namespace]; ok {
return mc, nil
}
cfg := m.defaultMCCfg
store, err := newMissionControlStore(
newNamespacedDB(m.db, namespace), cfg.MaxMcHistory,
cfg.McFlushInterval,
)
if err != nil {
return nil, err
}
mc := &MissionControl{
cfg: m.cfg,
state: newMissionControlState(
cfg.MinFailureRelaxInterval,
),
store: store,
estimator: cfg.Estimator,
log: log.WithPrefix(fmt.Sprintf("[%s]:", namespace)),
onConfigUpdate: cfg.OnConfigUpdate,
}
m.mc[namespace] = mc
return mc, nil
}
// RunStoreTickers runs the mission controller store's tickers.
func (m *MissionController) RunStoreTickers() {
m.mu.Lock()
defer m.mu.Unlock()
for _, mc := range m.mc {
mc.store.run()
}
}
// StopStoreTickers stops the mission control store's tickers.
func (m *MissionController) StopStoreTickers() {
log.Debug("Stopping mission control store ticker")
defer log.Debug("Mission control store ticker stopped")
m.mu.Lock()
defer m.mu.Unlock()
for _, mc := range m.mc {
mc.store.stop()
}
}
// init initializes mission control with historical data.
func (m *MissionControl) init() error {
m.log.Debugf("Mission control state reconstruction started")
m.mu.Lock()
defer m.mu.Unlock()
start := time.Now()
results, err := m.store.fetchAll()
if err != nil {
return err
}
for _, result := range results {
_ = m.applyPaymentResult(result)
}
m.log.Debugf("Mission control state reconstruction finished: "+
"n=%v, time=%v", len(results), time.Since(start))
return nil
}
// GetConfig returns the config that mission control is currently configured
// with. All fields are copied by value, so we do not need to worry about
// mutation.
func (m *MissionControl) GetConfig() *MissionControlConfig {
m.mu.Lock()
defer m.mu.Unlock()
return &MissionControlConfig{
Estimator: m.estimator,
MaxMcHistory: m.store.maxRecords,
McFlushInterval: m.store.flushInterval,
MinFailureRelaxInterval: m.state.minFailureRelaxInterval,
}
}
// SetConfig validates the config provided and updates mission control's config
// if it is valid.
func (m *MissionControl) SetConfig(cfg *MissionControlConfig) error {
if cfg == nil {
return errors.New("nil mission control config")
}
if err := cfg.validate(); err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
m.log.Infof("Active mission control cfg: %v, estimator: %v", cfg,
cfg.Estimator)
m.store.maxRecords = cfg.MaxMcHistory
m.state.minFailureRelaxInterval = cfg.MinFailureRelaxInterval
m.estimator = cfg.Estimator
// Execute the callback function if it is set.
m.onConfigUpdate.WhenSome(func(f func(cfg *MissionControlConfig)) {
f(cfg)
})
return nil
}
// ResetHistory resets the history of MissionControl returning it to a state as
// if no payment attempts have been made.
func (m *MissionControl) ResetHistory() error {
m.mu.Lock()
defer m.mu.Unlock()
if err := m.store.clear(); err != nil {
return err
}
m.state.resetHistory()
m.log.Debugf("Mission control history cleared")
return nil
}
// GetProbability is expected to return the success probability of a payment
// from fromNode along edge.
func (m *MissionControl) GetProbability(fromNode, toNode route.Vertex,
amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 {
m.mu.Lock()
defer m.mu.Unlock()
now := m.cfg.clock.Now()
results, _ := m.state.getLastPairResult(fromNode)
// Use a distinct probability estimation function for local channels.
if fromNode == m.cfg.selfNode {
return m.estimator.LocalPairProbability(now, results, toNode)
}
return m.estimator.PairProbability(
now, results, toNode, amt, capacity,
)
}
// GetHistorySnapshot takes a snapshot from the current mission control state
// and actual probability estimates.
func (m *MissionControl) GetHistorySnapshot() *MissionControlSnapshot {
m.mu.Lock()
defer m.mu.Unlock()
m.log.Debugf("Requesting history snapshot from mission control")
return m.state.getSnapshot()
}
// ImportHistory imports the set of mission control results provided to our
// in-memory state. These results are not persisted, so will not survive
// restarts.
func (m *MissionControl) ImportHistory(history *MissionControlSnapshot,
force bool) error {
if history == nil {
return errors.New("cannot import nil history")
}
m.mu.Lock()
defer m.mu.Unlock()
m.log.Infof("Importing history snapshot with %v pairs to mission "+
"control", len(history.Pairs))
imported := m.state.importSnapshot(history, force)
m.log.Infof("Imported %v results to mission control", imported)
return nil
}
// GetPairHistorySnapshot returns the stored history for a given node pair.
func (m *MissionControl) GetPairHistorySnapshot(
fromNode, toNode route.Vertex) TimedPairResult {
m.mu.Lock()
defer m.mu.Unlock()
results, ok := m.state.getLastPairResult(fromNode)
if !ok {
return TimedPairResult{}
}
result, ok := results[toNode]
if !ok {
return TimedPairResult{}
}
return result
}
// ReportPaymentFail reports a failed payment to mission control as input for
// future probability estimates. The failureSourceIdx argument indicates the
// failure source. If it is nil, the failure source is unknown. This function
// returns a reason if this failure is a final failure. In that case no further
// payment attempts need to be made.
func (m *MissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route,
failureSourceIdx *int, failure lnwire.FailureMessage) (
*channeldb.FailureReason, error) {
timestamp := m.cfg.clock.Now()
result := newPaymentResult(
paymentID, extractMCRoute(rt), timestamp, timestamp,
newPaymentFailure(failureSourceIdx, failure),
)
return m.processPaymentResult(result)
}
// ReportPaymentSuccess reports a successful payment to mission control as input
// for future probability estimates.
func (m *MissionControl) ReportPaymentSuccess(paymentID uint64,
rt *route.Route) error {
timestamp := m.cfg.clock.Now()
result := newPaymentResult(
paymentID, extractMCRoute(rt), timestamp, timestamp, nil,
)
_, err := m.processPaymentResult(result)
return err
}
// processPaymentResult stores a payment result in the mission control store and
// updates mission control's in-memory state.
func (m *MissionControl) processPaymentResult(result *paymentResult) (
*channeldb.FailureReason, error) {
// Store complete result in database.
m.store.AddResult(result)
m.mu.Lock()
defer m.mu.Unlock()
// Apply result to update mission control state.
reason := m.applyPaymentResult(result)
return reason, nil
}
// applyPaymentResult applies a payment result as input for future probability
// estimates. It returns a bool indicating whether this error is a final error
// and no further payment attempts need to be made.
func (m *MissionControl) applyPaymentResult(
result *paymentResult) *channeldb.FailureReason {
// Interpret result.
i := interpretResult(&result.route.Val, result.failure.ValOpt())
if i.policyFailure != nil {
if m.state.requestSecondChance(
time.Unix(0, int64(result.timeReply.Val)),
i.policyFailure.From, i.policyFailure.To,
) {
return nil
}
}
// If there is a node-level failure, record a failure for every tried
// connection of that node. A node-level failure can be considered as a
// failure that would have occurred with any of the node's channels.
//
// Ideally we'd also record the failure for the untried connections of
// the node. Unfortunately this would require access to the graph and
// adding this dependency and db calls does not outweigh the benefits.
//
// Untried connections will fall back to the node probability. After the
// call to setAllPairResult below, the node probability will be equal to
// the probability of the tried channels except that the a priori
// probability is mixed in too. This effect is controlled by the
// aprioriWeight parameter. If that parameter isn't set to an extreme
// and there are a few known connections, there shouldn't be much of a
// difference. The largest difference occurs when aprioriWeight is 1. In
// that case, a node-level failure would not be applied to untried
// channels.
if i.nodeFailure != nil {
m.log.Debugf("Reporting node failure to Mission Control: "+
"node=%v", *i.nodeFailure)
m.state.setAllFail(
*i.nodeFailure,
time.Unix(0, int64(result.timeReply.Val)),
)
}
for pair, pairResult := range i.pairResults {
pairResult := pairResult
if pairResult.success {
m.log.Debugf("Reporting pair success to Mission "+
"Control: pair=%v, amt=%v",
pair, pairResult.amt)
} else {
m.log.Debugf("Reporting pair failure to Mission "+
"Control: pair=%v, amt=%v",
pair, pairResult.amt)
}
m.state.setLastPairResult(
pair.From, pair.To,
time.Unix(0, int64(result.timeReply.Val)), &pairResult,
false,
)
}
return i.finalFailureReason
}
// namespacedDB is an implementation of the missionControlDB that gives a user
// of the interface access to a namespaced bucket within the top level mission
// control bucket.
type namespacedDB struct {
topLevelBucketKey []byte
namespace []byte
db kvdb.Backend
}
// A compile-time check to ensure that namespacedDB implements missionControlDB.
var _ missionControlDB = (*namespacedDB)(nil)
// newDefaultNamespacedStore creates an instance of namespaceDB that uses the
// default namespace.
func newDefaultNamespacedStore(db kvdb.Backend) missionControlDB {
return newNamespacedDB(db, DefaultMissionControlNamespace)
}
// newNamespacedDB creates a new instance of missionControlDB where the DB will
// have access to a namespaced bucket within the top level mission control
// bucket.
func newNamespacedDB(db kvdb.Backend, namespace string) missionControlDB {
return &namespacedDB{
db: db,
namespace: []byte(namespace),
topLevelBucketKey: resultsKey,
}
}
// update can be used to perform reads and writes on the given bucket.
//
// NOTE: this is part of the missionControlDB interface.
func (n *namespacedDB) update(f func(bkt walletdb.ReadWriteBucket) error,
reset func()) error {
return n.db.Update(func(tx kvdb.RwTx) error {
mcStoreBkt, err := tx.CreateTopLevelBucket(n.topLevelBucketKey)
if err != nil {
return fmt.Errorf("cannot create top level mission "+
"control bucket: %w", err)
}
namespacedBkt, err := mcStoreBkt.CreateBucketIfNotExists(
n.namespace,
)
if err != nil {
return fmt.Errorf("cannot create namespaced bucket "+
"(%s) in mission control store: %w",
n.namespace, err)
}
return f(namespacedBkt)
}, reset)
}
// view can be used to perform reads on the given bucket.
//
// NOTE: this is part of the missionControlDB interface.
func (n *namespacedDB) view(f func(bkt walletdb.ReadBucket) error,
reset func()) error {
return n.db.View(func(tx kvdb.RTx) error {
mcStoreBkt := tx.ReadBucket(n.topLevelBucketKey)
if mcStoreBkt == nil {
return fmt.Errorf("top level mission control bucket " +
"not found")
}
namespacedBkt := mcStoreBkt.NestedReadBucket(n.namespace)
if namespacedBkt == nil {
return fmt.Errorf("namespaced bucket (%s) not found "+
"in mission control store", n.namespace)
}
return f(namespacedBkt)
}, reset)
}
// purge will delete all the contents in the namespace.
//
// NOTE: this is part of the missionControlDB interface.
func (n *namespacedDB) purge() error {
return n.db.Update(func(tx kvdb.RwTx) error {
mcStoreBkt := tx.ReadWriteBucket(n.topLevelBucketKey)
if mcStoreBkt == nil {
return nil
}
err := mcStoreBkt.DeleteNestedBucket(n.namespace)
if err != nil {
return err
}
_, err = mcStoreBkt.CreateBucket(n.namespace)
return err
}, func() {})
}
// paymentFailure represents the presence of a payment failure. It may or may
// not include additional information about said failure.
type paymentFailure struct {
info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo]
}
// newPaymentFailure constructs a new paymentFailure struct. If the source
// index is nil, then an empty paymentFailure is returned. This represents a
// failure with unknown details. Otherwise, the index and failure message are
// used to populate the info field of the paymentFailure.
func newPaymentFailure(sourceIdx *int,
failureMsg lnwire.FailureMessage) *paymentFailure {
if sourceIdx == nil {
return &paymentFailure{}
}
info := paymentFailureInfo{
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
uint8(*sourceIdx),
),
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
}
return &paymentFailure{
info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)),
}
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailure to/from a TLV stream.
func (r *paymentFailure) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailure(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailure, decodePaymentFailure,
)
}
func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailure); ok {
var recordProducers []tlv.RecordProducer
v.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
recordProducers = append(recordProducers, &r)
},
)
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure")
}
func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailure); ok {
var h paymentFailure
info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]()
typeMap, err := lnwire.DecodeRecords(
r, lnwire.ProduceRecordsSorted(&info)...,
)
if err != nil {
return err
}
if _, ok := typeMap[h.info.TlvType()]; ok {
h.info = tlv.SomeRecordT(info)
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l)
}
// paymentFailureInfo holds additional information about a payment failure.
type paymentFailureInfo struct {
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
msg tlv.RecordT[tlv.TlvType1, failureMessage]
}
// Record returns a TLV record that can be used to encode/decode a
// paymentFailureInfo to/from a TLV stream.
func (r *paymentFailureInfo) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodePaymentFailureInfo(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodePaymentFailureInfo,
decodePaymentFailureInfo,
)
}
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*paymentFailureInfo); ok {
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(
&v.sourceIdx, &v.msg,
),
)
}
return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo")
}
func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*paymentFailureInfo); ok {
var h paymentFailureInfo
_, err := lnwire.DecodeRecords(
r,
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
)
if err != nil {
return err
}
*v = h
return nil
}
return tlv.NewTypeForDecodingErr(
val, "routing.paymentFailureInfo", l, l,
)
}
package routing
import (
"time"
"github.com/lightningnetwork/lnd/routing/route"
)
// missionControlState is an object that manages the internal mission control
// state. Note that it isn't thread safe and synchronization needs to be
// enforced externally.
type missionControlState struct {
// lastPairResult tracks the last payment result (on a pair basis) for
// each transited node. This is a multi-layer map that allows us to look
// up the failure history of all connected channels (node pairs) for a
// particular node.
lastPairResult map[route.Vertex]NodeResults
// lastSecondChance tracks the last time a second chance was granted for
// a directed node pair.
lastSecondChance map[DirectedNodePair]time.Time
// minFailureRelaxInterval is the minimum time that must have passed
// since the previously recorded failure before the failure amount may
// be raised.
minFailureRelaxInterval time.Duration
}
// newMissionControlState instantiates a new mission control state object.
func newMissionControlState(
minFailureRelaxInterval time.Duration) *missionControlState {
return &missionControlState{
lastPairResult: make(map[route.Vertex]NodeResults),
lastSecondChance: make(map[DirectedNodePair]time.Time),
minFailureRelaxInterval: minFailureRelaxInterval,
}
}
// getLastPairResult returns the current state for connections to the given
// node.
func (m *missionControlState) getLastPairResult(node route.Vertex) (NodeResults,
bool) {
result, ok := m.lastPairResult[node]
return result, ok
}
// ResetHistory resets the history of missionControlState returning it to a
// state as if no payment attempts have been made.
func (m *missionControlState) resetHistory() {
m.lastPairResult = make(map[route.Vertex]NodeResults)
m.lastSecondChance = make(map[DirectedNodePair]time.Time)
}
// setLastPairResult stores a result for a node pair.
func (m *missionControlState) setLastPairResult(fromNode, toNode route.Vertex,
timestamp time.Time, result *pairResult, force bool) {
nodePairs, ok := m.lastPairResult[fromNode]
if !ok {
nodePairs = make(NodeResults)
m.lastPairResult[fromNode] = nodePairs
}
current := nodePairs[toNode]
// Apply the new result to the existing data for this pair. If there is
// no existing data, apply it to the default values for TimedPairResult.
if result.success {
successAmt := result.amt
current.SuccessTime = timestamp
// Only update the success amount if this amount is higher. This
// prevents the success range from shrinking when there is no
// reason to do so. For example: small amount probes shouldn't
// affect a previous success for a much larger amount.
if force || successAmt > current.SuccessAmt {
current.SuccessAmt = successAmt
}
// If the success amount goes into the failure range, move the
// failure range up. Future attempts up to the success amount
// are likely to succeed. We don't want to clear the failure
// completely, because we haven't learnt much for amounts above
// the current success amount.
if force || (!current.FailTime.IsZero() &&
successAmt >= current.FailAmt) {
current.FailAmt = successAmt + 1
}
} else {
// For failures we always want to update both the amount and the
// time. Those need to relate to the same result, because the
// time is used to gradually diminish the penalty for that
// specific result. Updating the timestamp but not the amount
// could cause a failure for a lower amount (a more severe
// condition) to be revived as if it just happened.
failAmt := result.amt
// Drop result if it would increase the failure amount too soon
// after a previous failure. This can happen if htlc results
// come in out of order. This check makes it easier for payment
// processes to converge to a final state.
failInterval := timestamp.Sub(current.FailTime)
if !force && failAmt > current.FailAmt &&
failInterval < m.minFailureRelaxInterval {
log.Debugf("Ignoring higher amount failure within min "+
"failure relaxation interval: prev_fail_amt=%v, "+
"fail_amt=%v, interval=%v",
current.FailAmt, failAmt, failInterval)
return
}
current.FailTime = timestamp
current.FailAmt = failAmt
switch {
// The failure amount is set to zero when the failure is
// amount-independent, meaning that the attempt would have
// failed regardless of the amount. This should also reset the
// success amount to zero.
case failAmt == 0:
current.SuccessAmt = 0
// If the failure range goes into the success range, move the
// success range down.
case failAmt <= current.SuccessAmt:
current.SuccessAmt = failAmt - 1
}
}
log.Debugf("Setting %v->%v range to [%v-%v]",
fromNode, toNode, current.SuccessAmt, current.FailAmt)
nodePairs[toNode] = current
}
// setAllFail stores a fail result for all known connections to and from the
// given node.
func (m *missionControlState) setAllFail(node route.Vertex,
timestamp time.Time) {
for fromNode, nodePairs := range m.lastPairResult {
for toNode := range nodePairs {
if fromNode == node || toNode == node {
nodePairs[toNode] = TimedPairResult{
FailTime: timestamp,
}
}
}
}
}
// requestSecondChance checks whether the node fromNode can have a second chance
// at providing a channel update for its channel with toNode.
func (m *missionControlState) requestSecondChance(timestamp time.Time,
fromNode, toNode route.Vertex) bool {
// Look up previous second chance time.
pair := DirectedNodePair{
From: fromNode,
To: toNode,
}
lastSecondChance, ok := m.lastSecondChance[pair]
// If the channel hasn't already be given a second chance or its last
// second chance was long ago, we give it another chance.
if !ok || timestamp.Sub(lastSecondChance) > minSecondChanceInterval {
m.lastSecondChance[pair] = timestamp
log.Debugf("Second chance granted for %v->%v", fromNode, toNode)
return true
}
// Otherwise penalize the channel, because we don't allow channel
// updates that are that frequent. This is to prevent nodes from keeping
// us busy by continuously sending new channel updates.
log.Debugf("Second chance denied for %v->%v, remaining interval: %v",
fromNode, toNode, timestamp.Sub(lastSecondChance))
return false
}
// GetHistorySnapshot takes a snapshot from the current mission control state
// and actual probability estimates.
func (m *missionControlState) getSnapshot() *MissionControlSnapshot {
log.Debugf("Requesting history snapshot from mission control: "+
"pair_result_count=%v", len(m.lastPairResult))
pairs := make([]MissionControlPairSnapshot, 0, len(m.lastPairResult))
for fromNode, fromPairs := range m.lastPairResult {
for toNode, result := range fromPairs {
pair := NewDirectedNodePair(fromNode, toNode)
pairSnapshot := MissionControlPairSnapshot{
Pair: pair,
TimedPairResult: result,
}
pairs = append(pairs, pairSnapshot)
}
}
snapshot := MissionControlSnapshot{
Pairs: pairs,
}
return &snapshot
}
// importSnapshot takes an existing snapshot and merges it with our current
// state if the result provided are fresher than our current results. It returns
// the number of pairs that were used.
func (m *missionControlState) importSnapshot(snapshot *MissionControlSnapshot,
force bool) int {
var imported int
for _, pair := range snapshot.Pairs {
fromNode := pair.Pair.From
toNode := pair.Pair.To
results, found := m.getLastPairResult(fromNode)
if !found {
results = make(map[route.Vertex]TimedPairResult)
}
lastResult := results[toNode]
failResult := failPairResult(pair.FailAmt)
imported += m.importResult(
lastResult.FailTime, pair.FailTime, failResult,
fromNode, toNode, force,
)
successResult := successPairResult(pair.SuccessAmt)
imported += m.importResult(
lastResult.SuccessTime, pair.SuccessTime, successResult,
fromNode, toNode, force,
)
}
return imported
}
func (m *missionControlState) importResult(currentTs, importedTs time.Time,
importedResult pairResult, fromNode, toNode route.Vertex,
force bool) int {
if !force && currentTs.After(importedTs) {
log.Debugf("Not setting pair result for %v->%v (%v) "+
"success=%v, timestamp %v older than last result %v",
fromNode, toNode, importedResult.amt,
importedResult.success, importedTs, currentTs)
return 0
}
m.setLastPairResult(
fromNode, toNode, importedTs, &importedResult, force,
)
return 1
}
package routing
import (
"bytes"
"container/list"
"encoding/binary"
"fmt"
"io"
"sync"
"time"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// resultsKey is the fixed key under which the attempt results are
// stored.
resultsKey = []byte("missioncontrol-results")
// Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order.
byteOrder = binary.BigEndian
)
// missionControlDB is an interface that defines the database methods that a
// single missionControlStore has access to. It allows the missionControlStore
// to be unaware of the overall DB structure and restricts its access to the DB
// by only providing it the bucket that it needs to care about.
type missionControlDB interface {
// update can be used to perform reads and writes on the given bucket.
update(f func(bkt kvdb.RwBucket) error, reset func()) error
// view can be used to perform reads on the given bucket.
view(f func(bkt kvdb.RBucket) error, reset func()) error
// purge will delete all the contents in this store.
purge() error
}
// missionControlStore is a bolt db based implementation of a mission control
// store. It stores the raw payment attempt data from which the internal mission
// controls state can be rederived on startup. This allows the mission control
// internal data structure to be changed without requiring a database migration.
// Also changes to mission control parameters can be applied to historical data.
// Finally, it enables importing raw data from an external source.
type missionControlStore struct {
done chan struct{}
wg sync.WaitGroup
db missionControlDB
// TODO(yy): Remove the usage of sync.Cond - we are better off using
// channes than a Cond as suggested in the official godoc.
//
// queueCond is signalled when items are put into the queue.
queueCond *sync.Cond
// queue stores all pending payment results not yet added to the store.
// Access is protected by the queueCond.L mutex.
queue *list.List
// keys holds the stored MC store item keys in the order of storage.
// We use this list when adding/deleting items from the database to
// avoid cursor use which may be slow in the remote DB case.
keys *list.List
// keysMap holds the stored MC store item keys. We use this map to check
// if a new payment result has already been stored.
keysMap map[string]struct{}
// maxRecords is the maximum amount of records we will store in the db.
maxRecords int
// flushInterval is the configured interval we use to store new results
// and delete outdated ones from the db.
flushInterval time.Duration
}
func newMissionControlStore(db missionControlDB, maxRecords int,
flushInterval time.Duration) (*missionControlStore, error) {
var (
keys *list.List
keysMap map[string]struct{}
)
// Create buckets if not yet existing.
err := db.update(func(resultsBucket kvdb.RwBucket) error {
// Collect all keys to be able to quickly calculate the
// difference when updating the DB state.
c := resultsBucket.ReadCursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
keys.PushBack(string(k))
keysMap[string(k)] = struct{}{}
}
return nil
}, func() {
keys = list.New()
keysMap = make(map[string]struct{})
})
if err != nil {
return nil, err
}
log.Infof("Loaded %d mission control entries", len(keysMap))
return &missionControlStore{
done: make(chan struct{}),
db: db,
queueCond: sync.NewCond(&sync.Mutex{}),
queue: list.New(),
keys: keys,
keysMap: keysMap,
maxRecords: maxRecords,
flushInterval: flushInterval,
}, nil
}
// clear removes all results from the db.
func (b *missionControlStore) clear() error {
b.queueCond.L.Lock()
defer b.queueCond.L.Unlock()
if err := b.db.purge(); err != nil {
return err
}
b.queue = list.New()
return nil
}
// fetchAll returns all results currently stored in the database.
func (b *missionControlStore) fetchAll() ([]*paymentResult, error) {
var results []*paymentResult
err := b.db.view(func(resultBucket kvdb.RBucket) error {
results = make([]*paymentResult, 0)
return resultBucket.ForEach(func(k, v []byte) error {
result, err := deserializeResult(k, v)
if err != nil {
return err
}
results = append(results, result)
return nil
})
}, func() {
results = nil
})
if err != nil {
return nil, err
}
return results, nil
}
// serializeResult serializes a payment result and returns a key and value byte
// slice to insert into the bucket.
func serializeResult(rp *paymentResult) ([]byte, []byte, error) {
recordProducers := []tlv.RecordProducer{
&rp.timeFwd,
&rp.timeReply,
&rp.route,
}
rp.failure.WhenSome(
func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) {
recordProducers = append(recordProducers, &failure)
},
)
// Compose key that identifies this result.
key := getResultKey(rp)
var buff bytes.Buffer
err := lnwire.EncodeRecordsTo(
&buff, lnwire.ProduceRecordsSorted(recordProducers...),
)
if err != nil {
return nil, nil, err
}
return key, buff.Bytes(), nil
}
// deserializeResult deserializes a payment result.
func deserializeResult(k, v []byte) (*paymentResult, error) {
// Parse payment id.
result := paymentResult{
id: byteOrder.Uint64(k[8:]),
}
failure := tlv.ZeroRecordT[tlv.TlvType3, paymentFailure]()
recordProducers := []tlv.RecordProducer{
&result.timeFwd,
&result.timeReply,
&result.route,
&failure,
}
r := bytes.NewReader(v)
typeMap, err := lnwire.DecodeRecords(
r, lnwire.ProduceRecordsSorted(recordProducers...)...,
)
if err != nil {
return nil, err
}
if _, ok := typeMap[result.failure.TlvType()]; ok {
result.failure = tlv.SomeRecordT(failure)
}
return &result, nil
}
// serializeRoute serializes a mcRoute and writes the resulting bytes to the
// given io.Writer.
func serializeRoute(w io.Writer, r *mcRoute) error {
records := lnwire.ProduceRecordsSorted(
&r.sourcePubKey,
&r.totalAmount,
&r.hops,
)
return lnwire.EncodeRecordsTo(w, records)
}
// deserializeRoute deserializes the mcRoute from the given io.Reader.
func deserializeRoute(r io.Reader) (*mcRoute, error) {
var rt mcRoute
records := lnwire.ProduceRecordsSorted(
&rt.sourcePubKey,
&rt.totalAmount,
&rt.hops,
)
_, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
return &rt, nil
}
// deserializeHop deserializes the mcHop from the given io.Reader.
func deserializeHop(r io.Reader) (*mcHop, error) {
var (
h mcHop
blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]()
custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]()
)
records := lnwire.ProduceRecordsSorted(
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
&blinding,
&custom,
)
typeMap, err := lnwire.DecodeRecords(r, records...)
if err != nil {
return nil, err
}
if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok {
h.hasBlindingPoint = tlv.SomeRecordT(blinding)
}
if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok {
h.hasCustomRecords = tlv.SomeRecordT(custom)
}
return &h, nil
}
// serializeHop serializes a mcHop and writes the resulting bytes to the given
// io.Writer.
func serializeHop(w io.Writer, h *mcHop) error {
recordProducers := []tlv.RecordProducer{
&h.channelID,
&h.pubKeyBytes,
&h.amtToFwd,
}
h.hasBlindingPoint.WhenSome(func(
hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasBlinding)
})
h.hasCustomRecords.WhenSome(func(
hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) {
recordProducers = append(recordProducers, &hasCustom)
})
return lnwire.EncodeRecordsTo(
w, lnwire.ProduceRecordsSorted(recordProducers...),
)
}
// AddResult adds a new result to the db.
func (b *missionControlStore) AddResult(rp *paymentResult) {
b.queueCond.L.Lock()
b.queue.PushBack(rp)
b.queueCond.L.Unlock()
b.queueCond.Signal()
}
// stop stops the store ticker goroutine.
func (b *missionControlStore) stop() {
close(b.done)
b.queueCond.Signal()
b.wg.Wait()
}
// run runs the MC store ticker goroutine.
func (b *missionControlStore) run() {
b.wg.Add(1)
go func() {
defer b.wg.Done()
timer := time.NewTimer(b.flushInterval)
// Immediately stop the timer. It will be started once new
// items are added to the store. As the doc for time.Timer
// states, every call to Stop() done on a timer that is not
// known to have been fired needs to be checked and the timer's
// channel needs to be drained appropriately. This could happen
// if the flushInterval is very small (e.g. 1 nanosecond).
if !timer.Stop() {
select {
case <-timer.C:
case <-b.done:
log.Debugf("Stopping mission control store")
}
}
for {
// Wait for the queue to not be empty.
b.queueCond.L.Lock()
for b.queue.Front() == nil {
// To make sure we can properly stop, we must
// read the `done` channel first before
// attempting to call `Wait()`. This is due to
// the fact when `Signal` is called before the
// `Wait` call, the `Wait` call will block
// indefinitely.
//
// TODO(yy): replace this with channels.
select {
case <-b.done:
b.queueCond.L.Unlock()
return
default:
}
b.queueCond.Wait()
}
b.queueCond.L.Unlock()
// Restart the timer.
timer.Reset(b.flushInterval)
select {
case <-timer.C:
if err := b.storeResults(); err != nil {
log.Errorf("Failed to update mission "+
"control store: %v", err)
}
case <-b.done:
// Release the timer's resources.
if !timer.Stop() {
select {
case <-timer.C:
case <-b.done:
log.Debugf("Mission control " +
"store stopped")
}
}
return
}
}
}()
}
// storeResults stores all accumulated results.
func (b *missionControlStore) storeResults() error {
// We copy a reference to the queue and clear the original queue to be
// able to release the lock.
b.queueCond.L.Lock()
l := b.queue
if l.Len() == 0 {
b.queueCond.L.Unlock()
return nil
}
b.queue = list.New()
b.queueCond.L.Unlock()
var (
newKeys map[string]struct{}
delKeys []string
storeCount int
pruneCount int
)
// Create a deduped list of new entries.
newKeys = make(map[string]struct{}, l.Len())
for e := l.Front(); e != nil; e = e.Next() {
pr, ok := e.Value.(*paymentResult)
if !ok {
return fmt.Errorf("wrong type %T (not *paymentResult)",
e.Value)
}
key := string(getResultKey(pr))
if _, ok := b.keysMap[key]; ok {
l.Remove(e)
continue
}
if _, ok := newKeys[key]; ok {
l.Remove(e)
continue
}
newKeys[key] = struct{}{}
}
// Create a list of entries to delete.
toDelete := b.keys.Len() + len(newKeys) - b.maxRecords
if b.maxRecords > 0 && toDelete > 0 {
delKeys = make([]string, 0, toDelete)
// Delete as many as needed from old keys.
for e := b.keys.Front(); len(delKeys) < toDelete && e != nil; {
key, ok := e.Value.(string)
if !ok {
return fmt.Errorf("wrong type %T (not string)",
e.Value)
}
delKeys = append(delKeys, key)
e = e.Next()
}
// If more deletions are needed, simply do not add from the
// list of new keys.
for e := l.Front(); len(delKeys) < toDelete && e != nil; {
toDelete--
pr, ok := e.Value.(*paymentResult)
if !ok {
return fmt.Errorf("wrong type %T (not "+
"*paymentResult )", e.Value)
}
key := string(getResultKey(pr))
delete(newKeys, key)
l.Remove(e)
e = l.Front()
}
}
err := b.db.update(func(bucket kvdb.RwBucket) error {
for e := l.Front(); e != nil; e = e.Next() {
pr, ok := e.Value.(*paymentResult)
if !ok {
return fmt.Errorf("wrong type %T (not "+
"*paymentResult)", e.Value)
}
// Serialize result into key and value byte slices.
k, v, err := serializeResult(pr)
if err != nil {
return err
}
// Put into results bucket.
if err := bucket.Put(k, v); err != nil {
return err
}
storeCount++
}
// Prune oldest entries.
for _, key := range delKeys {
if err := bucket.Delete([]byte(key)); err != nil {
return err
}
pruneCount++
}
return nil
}, func() {
storeCount, pruneCount = 0, 0
})
if err != nil {
return err
}
log.Debugf("Stored mission control results: %d added, %d deleted",
storeCount, pruneCount)
// DB Update was successful, update the in-memory cache.
for _, key := range delKeys {
delete(b.keysMap, key)
b.keys.Remove(b.keys.Front())
}
for e := l.Front(); e != nil; e = e.Next() {
pr, ok := e.Value.(*paymentResult)
if !ok {
return fmt.Errorf("wrong type %T (not *paymentResult)",
e.Value)
}
key := string(getResultKey(pr))
b.keys.PushBack(key)
}
return nil
}
// getResultKey returns a byte slice representing a unique key for this payment
// result.
func getResultKey(rp *paymentResult) []byte {
var keyBytes [8 + 8 + 33]byte
// Identify records by a combination of time, payment id and sender pub
// key. This allows importing mission control data from an external
// source without key collisions and keeps the records sorted
// chronologically.
byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val)
byteOrder.PutUint64(keyBytes[8:], rp.id)
copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:])
return keyBytes[:]
}
// failureMessage wraps the lnwire.FailureMessage interface such that we can
// apply a Record method and use the failureMessage in a TLV encoded type.
type failureMessage struct {
lnwire.FailureMessage
}
// Record returns a TLV record that can be used to encode/decode a list of
// failureMessage to/from a TLV stream.
func (r *failureMessage) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeFailureMessage(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeFailureMessage, decodeFailureMessage,
)
}
func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*failureMessage); ok {
var b bytes.Buffer
err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0)
if err != nil {
return err
}
_, err = w.Write(b.Bytes())
return err
}
return tlv.NewTypeForEncodingErr(val, "routing.failureMessage")
}
func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*failureMessage); ok {
msg, err := lnwire.DecodeFailureMessage(r, 0)
if err != nil {
return err
}
*v = failureMessage{
FailureMessage: msg,
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l)
}
package routing
import (
"fmt"
"github.com/lightningnetwork/lnd/routing/route"
)
// DirectedNodePair stores a directed pair of nodes.
type DirectedNodePair struct {
From, To route.Vertex
}
// NewDirectedNodePair instantiates a new DirectedNodePair struct.
func NewDirectedNodePair(from, to route.Vertex) DirectedNodePair {
return DirectedNodePair{
From: from,
To: to,
}
}
// String converts a node pair to its human readable representation.
func (d DirectedNodePair) String() string {
return fmt.Sprintf("%v -> %v", d.From, d.To)
}
// Reverse returns a reversed copy of the pair.
func (d DirectedNodePair) Reverse() DirectedNodePair {
return DirectedNodePair{From: d.To, To: d.From}
}
package routing
import (
"bytes"
"container/heap"
"errors"
"fmt"
"math"
"sort"
"time"
"github.com/btcsuite/btcd/btcutil"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/fn/v2"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
// infinity is used as a starting distance in our shortest path search.
infinity = math.MaxInt64
// RiskFactorBillionths controls the influence of time lock delta
// of a channel on route selection. It is expressed as billionths
// of msat per msat sent through the channel per time lock delta
// block. See edgeWeight function below for more details.
// The chosen value is based on the previous incorrect weight function
// 1 + timelock + fee * fee. In this function, the fee penalty
// diminishes the time lock penalty for all but the smallest amounts.
// To not change the behaviour of path finding too drastically, a
// relatively small value is chosen which is still big enough to give
// some effect with smaller time lock values. The value may need
// tweaking and/or be made configurable in the future.
RiskFactorBillionths = 15
// estimatedNodeCount is used to preallocate the path finding structures
// to avoid resizing and copies. It should be number on the same order as
// the number of active nodes in the network.
estimatedNodeCount = 10000
// fakeHopHintCapacity is the capacity we assume for hop hint channels.
// This is a high number, which expresses that a hop hint channel should
// be able to route payments.
fakeHopHintCapacity = btcutil.Amount(10 * btcutil.SatoshiPerBitcoin)
)
// pathFinder defines the interface of a path finding algorithm.
type pathFinder = func(g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error)
var (
// DefaultEstimator is the default estimator used for computing
// probabilities in pathfinding.
DefaultEstimator = AprioriEstimatorName
// DefaultAttemptCost is the default fixed virtual cost in path finding
// of a failed payment attempt. It is used to trade off potentially
// better routes against their probability of succeeding.
DefaultAttemptCost = lnwire.NewMSatFromSatoshis(100)
// DefaultAttemptCostPPM is the default proportional virtual cost in
// path finding weight units of executing a payment attempt that fails.
// It is used to trade off potentially better routes against their
// probability of succeeding. This parameter is expressed in parts per
// million of the payment amount.
//
// It is impossible to pick a perfect default value. The current value
// of 0.1% is based on the idea that a transaction fee of 1% is within
// reasonable territory and that a payment shouldn't need more than 10
// attempts.
DefaultAttemptCostPPM = int64(1000)
// DefaultMinRouteProbability is the default minimum probability for routes
// returned from findPath.
DefaultMinRouteProbability = float64(0.01)
// DefaultAprioriHopProbability is the default a priori probability for
// a hop.
DefaultAprioriHopProbability = float64(0.6)
)
// edgePolicyWithSource is a helper struct to keep track of the source node
// of a channel edge. ChannelEdgePolicy only contains to destination node
// of the edge.
type edgePolicyWithSource struct {
sourceNode route.Vertex
edge AdditionalEdge
}
// finalHopParams encapsulates various parameters for route construction that
// apply to the final hop in a route. These features include basic payment data
// such as amounts and cltvs, as well as more complex features like destination
// custom records and payment address.
type finalHopParams struct {
amt lnwire.MilliSatoshi
totalAmt lnwire.MilliSatoshi
// cltvDelta is the final hop's minimum CLTV expiry delta.
//
// NOTE that in the case of paying to a blinded path, this value will
// be set to a duplicate of the blinded path's accumulated CLTV value.
// We would then only need to use this value in the case where the
// introduction node of the path is also the destination node.
cltvDelta uint16
records record.CustomSet
paymentAddr fn.Option[[32]byte]
// metadata is additional data that is sent along with the payment to
// the payee.
metadata []byte
}
// newRoute constructs a route using the provided path and final hop constraints.
// Any destination specific fields from the final hop params will be attached
// assuming the destination's feature vector signals support, otherwise this
// method will fail. If the route is too long, or the selected path cannot
// support the fully payment including fees, then a non-nil error is returned.
// If the route is to a blinded path, the blindedPath parameter is used to
// back fill additional fields that are required for a blinded payment. This is
// done in a separate pass to keep our route construction simple, as blinded
// paths require zero expiry and amount values for intermediate hops (which
// makes calculating the totals during route construction difficult if we
// include blinded paths on the first pass).
//
// NOTE: The passed slice of unified edges MUST be sorted in forward order: from
// the source to the target node of the path finding attempt. It is assumed that
// any feature vectors on all hops have been validated for transitive
// dependencies.
// NOTE: If a non-nil blinded path is provided it is assumed to have been
// validated by the caller.
func newRoute(sourceVertex route.Vertex,
pathEdges []*unifiedEdge, currentHeight uint32, finalHop finalHopParams,
blindedPathSet *BlindedPaymentPathSet) (*route.Route, error) {
var (
hops []*route.Hop
// totalTimeLock will accumulate the cumulative time lock
// across the entire route. This value represents how long the
// sender will need to wait in the *worst* case.
totalTimeLock = currentHeight
// nextIncomingAmount is the amount that will need to flow into
// the *next* hop. Since we're going to be walking the route
// backwards below, this next hop gets closer and closer to the
// sender of the payment.
nextIncomingAmount lnwire.MilliSatoshi
blindedPayment *BlindedPayment
)
pathLength := len(pathEdges)
// When paying to a blinded route we might have appended a dummy hop at
// the end to make MPP payments possible via all paths of the blinded
// route set. We always append a dummy hop when the internal pathfiner
// looks for a route to a blinded path which is at least one hop long
// (excluding the introduction point). We add this dummy hop so that
// we search for a universal target but also respect potential mc
// entries which might already be present for a particular blinded path.
// However when constructing the Sphinx packet we need to remove this
// dummy hop again which we do here.
//
// NOTE: The path length is always at least 1 because there must be one
// edge from the source to the destination. However we check for > 0
// just for robustness here.
if blindedPathSet != nil && pathLength > 0 {
finalBlindedPubKey := pathEdges[pathLength-1].policy.
ToNodePubKey()
if IsBlindedRouteNUMSTargetKey(finalBlindedPubKey[:]) {
// If the last hop is the NUMS key for blinded paths, we
// remove the dummy hop from the route.
pathEdges = pathEdges[:pathLength-1]
pathLength--
}
}
for i := pathLength - 1; i >= 0; i-- {
// Now we'll start to calculate the items within the per-hop
// payload for the hop this edge is leading to.
edge := pathEdges[i].policy
// If this is an edge from a blinded path and the
// blindedPayment variable has not been set yet, then set it now
// by extracting the corresponding blinded payment from the
// edge.
isBlindedEdge := pathEdges[i].blindedPayment != nil
if isBlindedEdge && blindedPayment == nil {
blindedPayment = pathEdges[i].blindedPayment
}
// We'll calculate the amounts, timelocks, and fees for each hop
// in the route. The base case is the final hop which includes
// their amount and timelocks. These values will accumulate
// contributions from the preceding hops back to the sender as
// we compute the route in reverse.
var (
amtToForward lnwire.MilliSatoshi
fee int64
totalAmtMsatBlinded lnwire.MilliSatoshi
outgoingTimeLock uint32
customRecords record.CustomSet
mpp *record.MPP
metadata []byte
)
// Define a helper function that checks this edge's feature
// vector for support for a given feature. We assume at this
// point that the feature vectors transitive dependencies have
// been validated.
supports := func(feature lnwire.FeatureBit) bool {
// If this edge comes from router hints, the features
// could be nil.
if edge.ToNodeFeatures == nil {
return false
}
return edge.ToNodeFeatures.HasFeature(feature)
}
if i == len(pathEdges)-1 {
// If this is the last hop, then the hop payload will
// contain the exact amount. In BOLT #4: Onion Routing
// Protocol / "Payload for the Last Node", this is
// detailed.
amtToForward = finalHop.amt
// Fee is not part of the hop payload, but only used for
// reporting through RPC. Set to zero for the final hop.
fee = 0
if blindedPathSet == nil {
totalTimeLock += uint32(finalHop.cltvDelta)
} else {
totalTimeLock += uint32(
blindedPathSet.FinalCLTVDelta(),
)
}
outgoingTimeLock = totalTimeLock
// Attach any custom records to the final hop.
customRecords = finalHop.records
// If we're attaching a payment addr but the receiver
// doesn't support both TLV and payment addrs, fail.
payAddr := supports(lnwire.PaymentAddrOptional)
if !payAddr && finalHop.paymentAddr.IsSome() {
return nil, errors.New("cannot attach " +
"payment addr")
}
// Otherwise attach the mpp record if it exists.
// TODO(halseth): move this to payment life cycle,
// where AMP options are set.
finalHop.paymentAddr.WhenSome(func(addr [32]byte) {
mpp = record.NewMPP(finalHop.totalAmt, addr)
})
metadata = finalHop.metadata
if blindedPathSet != nil {
totalAmtMsatBlinded = finalHop.totalAmt
}
} else {
// The amount that the current hop needs to forward is
// equal to the incoming amount of the next hop.
amtToForward = nextIncomingAmount
// The fee that needs to be paid to the current hop is
// based on the amount that this hop needs to forward
// and its policy for the outgoing channel. This policy
// is stored as part of the incoming channel of
// the next hop.
outboundFee := pathEdges[i+1].policy.ComputeFee(
amtToForward,
)
inboundFee := pathEdges[i].inboundFees.CalcFee(
amtToForward + outboundFee,
)
fee = int64(outboundFee) + inboundFee
if fee < 0 {
fee = 0
}
// We'll take the total timelock of the preceding hop as
// the outgoing timelock or this hop. Then we'll
// increment the total timelock incurred by this hop.
outgoingTimeLock = totalTimeLock
totalTimeLock += uint32(
pathEdges[i+1].policy.TimeLockDelta,
)
}
// Since we're traversing the path backwards atm, we prepend
// each new hop such that, the final slice of hops will be in
// the forwards order.
currentHop := &route.Hop{
PubKeyBytes: edge.ToNodePubKey(),
ChannelID: edge.ChannelID,
AmtToForward: amtToForward,
OutgoingTimeLock: outgoingTimeLock,
CustomRecords: customRecords,
MPP: mpp,
Metadata: metadata,
TotalAmtMsat: totalAmtMsatBlinded,
}
hops = append([]*route.Hop{currentHop}, hops...)
// Finally, we update the amount that needs to flow into the
// *next* hop, which is the amount this hop needs to forward,
// accounting for the fee that it takes.
nextIncomingAmount = amtToForward + lnwire.MilliSatoshi(fee)
}
// If we are creating a route to a blinded path, we need to add some
// additional data to the route that is required for blinded forwarding.
// We do another pass on our edges to append this data.
if blindedPathSet != nil {
// If the passed in BlindedPaymentPathSet is non-nil but no
// edge had a BlindedPayment attached, it means that the path
// chosen was an introduction-node-only path. So in this case,
// we can assume the relevant payment is the only one in the
// payment set.
if blindedPayment == nil {
var err error
blindedPayment, err = blindedPathSet.IntroNodeOnlyPath()
if err != nil {
return nil, err
}
}
var (
inBlindedRoute bool
dataIndex = 0
blindedPath = blindedPayment.BlindedPath
introVertex = route.NewVertex(
blindedPath.IntroductionPoint,
)
)
for i, hop := range hops {
// Once we locate our introduction node, we know that
// every hop after this is part of the blinded route.
if bytes.Equal(hop.PubKeyBytes[:], introVertex[:]) {
inBlindedRoute = true
hop.BlindingPoint = blindedPath.BlindingPoint
}
// We don't need to modify edges outside of our blinded
// route.
if !inBlindedRoute {
continue
}
payload := blindedPath.BlindedHops[dataIndex].CipherText
hop.EncryptedData = payload
// All of the hops in a blinded route *except* the
// final hop should have zero amounts / time locks.
if i != len(hops)-1 {
hop.AmtToForward = 0
hop.OutgoingTimeLock = 0
}
dataIndex++
}
}
// With the base routing data expressed as hops, build the full route
newRoute, err := route.NewRouteFromHops(
nextIncomingAmount, totalTimeLock, route.Vertex(sourceVertex),
hops,
)
if err != nil {
return nil, err
}
return newRoute, nil
}
// edgeWeight computes the weight of an edge. This value is used when searching
// for the shortest path within the channel graph between two nodes. Weight is
// is the fee itself plus a time lock penalty added to it. This benefits
// channels with shorter time lock deltas and shorter (hops) routes in general.
// RiskFactor controls the influence of time lock on route selection. This is
// currently a fixed value, but might be configurable in the future.
func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi,
timeLockDelta uint16) int64 {
// timeLockPenalty is the penalty for the time lock delta of this channel.
// It is controlled by RiskFactorBillionths and scales proportional
// to the amount that will pass through channel. Rationale is that it if
// a twice as large amount gets locked up, it is twice as bad.
timeLockPenalty := int64(lockedAmt) * int64(timeLockDelta) *
RiskFactorBillionths / 1000000000
return int64(fee) + timeLockPenalty
}
// graphParams wraps the set of graph parameters passed to findPath.
type graphParams struct {
// graph is the ChannelGraph to be used during path finding.
graph Graph
// additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the
// channel graph. These can either be private edges for bolt 11 invoices
// or blinded edges when a payment to a blinded path is made.
additionalEdges map[route.Vertex][]AdditionalEdge
// bandwidthHints is an interface that provides bandwidth hints that
// can provide a better estimate of the current channel bandwidth than
// what is found in the graph. It will override the capacities and
// disabled flags found in the graph for local channels when doing
// path finding if it has updated values for that channel. In
// particular, it should be set to the current available sending
// bandwidth for active local channels, and 0 for inactive channels.
bandwidthHints bandwidthHints
}
// RestrictParams wraps the set of restrictions passed to findPath that the
// found path must adhere to.
type RestrictParams struct {
// ProbabilitySource is a callback that is expected to return the
// success probability of traversing the channel from the node.
ProbabilitySource func(route.Vertex, route.Vertex,
lnwire.MilliSatoshi, btcutil.Amount) float64
// FeeLimit is a maximum fee amount allowed to be used on the path from
// the source to the target.
FeeLimit lnwire.MilliSatoshi
// OutgoingChannelIDs is the list of channels that are allowed for the
// first hop. If nil, any channel may be used.
OutgoingChannelIDs []uint64
// LastHop is the pubkey of the last node before the final destination
// is reached. If nil, any node may be used.
LastHop *route.Vertex
// CltvLimit is the maximum time lock of the route excluding the final
// ctlv. After path finding is complete, the caller needs to increase
// all cltv expiry heights with the required final cltv delta.
CltvLimit uint32
// DestCustomRecords contains the custom records to drop off at the
// final hop, if any.
DestCustomRecords record.CustomSet
// DestFeatures is a feature vector describing what the final hop
// supports. If none are provided, pathfinding will try to inspect any
// features on the node announcement instead.
DestFeatures *lnwire.FeatureVector
// PaymentAddr is a random 32-byte value generated by the receiver to
// mitigate probing vectors and payment sniping attacks on overpaid
// invoices.
PaymentAddr fn.Option[[32]byte]
// Amp signals to the pathfinder that this payment is an AMP payment
// and therefore it needs to account for additional AMP data in the
// final hop payload size calculation.
Amp *AMPOptions
// Metadata is additional data that is sent along with the payment to
// the payee.
Metadata []byte
// BlindedPaymentPathSet is necessary to determine the hop size of the
// last/exit hop.
BlindedPaymentPathSet *BlindedPaymentPathSet
// FirstHopCustomRecords includes any records that should be included in
// the update_add_htlc message towards our peer.
FirstHopCustomRecords lnwire.CustomRecords
}
// PathFindingConfig defines global parameters that control the trade-off in
// path finding between fees and probability.
type PathFindingConfig struct {
// AttemptCost is the fixed virtual cost in path finding of a failed
// payment attempt. It is used to trade off potentially better routes
// against their probability of succeeding.
AttemptCost lnwire.MilliSatoshi
// AttemptCostPPM is the proportional virtual cost in path finding of a
// failed payment attempt. It is used to trade off potentially better
// routes against their probability of succeeding. This parameter is
// expressed in parts per million of the total payment amount.
AttemptCostPPM int64
// MinProbability defines the minimum success probability of the
// returned route.
MinProbability float64
}
// getOutgoingBalance returns the maximum available balance in any of the
// channels of the given node. The second return parameters is the total
// available balance.
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
bandwidthHints bandwidthHints,
g Graph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
var max, total lnwire.MilliSatoshi
cb := func(channel *graphdb.DirectedChannel) error {
if !channel.OutPolicySet {
return nil
}
chanID := channel.ChannelID
// Enforce outgoing channel restriction.
if outgoingChans != nil {
if _, ok := outgoingChans[chanID]; !ok {
return nil
}
}
bandwidth, ok := bandwidthHints.availableChanBandwidth(
chanID, 0,
)
// If the bandwidth is not available, use the channel capacity.
// This can happen when a channel is added to the graph after
// we've already queried the bandwidth hints.
if !ok {
bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity)
}
if bandwidth > max {
max = bandwidth
}
var overflow bool
total, overflow = overflowSafeAdd(total, bandwidth)
if overflow {
// If the current total and the bandwidth would
// overflow the maximum value, we set the total to the
// maximum value. Which is more milli-satoshis than are
// in existence anyway, so the actual value is
// irrelevant.
total = lnwire.MilliSatoshi(math.MaxUint64)
}
return nil
}
// Iterate over all channels of the to node.
err := g.ForEachNodeDirectedChannel(node, cb)
if err != nil {
return 0, 0, err
}
return max, total, err
}
// findPath attempts to find a path from the source node within the ChannelGraph
// to the target node that's capable of supporting a payment of `amt` value. The
// current approach implemented is modified version of Dijkstra's algorithm to
// find a single shortest path between the source node and the destination. The
// distance metric used for edges is related to the time-lock+fee costs along a
// particular edge. If a path is found, this function returns a slice of
// ChannelHop structs which encoded the chosen path from the target to the
// source. The search is performed backwards from destination node back to
// source. This is to properly accumulate fees that need to be paid along the
// path and accurately check the amount to forward at every node against the
// available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
self, source, target route.Vertex, amt lnwire.MilliSatoshi,
timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64,
error) {
// Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to
// aid in the analysis performance problems in this area.
start := time.Now()
nodesVisited := 0
edgesExpanded := 0
defer func() {
timeElapsed := time.Since(start)
log.Debugf("Pathfinding perf metrics: nodes=%v, edges=%v, "+
"time=%v", nodesVisited, edgesExpanded, timeElapsed)
}()
// If no destination features are provided, we will load what features
// we have for the target node from our graph.
features := r.DestFeatures
if features == nil {
var err error
features, err = g.graph.FetchNodeFeatures(target)
if err != nil {
return nil, 0, err
}
}
// Ensure that the destination's features don't include unknown
// required features.
err := feature.ValidateRequired(features)
if err != nil {
log.Warnf("Pathfinding destination node features: %v", err)
return nil, 0, errUnknownRequiredFeature
}
// Ensure that all transitive dependencies are set.
err = feature.ValidateDeps(features)
if err != nil {
log.Warnf("Pathfinding destination node features: %v", err)
return nil, 0, errMissingDependentFeature
}
// Now that we know the feature vector is well-formed, we'll proceed in
// checking that it supports the features we need. If the caller has a
// payment address to attach, check that our destination feature vector
// supports them.
if r.PaymentAddr.IsSome() &&
!features.HasFeature(lnwire.PaymentAddrOptional) {
return nil, 0, errNoPaymentAddr
}
// Set up outgoing channel map for quicker access.
var outgoingChanMap map[uint64]struct{}
if len(r.OutgoingChannelIDs) > 0 {
outgoingChanMap = make(map[uint64]struct{})
for _, outChan := range r.OutgoingChannelIDs {
outgoingChanMap[outChan] = struct{}{}
}
}
// If we are routing from ourselves, check that we have enough local
// balance available.
if source == self {
max, total, err := getOutgoingBalance(
self, outgoingChanMap, g.bandwidthHints, g.graph,
)
if err != nil {
return nil, 0, err
}
// If the total outgoing balance isn't sufficient, it will be
// impossible to complete the payment.
if total < amt {
log.Warnf("Not enough outbound balance to send "+
"htlc of amount: %v, only have local "+
"balance: %v", amt, total)
return nil, 0, errInsufficientBalance
}
// If there is only not enough capacity on a single route, it
// may still be possible to complete the payment by splitting.
if max < amt {
return nil, 0, errNoPathFound
}
}
// First we'll initialize an empty heap which'll help us to quickly
// locate the next edge we should visit next during our graph
// traversal.
nodeHeap := newDistanceHeap(estimatedNodeCount)
// Holds the current best distance for a given node.
distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount)
additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource)
for vertex, additionalEdges := range g.additionalEdges {
// Edges connected to self are always included in the graph,
// therefore can be skipped. This prevents us from trying
// routes to malformed hop hints.
if vertex == self {
continue
}
// Build reverse lookup to find incoming edges. Needed because
// search is taken place from target to source.
for _, additionalEdge := range additionalEdges {
outgoingEdgePolicy := additionalEdge.EdgePolicy()
toVertex := outgoingEdgePolicy.ToNodePubKey()
incomingEdgePolicy := &edgePolicyWithSource{
sourceNode: vertex,
edge: additionalEdge,
}
additionalEdgesWithSrc[toVertex] =
append(additionalEdgesWithSrc[toVertex],
incomingEdgePolicy)
}
}
// The payload size of the final hop differ from intermediate hops
// and depends on whether the destination is blinded or not.
lastHopPayloadSize, err := lastHopPayloadSize(r, finalHtlcExpiry, amt)
if err != nil {
return nil, 0, err
}
// We can't always assume that the end destination is publicly
// advertised to the network so we'll manually include the target node.
// The target node charges no fee. Distance is set to 0, because this is
// the starting point of the graph traversal. We are searching backwards
// to get the fees first time right and correctly match channel
// bandwidth.
//
// Don't record the initial partial path in the distance map and reserve
// that key for the source key in the case we route to ourselves.
partialPath := &nodeWithDist{
dist: 0,
weight: 0,
node: target,
netAmountReceived: amt,
incomingCltv: finalHtlcExpiry,
probability: 1,
routingInfoSize: lastHopPayloadSize,
}
// Calculate the absolute cltv limit. Use uint64 to prevent an overflow
// if the cltv limit is MaxUint32.
absoluteCltvLimit := uint64(r.CltvLimit) + uint64(finalHtlcExpiry)
// Calculate the default attempt cost as configured globally.
defaultAttemptCost := float64(
cfg.AttemptCost +
amt*lnwire.MilliSatoshi(cfg.AttemptCostPPM)/1000000,
)
// Validate time preference value.
if math.Abs(timePref) > 1 {
return nil, 0, fmt.Errorf("time preference %v out of range "+
"[-1, 1]", timePref)
}
// Scale to avoid the extremes -1 and 1 which run into infinity issues.
timePref *= 0.9
// Apply time preference. At 0, the default attempt cost will
// be used.
absoluteAttemptCost := defaultAttemptCost * (1/(0.5-timePref/2) - 1)
log.Debugf("Pathfinding absolute attempt cost: %v sats",
absoluteAttemptCost/1000)
// processEdge is a helper closure that will be used to make sure edges
// satisfy our specific requirements.
processEdge := func(fromVertex route.Vertex,
edge *unifiedEdge, toNodeDist *nodeWithDist) {
edgesExpanded++
// Calculate inbound fee charged by "to" node. The exit hop
// doesn't charge inbound fees. If the "to" node is the exit
// hop, its inbound fees have already been set to zero by
// nodeEdgeUnifier.
inboundFee := edge.inboundFees.CalcFee(
toNodeDist.netAmountReceived,
)
// Make sure that the node total fee is never negative.
// Routing nodes treat a total fee that turns out
// negative as a zero fee and pathfinding should do the
// same.
minInboundFee := -int64(toNodeDist.outboundFee)
if inboundFee < minInboundFee {
inboundFee = minInboundFee
}
// Calculate amount that the candidate node would have to send
// out.
amountToSend := toNodeDist.netAmountReceived +
lnwire.MilliSatoshi(inboundFee)
// Check if accumulated fees would exceed fee limit when this
// node would be added to the path.
totalFee := int64(amountToSend) - int64(amt)
log.Trace(lnutils.NewLogClosure(func() string {
return fmt.Sprintf(
"Checking fromVertex (%v) with "+
"minInboundFee=%v, inboundFee=%v, "+
"amountToSend=%v, amt=%v, totalFee=%v",
fromVertex, minInboundFee, inboundFee,
amountToSend, amt, totalFee,
)
}))
if totalFee > 0 && lnwire.MilliSatoshi(totalFee) > r.FeeLimit {
return
}
// Request the success probability for this edge.
edgeProbability := r.ProbabilitySource(
fromVertex, toNodeDist.node, amountToSend,
edge.capacity,
)
log.Trace(lnutils.NewLogClosure(func() string {
return fmt.Sprintf("path finding probability: fromnode=%v,"+
" tonode=%v, amt=%v, cap=%v, probability=%v",
fromVertex, toNodeDist.node, amountToSend,
edge.capacity, edgeProbability)
}))
// If the probability is zero, there is no point in trying.
if edgeProbability == 0 {
return
}
// Compute fee that fromVertex is charging. It is based on the
// amount that needs to be sent to the next node in the route.
//
// Source node has no predecessor to pay a fee. Therefore set
// fee to zero, because it should not be included in the fee
// limit check and edge weight.
//
// Also determine the time lock delta that will be added to the
// route if fromVertex is selected. If fromVertex is the source
// node, no additional timelock is required.
var (
timeLockDelta uint16
outboundFee int64
)
if fromVertex != source {
outboundFee = int64(
edge.policy.ComputeFee(amountToSend),
)
timeLockDelta = edge.policy.TimeLockDelta
}
incomingCltv := toNodeDist.incomingCltv + int32(timeLockDelta)
// Check that we are within our CLTV limit.
if uint64(incomingCltv) > absoluteCltvLimit {
return
}
// netAmountToReceive is the amount that the node that is added
// to the distance map needs to receive from a (to be found)
// previous node in the route. The inbound fee of the receiving
// node is already subtracted from this value. The previous node
// will need to pay the amount that this node forwards plus the
// fee it charges plus this node's inbound fee.
netAmountToReceive := amountToSend +
lnwire.MilliSatoshi(outboundFee)
// Calculate total probability of successfully reaching target
// by multiplying the probabilities. Both this edge and the rest
// of the route must succeed.
probability := toNodeDist.probability * edgeProbability
// If the probability is below the specified lower bound, we can
// abandon this direction. Adding further nodes can only lower
// the probability more.
if probability < cfg.MinProbability {
return
}
// Calculate the combined fee for this edge. Dijkstra does not
// support negative edge weights. Because this fee feeds into
// the edge weight calculation, we don't allow it to be
// negative.
signedFee := inboundFee + outboundFee
fee := lnwire.MilliSatoshi(0)
if signedFee > 0 {
fee = lnwire.MilliSatoshi(signedFee)
}
// By adding fromVertex in the route, there will be an extra
// weight composed of the fee that this node will charge and
// the amount that will be locked for timeLockDelta blocks in
// the HTLC that is handed out to fromVertex.
weight := edgeWeight(amountToSend, fee, timeLockDelta)
// Compute the tentative weight to this new channel/edge
// which is the weight from our toNode to the target node
// plus the weight of this edge.
tempWeight := toNodeDist.weight + weight
// Add an extra factor to the weight to take into account the
// probability. Another reason why we rounded the fee up to zero
// is to prevent a highly negative fee from cancelling out the
// extra factor. We don't want an always-failing node to attract
// traffic using a highly negative fee and escape penalization.
tempDist := getProbabilityBasedDist(
tempWeight, probability,
absoluteAttemptCost,
)
// If there is already a best route stored, compare this
// candidate route with the best route so far.
current, ok := distance[fromVertex]
if ok {
// If this route is worse than what we already found,
// skip this route.
if tempDist > current.dist {
return
}
// If the route is equally good and the probability
// isn't better, skip this route. It is important to
// also return if both cost and probability are equal,
// because otherwise the algorithm could run into an
// endless loop.
probNotBetter := probability <= current.probability
if tempDist == current.dist && probNotBetter {
return
}
}
// Calculate the total routing info size if this hop were to be
// included. If we are coming from the source hop, the payload
// size is zero, because the original htlc isn't in the onion
// blob.
//
// NOTE: For blinded paths with the NUMS key as the last hop,
// the payload size accounts for this dummy hop which is of
// the same size as the real last hop. So we account for a
// bigger size than the route is however we accept this
// little inaccuracy here because we are over estimating by
// 1 hop.
var payloadSize uint64
if fromVertex != source {
// In case the unifiedEdge does not have a payload size
// function supplied we request a graceful shutdown
// because this should never happen.
if edge.hopPayloadSizeFn == nil {
log.Criticalf("No payload size function "+
"available for edge=%v unable to "+
"determine payload size: %v", edge,
ErrNoPayLoadSizeFunc)
return
}
payloadSize = edge.hopPayloadSizeFn(
amountToSend,
uint32(toNodeDist.incomingCltv),
edge.policy.ChannelID,
)
}
routingInfoSize := toNodeDist.routingInfoSize + payloadSize
// Skip paths that would exceed the maximum routing info size.
if routingInfoSize > sphinx.MaxPayloadSize {
return
}
// All conditions are met and this new tentative distance is
// better than the current best known distance to this node.
// The new better distance is recorded, and also our "next hop"
// map is populated with this edge.
withDist := &nodeWithDist{
dist: tempDist,
weight: tempWeight,
node: fromVertex,
netAmountReceived: netAmountToReceive,
outboundFee: lnwire.MilliSatoshi(outboundFee),
incomingCltv: incomingCltv,
probability: probability,
nextHop: edge,
routingInfoSize: routingInfoSize,
}
distance[fromVertex] = withDist
// Either push withDist onto the heap if the node
// represented by fromVertex is not already on the heap OR adjust
// its position within the heap via heap.Fix.
nodeHeap.PushOrFix(withDist)
}
// TODO(roasbeef): also add path caching
// * similar to route caching, but doesn't factor in the amount
// Cache features because we visit nodes multiple times.
featureCache := make(map[route.Vertex]*lnwire.FeatureVector)
// getGraphFeatures returns (cached) node features from the graph.
getGraphFeatures := func(node route.Vertex) (*lnwire.FeatureVector,
error) {
// Check cache for features of the fromNode.
fromFeatures, ok := featureCache[node]
if ok {
return fromFeatures, nil
}
// Fetch node features fresh from the graph.
fromFeatures, err := g.graph.FetchNodeFeatures(node)
if err != nil {
return nil, err
}
// Don't route through nodes that contain unknown required
// features and mark as nil in the cache.
err = feature.ValidateRequired(fromFeatures)
if err != nil {
featureCache[node] = nil
return nil, nil
}
// Don't route through nodes that don't properly set all
// transitive feature dependencies and mark as nil in the cache.
err = feature.ValidateDeps(fromFeatures)
if err != nil {
featureCache[node] = nil
return nil, nil
}
// Update cache.
featureCache[node] = fromFeatures
return fromFeatures, nil
}
routeToSelf := source == target
for {
nodesVisited++
pivot := partialPath.node
isExitHop := partialPath.nextHop == nil
// Create unified policies for all incoming connections. Don't
// use inbound fees for the exit hop.
u := newNodeEdgeUnifier(
self, pivot, !isExitHop, outgoingChanMap,
)
err := u.addGraphPolicies(g.graph)
if err != nil {
return nil, 0, err
}
// We add hop hints that were supplied externally.
for _, reverseEdge := range additionalEdgesWithSrc[pivot] {
// Assume zero inbound fees for route hints. If inbound
// fees would apply, they couldn't be communicated in
// bolt11 invoices currently.
inboundFee := models.InboundFee{}
// Hop hints don't contain a capacity. We set one here,
// since a capacity is needed for probability
// calculations. We set a high capacity to act as if
// there is enough liquidity, otherwise the hint would
// not have been added by a wallet.
// We also pass the payload size function to the
// graph data so that we calculate the exact payload
// size when evaluating this hop for a route.
u.addPolicy(
reverseEdge.sourceNode,
reverseEdge.edge.EdgePolicy(),
inboundFee,
fakeHopHintCapacity,
reverseEdge.edge.IntermediatePayloadSize,
reverseEdge.edge.BlindedPayment(),
)
}
netAmountReceived := partialPath.netAmountReceived
// Expand all connections using the optimal policy for each
// connection.
for fromNode, edgeUnifier := range u.edgeUnifiers {
// The target node is not recorded in the distance map.
// Therefore we need to have this check to prevent
// creating a cycle. Only when we intend to route to
// self, we allow this cycle to form. In that case we'll
// also break out of the search loop below.
if !routeToSelf && fromNode == target {
continue
}
// Apply last hop restriction if set.
if r.LastHop != nil &&
pivot == target && fromNode != *r.LastHop {
continue
}
edge := edgeUnifier.getEdge(
netAmountReceived, g.bandwidthHints,
partialPath.outboundFee,
)
if edge == nil {
continue
}
// Get feature vector for fromNode.
fromFeatures, err := getGraphFeatures(fromNode)
if err != nil {
return nil, 0, err
}
// If there are no valid features, skip this node.
if fromFeatures == nil {
continue
}
// Check if this candidate node is better than what we
// already have.
processEdge(fromNode, edge, partialPath)
}
if nodeHeap.Len() == 0 {
break
}
// Fetch the node within the smallest distance from our source
// from the heap.
partialPath = heap.Pop(&nodeHeap).(*nodeWithDist)
// If we've reached our source (or we don't have any incoming
// edges), then we're done here and can exit the graph
// traversal early.
if partialPath.node == source {
break
}
}
// Use the distance map to unravel the forward path from source to
// target.
var pathEdges []*unifiedEdge
currentNode := source
for {
// Determine the next hop forward using the next map.
currentNodeWithDist, ok := distance[currentNode]
if !ok {
// If the node doesn't have a next hop it means we
// didn't find a path.
return nil, 0, errNoPathFound
}
// Add the next hop to the list of path edges.
pathEdges = append(pathEdges, currentNodeWithDist.nextHop)
// Advance current node.
currentNode = currentNodeWithDist.nextHop.policy.ToNodePubKey()
// Check stop condition at the end of this loop. This prevents
// breaking out too soon for self-payments that have target set
// to source.
if currentNode == target {
break
}
}
// For the final hop, we'll set the node features to those determined
// above. These are either taken from the destination features, e.g.
// virtual or invoice features, or loaded as a fallback from the graph.
// The transitive dependencies were already validated above, so no need
// to do so now.
//
// NOTE: This may overwrite features loaded from the graph if
// destination features were provided. This is fine though, since our
// route construction does not care where the features are actually
// taken from. In the future we may wish to do route construction within
// findPath, and avoid using ChannelEdgePolicy altogether.
pathEdges[len(pathEdges)-1].policy.ToNodeFeatures = features
log.Debugf("Found route: probability=%v, hops=%v, fee=%v",
distance[source].probability, len(pathEdges),
distance[source].netAmountReceived-amt)
return pathEdges, distance[source].probability, nil
}
// blindedPathRestrictions are a set of constraints to adhere to when
// choosing a set of blinded paths to this node.
type blindedPathRestrictions struct {
// minNumHops is the minimum number of hops to include in a blinded
// path. This doesn't include our node, so if the minimum is 1, then
// the path will contain at minimum our node along with an introduction
// node hop. A minimum of 0 will include paths where this node is the
// introduction node and so should be used with caution.
minNumHops uint8
// maxNumHops is the maximum number of hops to include in a blinded
// path. This doesn't include our node, so if the maximum is 1, then
// the path will contain our node along with an introduction node hop.
maxNumHops uint8
// nodeOmissionSet holds a set of node IDs of nodes that we should
// ignore during blinded path selection.
nodeOmissionSet fn.Set[route.Vertex]
}
// blindedHop holds the information about a hop we have selected for a blinded
// path.
type blindedHop struct {
vertex route.Vertex
channelID uint64
edgeCapacity btcutil.Amount
}
// findBlindedPaths does a depth first search from the target node to find a set
// of blinded paths to the target node given the set of restrictions. This
// function will select and return any candidate path. A candidate path is a
// path to the target node with a size determined by the given hop number
// constraints where all the nodes on the path signal the route blinding feature
// _and_ the introduction node for the path has more than one public channel.
// Any filtering of paths based on payment value or success probabilities is
// left to the caller.
func findBlindedPaths(g Graph, target route.Vertex,
restrictions *blindedPathRestrictions) ([][]blindedHop, error) {
// Sanity check the restrictions.
if restrictions.minNumHops > restrictions.maxNumHops {
return nil, fmt.Errorf("maximum number of blinded path hops "+
"(%d) must be greater than or equal to the minimum "+
"number of hops (%d)", restrictions.maxNumHops,
restrictions.minNumHops)
}
// If the node is not the destination node, then it is required that the
// node advertise the route blinding feature-bit in order for it to be
// chosen as a node on the blinded path.
supportsRouteBlinding := func(node route.Vertex) (bool, error) {
if node == target {
return true, nil
}
features, err := g.FetchNodeFeatures(node)
if err != nil {
return false, err
}
return features.HasFeature(lnwire.RouteBlindingOptional), nil
}
// This function will have some recursion. We will spin out from the
// target node & append edges to the paths until we reach various exit
// conditions such as: The maxHops number being reached or reaching
// a node that doesn't have any other edges - in that final case, the
// whole path should be ignored.
paths, _, err := processNodeForBlindedPath(
g, target, supportsRouteBlinding, nil, restrictions,
)
if err != nil {
return nil, err
}
// Reverse each path so that the order is correct (from introduction
// node to last hop node) and then append this node on as the
// destination of each path.
orderedPaths := make([][]blindedHop, len(paths))
for i, path := range paths {
sort.Slice(path, func(i, j int) bool {
return j < i
})
orderedPaths[i] = append(path, blindedHop{vertex: target})
}
// Handle the special case that allows a blinded path with the
// introduction node as the destination node.
if restrictions.minNumHops == 0 {
singleHopPath := [][]blindedHop{{{vertex: target}}}
//nolint:makezero
orderedPaths = append(
orderedPaths, singleHopPath...,
)
}
return orderedPaths, err
}
// processNodeForBlindedPath is a recursive function that traverses the graph
// in a depth first manner searching for a set of blinded paths to the given
// node.
func processNodeForBlindedPath(g Graph, node route.Vertex,
supportsRouteBlinding func(vertex route.Vertex) (bool, error),
alreadyVisited map[route.Vertex]bool,
restrictions *blindedPathRestrictions) ([][]blindedHop, bool, error) {
// If we have already visited the maximum number of hops, then this path
// is complete and we can exit now.
if len(alreadyVisited) > int(restrictions.maxNumHops) {
return nil, false, nil
}
// If we have already visited this peer on this path, then we skip
// processing it again.
if alreadyVisited[node] {
return nil, false, nil
}
// If we have explicitly been told to ignore this node for blinded paths
// then we skip it too.
if restrictions.nodeOmissionSet.Contains(node) {
return nil, false, nil
}
supports, err := supportsRouteBlinding(node)
if err != nil {
return nil, false, err
}
if !supports {
return nil, false, nil
}
// At this point, copy the alreadyVisited map.
visited := make(map[route.Vertex]bool, len(alreadyVisited))
for r := range alreadyVisited {
visited[r] = true
}
// Add this node the visited set.
visited[node] = true
var (
hopSets [][]blindedHop
chanCount int
)
// Now, iterate over the node's channels in search for paths to this
// node that can be used for blinded paths
err = g.ForEachNodeDirectedChannel(node,
func(channel *graphdb.DirectedChannel) error {
// Keep track of how many incoming channels this node
// has. We only use a node as an introduction node if it
// has channels other than the one that lead us to it.
chanCount++
// Process each channel peer to gather any paths that
// lead to the peer.
nextPaths, hasMoreChans, err := processNodeForBlindedPath( //nolint:ll
g, channel.OtherNode, supportsRouteBlinding,
visited, restrictions,
)
if err != nil {
return err
}
hop := blindedHop{
vertex: channel.OtherNode,
channelID: channel.ChannelID,
edgeCapacity: channel.Capacity,
}
// For each of the paths returned, unwrap them and
// append this hop to them.
for _, path := range nextPaths {
hopSets = append(
hopSets,
append([]blindedHop{hop}, path...),
)
}
// If this node does have channels other than the one
// that lead to it, and if the hop count up to this node
// meets the minHop requirement, then we also add a
// path that starts at this node.
if hasMoreChans &&
len(visited) >= int(restrictions.minNumHops) {
hopSets = append(hopSets, []blindedHop{hop})
}
return nil
},
)
if err != nil {
return nil, false, err
}
return hopSets, chanCount > 1, nil
}
// getProbabilityBasedDist converts a weight into a distance that takes into
// account the success probability and the (virtual) cost of a failed payment
// attempt.
//
// Derivation:
//
// Suppose there are two routes A and B with fees Fa and Fb and success
// probabilities Pa and Pb.
//
// Is the expected cost of trying route A first and then B lower than trying the
// other way around?
//
// The expected cost of A-then-B is: Pa*Fa + (1-Pa)*Pb*(c+Fb)
//
// The expected cost of B-then-A is: Pb*Fb + (1-Pb)*Pa*(c+Fa)
//
// In these equations, the term representing the case where both A and B fail is
// left out because its value would be the same in both cases.
//
// Pa*Fa + (1-Pa)*Pb*(c+Fb) < Pb*Fb + (1-Pb)*Pa*(c+Fa)
//
// Pa*Fa + Pb*c + Pb*Fb - Pa*Pb*c - Pa*Pb*Fb < Pb*Fb + Pa*c + Pa*Fa - Pa*Pb*c - Pa*Pb*Fa
//
// Removing terms that cancel out:
// Pb*c - Pa*Pb*Fb < Pa*c - Pa*Pb*Fa
//
// Divide by Pa*Pb:
// c/Pa - Fb < c/Pb - Fa
//
// Move terms around:
// Fa + c/Pa < Fb + c/Pb
//
// So the value of F + c/P can be used to compare routes.
func getProbabilityBasedDist(weight int64, probability float64,
penalty float64) int64 {
// Prevent divide by zero by returning early.
if probability == 0 {
return infinity
}
// Calculate distance.
dist := float64(weight) + penalty/probability
// Avoid cast if an overflow would occur. The maxFloat constant is
// chosen to stay well below the maximum float64 value that is still
// convertible to int64.
const maxFloat = 9000000000000000000
if dist > maxFloat {
return infinity
}
return int64(dist)
}
// lastHopPayloadSize calculates the payload size of the final hop in a route.
// It depends on the tlv types which are present and also whether the hop is
// part of a blinded route or not.
func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32,
amount lnwire.MilliSatoshi) (uint64, error) {
if r.BlindedPaymentPathSet != nil {
paymentPath, err := r.BlindedPaymentPathSet.
LargestLastHopPayloadPath()
if err != nil {
return 0, err
}
blindedPath := paymentPath.BlindedPath.BlindedHops
blindedPoint := paymentPath.BlindedPath.BlindingPoint
encryptedData := blindedPath[len(blindedPath)-1].CipherText
finalHop := route.Hop{
AmtToForward: amount,
OutgoingTimeLock: uint32(finalHtlcExpiry),
EncryptedData: encryptedData,
}
if len(blindedPath) == 1 {
finalHop.BlindingPoint = blindedPoint
}
// The final hop does not have a short chanID set.
return finalHop.PayloadSize(0), nil
}
var mpp *record.MPP
r.PaymentAddr.WhenSome(func(addr [32]byte) {
mpp = record.NewMPP(amount, addr)
})
var amp *record.AMP
if r.Amp != nil {
// The AMP payload is not easy accessible at this point but we
// are only interested in the size of the payload so we just use
// the AMP record dummy.
amp = &record.MaxAmpPayLoadSize
}
finalHop := route.Hop{
AmtToForward: amount,
OutgoingTimeLock: uint32(finalHtlcExpiry),
CustomRecords: r.DestCustomRecords,
MPP: mpp,
AMP: amp,
Metadata: r.Metadata,
}
// The final hop does not have a short chanID set.
return finalHop.PayloadSize(0), nil
}
// overflowSafeAdd adds two MilliSatoshi values and returns the result. If an
// overflow could occur, zero is returned instead and the boolean is set to
// true.
func overflowSafeAdd(x, y lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) {
if y > math.MaxUint64-x {
// Overflow would occur, return 0 and set overflow flag.
return 0, true
}
return x + y, false
}
package routing
import (
"context"
"errors"
"fmt"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/tlv"
)
// ErrPaymentLifecycleExiting is used when waiting for htlc attempt result, but
// the payment lifecycle is exiting .
var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting")
// switchResult is the result sent back from the switch after processing the
// HTLC.
type switchResult struct {
// attempt is the HTLC sent to the switch.
attempt *channeldb.HTLCAttempt
// result is sent from the switch which contains either a preimage if
// ths HTLC is settled or an error if it's failed.
result *htlcswitch.PaymentResult
}
// paymentLifecycle holds all information about the current state of a payment
// needed to resume if from any point.
type paymentLifecycle struct {
router *ChannelRouter
feeLimit lnwire.MilliSatoshi
identifier lntypes.Hash
paySession PaymentSession
shardTracker shards.ShardTracker
currentHeight int32
firstHopCustomRecords lnwire.CustomRecords
// quit is closed to signal the sub goroutines of the payment lifecycle
// to stop.
quit chan struct{}
// resultCollected is used to send the result returned from the switch
// for a given HTLC attempt.
resultCollected chan *switchResult
// resultCollector is a function that is used to collect the result of
// an HTLC attempt, which is always mounted to `p.collectResultAsync`
// except in unit test, where we use a much simpler resultCollector to
// decouple the test flow for the payment lifecycle.
resultCollector func(attempt *channeldb.HTLCAttempt)
}
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, paySession PaymentSession,
shardTracker shards.ShardTracker, currentHeight int32,
firstHopCustomRecords lnwire.CustomRecords) *paymentLifecycle {
p := &paymentLifecycle{
router: r,
feeLimit: feeLimit,
identifier: identifier,
paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight,
quit: make(chan struct{}),
resultCollected: make(chan *switchResult, 1),
firstHopCustomRecords: firstHopCustomRecords,
}
// Mount the result collector.
p.resultCollector = p.collectResultAsync
return p
}
// calcFeeBudget returns the available fee to be used for sending HTLC
// attempts.
func (p *paymentLifecycle) calcFeeBudget(
feesPaid lnwire.MilliSatoshi) lnwire.MilliSatoshi {
budget := p.feeLimit
// We'll subtract the used fee from our fee budget. In case of
// overflow, we need to check whether feesPaid exceeds our budget
// already.
if feesPaid <= budget {
budget -= feesPaid
} else {
budget = 0
}
return budget
}
// stateStep defines an action to be taken in our payment lifecycle. We either
// quit, continue, or exit the lifecycle, see details below.
type stateStep uint8
const (
// stepSkip is used when we need to skip the current lifecycle and jump
// to the next one.
stepSkip stateStep = iota
// stepProceed is used when we can proceed the current lifecycle.
stepProceed
// stepExit is used when we need to quit the current lifecycle.
stepExit
)
// decideNextStep is used to determine the next step in the payment lifecycle.
// It first checks whether the current state of the payment allows more HTLC
// attempts to be made. If allowed, it will return so the lifecycle can continue
// making new attempts. Otherwise, it checks whether we need to wait for the
// results of already sent attempts. If needed, it will block until one of the
// results is sent back. then process its result here. When there's no need to
// wait for results, the method will exit with `stepExit` such that the payment
// lifecycle loop will terminate.
func (p *paymentLifecycle) decideNextStep(
payment DBMPPayment) (stateStep, error) {
// Check whether we could make new HTLC attempts.
allow, err := payment.AllowMoreAttempts()
if err != nil {
return stepExit, err
}
// Exit early we need to make more attempts.
if allow {
return stepProceed, nil
}
// We cannot make more attempts, we now check whether we need to wait
// for results.
wait, err := payment.NeedWaitAttempts()
if err != nil {
return stepExit, err
}
// If we are not allowed to make new HTLC attempts and there's no need
// to wait, the lifecycle is done and we can exit.
if !wait {
return stepExit, nil
}
log.Tracef("Waiting for attempt results for payment %v", p.identifier)
// Otherwise we wait for the result for one HTLC attempt then continue
// the lifecycle.
select {
case r := <-p.resultCollected:
log.Tracef("Received attempt result for payment %v",
p.identifier)
// Handle the result here. If there's no error, we will return
// stepSkip and move to the next lifecycle iteration, which will
// refresh the payment and wait for the next attempt result, if
// any.
_, err := p.handleAttemptResult(r.attempt, r.result)
// We would only get a DB-related error here, which will cause
// us to abort the payment flow.
if err != nil {
return stepExit, err
}
case <-p.quit:
return stepExit, ErrPaymentLifecycleExiting
case <-p.router.quit:
return stepExit, ErrRouterShuttingDown
}
return stepSkip, nil
}
// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {
// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
defer p.stop()
// If we had any existing attempts outstanding, we'll start by spinning
// up goroutines that'll collect their results and deliver them to the
// lifecycle loop below.
payment, err := p.reloadInflightAttempts()
if err != nil {
return [32]byte{}, nil, err
}
// Get the payment status.
status := payment.GetStatus()
// exitWithErr is a helper closure that logs and returns an error.
exitWithErr := func(err error) ([32]byte, *route.Route, error) {
// Log an error with the latest payment status.
//
// NOTE: this `status` variable is reassigned in the loop
// below. We could also call `payment.GetStatus` here, but in a
// rare case when the critical log is triggered when using
// postgres as db backend, the `payment` could be nil, causing
// the payment fetching to return an error.
log.Errorf("Payment %v with status=%v failed: %v", p.identifier,
status, err)
return [32]byte{}, nil, err
}
// We'll continue until either our payment succeeds, or we encounter a
// critical error during path finding.
lifecycle:
for {
// We update the payment state on every iteration.
currentPayment, ps, err := p.reloadPayment()
if err != nil {
return exitWithErr(err)
}
// Reassign status so it can be read in `exitWithErr`.
status = currentPayment.GetStatus()
// Reassign payment such that when the lifecycle exits, the
// latest payment can be read when we access its terminal info.
payment = currentPayment
// We now proceed our lifecycle with the following tasks in
// order,
// 1. check context.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if we've
// gone past the payment attempt timeout, or if the context was
// cancelled, or the router is exiting. In any of these cases,
// we'll stop this payment attempt short.
if err := p.checkContext(ctx); err != nil {
return exitWithErr(err)
}
// Now decide the next step of the current lifecycle.
step, err := p.decideNextStep(payment)
if err != nil {
return exitWithErr(err)
}
switch step {
// Exit the for loop and return below.
case stepExit:
break lifecycle
// Continue the for loop and skip the rest.
case stepSkip:
continue lifecycle
// Continue the for loop and proceed the rest.
case stepProceed:
// Unknown step received, exit with an error.
default:
err = fmt.Errorf("unknown step: %v", step)
return exitWithErr(err)
}
// Now request a route to be used to create our HTLC attempt.
rt, err := p.requestRoute(ps)
if err != nil {
return exitWithErr(err)
}
// We may not be able to find a route for current attempt. In
// that case, we continue the loop and move straight to the
// next iteration in case there are results for inflight HTLCs
// that still need to be collected.
if rt == nil {
log.Errorf("No route found for payment %v",
p.identifier)
continue lifecycle
}
log.Tracef("Found route: %s", spew.Sdump(rt.Hops))
// We found a route to try, create a new HTLC attempt to try.
attempt, err := p.registerAttempt(rt, ps.RemainingAmt)
if err != nil {
return exitWithErr(err)
}
// Once the attempt is created, send it to the htlcswitch.
result, err := p.sendAttempt(attempt)
if err != nil {
return exitWithErr(err)
}
// Now that the shard was successfully sent, launch a go
// routine that will handle its result when its back.
if result.err == nil {
p.resultCollector(attempt)
}
}
// Once we are out the lifecycle loop, it means we've reached a
// terminal condition. We either return the settled preimage or the
// payment's failure reason.
//
// Optionally delete the failed attempts from the database.
err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier)
if err != nil {
log.Errorf("Error deleting failed htlc attempts for payment "+
"%v: %v", p.identifier, err)
}
htlc, failure := payment.TerminalInfo()
if htlc != nil {
return htlc.Settle.Preimage, &htlc.Route, nil
}
// Otherwise return the payment failure reason.
return [32]byte{}, nil, *failure
}
// checkContext checks whether the payment context has been canceled.
// Cancellation occurs manually or if the context times out.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
select {
case <-ctx.Done():
// If the context was canceled, we'll mark the payment as
// failed. There are two cases to distinguish here: Either a
// user-provided timeout was reached, or the context was
// canceled, either to a manual cancellation or due to an
// unknown error.
var reason channeldb.FailureReason
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
reason = channeldb.FailureReasonTimeout
log.Warnf("Payment attempt not completed before "+
"context timeout, id=%s", p.identifier.String())
} else {
reason = channeldb.FailureReasonCanceled
log.Warnf("Payment attempt context canceled, id=%s",
p.identifier.String())
}
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}
case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)
// Fall through if we haven't hit our time limit.
default:
}
return nil
}
// requestRoute is responsible for finding a route to be used to create an HTLC
// attempt.
func (p *paymentLifecycle) requestRoute(
ps *channeldb.MPPaymentState) (*route.Route, error) {
remainingFees := p.calcFeeBudget(ps.FeesPaid)
// Query our payment session to construct a route.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
p.firstHopCustomRecords,
)
// Exit early if there's no error.
if err == nil {
// Allow the traffic shaper to add custom records to the
// outgoing HTLC and also adjust the amount if needed.
err = p.amendFirstHopData(rt)
if err != nil {
return nil, err
}
return rt, nil
}
// Otherwise we need to handle the error.
log.Warnf("Failed to find route for payment %v: %v", p.identifier, err)
// If the error belongs to `noRouteError` set, it means a non-critical
// error has happened during path finding, and we will mark the payment
// failed with this reason. Otherwise, we'll return the critical error
// found to abort the lifecycle.
var routeErr noRouteError
if !errors.As(err, &routeErr) {
return nil, fmt.Errorf("requestRoute got: %w", err)
}
// It's the `paymentSession`'s responsibility to find a route for us
// with the best effort. When it cannot find a path, we need to treat it
// as a terminal condition and fail the payment no matter it has
// inflight HTLCs or not.
failureCode := routeErr.FailureReason()
log.Warnf("Marking payment %v permanently failed with no route: %v",
p.identifier, failureCode)
err = p.router.cfg.Control.FailPayment(p.identifier, failureCode)
if err != nil {
return nil, fmt.Errorf("FailPayment got: %w", err)
}
// NOTE: we decide to not return the non-critical noRouteError here to
// avoid terminating the payment lifecycle as there might be other
// inflight HTLCs which we must wait for their results.
return nil, nil
}
// stop signals any active shard goroutine to exit.
func (p *paymentLifecycle) stop() {
close(p.quit)
}
// attemptResult holds the HTLC attempt and a possible error returned from
// sending it.
type attemptResult struct {
// err is non-nil if a non-critical error was encountered when trying
// to send the attempt, and we successfully updated the control tower
// to reflect this error. This can be errors like not enough local
// balance for the given route etc.
err error
// attempt is the attempt structure as recorded in the database.
attempt *channeldb.HTLCAttempt
}
// collectResultAsync launches a goroutine that will wait for the result of the
// given HTLC attempt to be available then save its result in a map. Once
// received, it will send the result returned from the switch to channel
// `resultCollected`.
func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) {
log.Debugf("Collecting result for attempt %v in payment %v",
attempt.AttemptID, p.identifier)
go func() {
result, err := p.collectResult(attempt)
if err != nil {
log.Errorf("Error collecting result for attempt %v in "+
"payment %v: %v", attempt.AttemptID,
p.identifier, err)
return
}
log.Debugf("Result collected for attempt %v in payment %v",
attempt.AttemptID, p.identifier)
// Create a switch result and send it to the resultCollected
// chan, which gets processed when the lifecycle is waiting for
// a result to be received in decideNextStep.
r := &switchResult{
attempt: attempt,
result: result,
}
// Signal that a result has been collected.
select {
// Send the result so decideNextStep can proceed.
case p.resultCollected <- r:
case <-p.quit:
log.Debugf("Lifecycle exiting while collecting "+
"result for payment %v", p.identifier)
case <-p.router.quit:
}
}()
}
// collectResult waits for the result of the given HTLC attempt to be sent by
// the switch and returns it.
func (p *paymentLifecycle) collectResult(
attempt *channeldb.HTLCAttempt) (*htlcswitch.PaymentResult, error) {
log.Tracef("Collecting result for attempt %v",
lnutils.SpewLogClosure(attempt))
result := &htlcswitch.PaymentResult{}
// Regenerate the circuit for this attempt.
circuit, err := attempt.Circuit()
// TODO(yy): We generate this circuit to create the error decryptor,
// which is then used in htlcswitch as the deobfuscator to decode the
// error from `UpdateFailHTLC`. However, suppose it's an
// `UpdateFulfillHTLC` message yet for some reason the sphinx packet is
// failed to be generated, we'd miss settling it. This means we should
// give it a second chance to try the settlement path in case
// `GetAttemptResult` gives us back the preimage. And move the circuit
// creation into htlcswitch so it's only constructed when there's a
// failure message we need to decode.
if err != nil {
log.Debugf("Unable to generate circuit for attempt %v: %v",
attempt.AttemptID, err)
return nil, err
}
// Using the created circuit, initialize the error decrypter, so we can
// parse+decode any failures incurred by this payment within the
// switch.
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
}
// Now ask the switch to return the result of the payment when
// available.
//
// TODO(yy): consider using htlcswitch to create the `errorDecryptor`
// since the htlc is already in db. This will also make the interface
// `PaymentAttemptDispatcher` deeper and easier to use. Moreover, we'd
// only create the decryptor when received a failure, further saving us
// a few CPU cycles.
resultChan, err := p.router.cfg.Payer.GetAttemptResult(
attempt.AttemptID, p.identifier, errorDecryptor,
)
// Handle the switch error.
if err != nil {
log.Errorf("Failed getting result for attemptID %d "+
"from switch: %v", attempt.AttemptID, err)
result.Error = err
return result, nil
}
// The switch knows about this payment, we'll wait for a result to be
// available.
select {
case r, ok := <-resultChan:
if !ok {
return nil, htlcswitch.ErrSwitchExiting
}
result = r
case <-p.quit:
return nil, ErrPaymentLifecycleExiting
case <-p.router.quit:
return nil, ErrRouterShuttingDown
}
return result, nil
}
// registerAttempt is responsible for creating and saving an HTLC attempt in db
// by using the route info provided. The `remainingAmt` is used to decide
// whether this is the last attempt.
func (p *paymentLifecycle) registerAttempt(rt *route.Route,
remainingAmt lnwire.MilliSatoshi) (*channeldb.HTLCAttempt, error) {
// If this route will consume the last remaining amount to send
// to the receiver, this will be our last shard (for now).
isLastAttempt := rt.ReceiverAmt() == remainingAmt
// Using the route received from the payment session, create a new
// shard to send.
attempt, err := p.createNewPaymentAttempt(rt, isLastAttempt)
if err != nil {
return nil, err
}
// Before sending this HTLC to the switch, we checkpoint the fresh
// paymentID and route to the DB. This lets us know on startup the ID
// of the payment that we attempted to send, such that we can query the
// Switch for its whereabouts. The route is needed to handle the result
// when it eventually comes back.
err = p.router.cfg.Control.RegisterAttempt(
p.identifier, &attempt.HTLCAttemptInfo,
)
return attempt, err
}
// createNewPaymentAttempt creates a new payment attempt from the given route.
func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route,
lastShard bool) (*channeldb.HTLCAttempt, error) {
// Generate a new key to be used for this attempt.
sessionKey, err := generateNewSessionKey()
if err != nil {
return nil, err
}
// We generate a new, unique payment ID that we will use for
// this HTLC.
attemptID, err := p.router.cfg.NextPaymentID()
if err != nil {
return nil, err
}
// Request a new shard from the ShardTracker. If this is an AMP
// payment, and this is the last shard, the outstanding shards together
// with this one will be enough for the receiver to derive all HTLC
// preimages. If this a non-AMP payment, the ShardTracker will return a
// simple shard with the payment's static payment hash.
shard, err := p.shardTracker.NewShard(attemptID, lastShard)
if err != nil {
return nil, err
}
// If this shard carries MPP or AMP options, add them to the last hop
// on the route.
hop := rt.Hops[len(rt.Hops)-1]
if shard.MPP() != nil {
hop.MPP = shard.MPP()
}
if shard.AMP() != nil {
hop.AMP = shard.AMP()
}
hash := shard.Hash()
// We now have all the information needed to populate the current
// attempt information.
return channeldb.NewHtlcAttempt(
attemptID, sessionKey, *rt, p.router.cfg.Clock.Now(), &hash,
)
}
// sendAttempt attempts to send the current attempt to the switch to complete
// the payment. If this attempt fails, then we'll continue on to the next
// available route.
func (p *paymentLifecycle) sendAttempt(
attempt *channeldb.HTLCAttempt) (*attemptResult, error) {
log.Debugf("Sending HTLC attempt(id=%v, total_amt=%v, first_hop_amt=%d"+
") for payment %v", attempt.AttemptID,
attempt.Route.TotalAmount, attempt.Route.FirstHopAmount.Val,
p.identifier)
rt := attempt.Route
// Construct the first hop.
firstHop := lnwire.NewShortChanIDFromInt(rt.Hops[0].ChannelID)
// Craft an HTLC packet to send to the htlcswitch. The metadata within
// this packet will be used to route the payment through the network,
// starting with the first-hop.
htlcAdd := &lnwire.UpdateAddHTLC{
Amount: rt.FirstHopAmount.Val.Int(),
Expiry: rt.TotalTimeLock,
PaymentHash: *attempt.Hash,
CustomRecords: rt.FirstHopWireCustomRecords,
}
// Generate the raw encoded sphinx packet to be included along
// with the htlcAdd message that we send directly to the
// switch.
onionBlob, err := attempt.OnionBlob()
if err != nil {
log.Errorf("Failed to create onion blob: attempt=%d in "+
"payment=%v, err:%v", attempt.AttemptID,
p.identifier, err)
return p.failAttempt(attempt.AttemptID, err)
}
htlcAdd.OnionBlob = onionBlob
// Send it to the Switch. When this method returns we assume
// the Switch successfully has persisted the payment attempt,
// such that we can resume waiting for the result after a
// restart.
err = p.router.cfg.Payer.SendHTLC(firstHop, attempt.AttemptID, htlcAdd)
if err != nil {
log.Errorf("Failed sending attempt %d for payment %v to "+
"switch: %v", attempt.AttemptID, p.identifier, err)
return p.handleSwitchErr(attempt, err)
}
log.Debugf("Attempt %v for payment %v successfully sent to switch, "+
"route: %v", attempt.AttemptID, p.identifier, &attempt.Route)
return &attemptResult{
attempt: attempt,
}, nil
}
// amendFirstHopData is a function that calls the traffic shaper to allow it to
// add custom records to the outgoing HTLC and also adjust the amount if
// needed.
func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error {
// The first hop amount on the route is the full route amount if not
// overwritten by the traffic shaper. So we set the initial value now
// and potentially overwrite it later.
rt.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0](
tlv.NewBigSizeT(rt.TotalAmount),
)
// By default, we set the first hop custom records to the initial
// value requested by the RPC. The traffic shaper may overwrite this
// value.
rt.FirstHopWireCustomRecords = p.firstHopCustomRecords
// extraDataRequest is a helper struct to pass the custom records and
// amount back from the traffic shaper.
type extraDataRequest struct {
customRecords fn.Option[lnwire.CustomRecords]
amount fn.Option[lnwire.MilliSatoshi]
}
// If a hook exists that may affect our outgoing message, we call it now
// and apply its side effects to the UpdateAddHTLC message.
result, err := fn.MapOptionZ(
p.router.cfg.TrafficShaper,
//nolint:ll
func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] {
newAmt, newRecords, err := ts.ProduceHtlcExtraData(
rt.TotalAmount, p.firstHopCustomRecords,
)
if err != nil {
return fn.Err[extraDataRequest](err)
}
// Make sure we only received valid records.
if err := newRecords.Validate(); err != nil {
return fn.Err[extraDataRequest](err)
}
log.Debugf("Aux traffic shaper returned custom "+
"records %v and amount %d msat for HTLC",
spew.Sdump(newRecords), newAmt)
return fn.Ok(extraDataRequest{
customRecords: fn.Some(newRecords),
amount: fn.Some(newAmt),
})
},
).Unpack()
if err != nil {
return fmt.Errorf("traffic shaper failed to produce extra "+
"data: %w", err)
}
// Apply the side effects to the UpdateAddHTLC message.
result.customRecords.WhenSome(func(records lnwire.CustomRecords) {
rt.FirstHopWireCustomRecords = records
})
result.amount.WhenSome(func(amount lnwire.MilliSatoshi) {
rt.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0](
tlv.NewBigSizeT(amount),
)
})
return nil
}
// failAttemptAndPayment fails both the payment and its attempt via the
// router's control tower, which marks the payment as failed in db.
func (p *paymentLifecycle) failPaymentAndAttempt(
attemptID uint64, reason *channeldb.FailureReason,
sendErr error) (*attemptResult, error) {
log.Errorf("Payment %v failed: final_outcome=%v, raw_err=%v",
p.identifier, *reason, sendErr)
// Fail the payment via control tower.
//
// NOTE: we must fail the payment first before failing the attempt.
// Otherwise, once the attempt is marked as failed, another goroutine
// might make another attempt while we are failing the payment.
err := p.router.cfg.Control.FailPayment(p.identifier, *reason)
if err != nil {
log.Errorf("Unable to fail payment: %v", err)
return nil, err
}
// Fail the attempt.
return p.failAttempt(attemptID, sendErr)
}
// handleSwitchErr inspects the given error from the Switch and determines
// whether we should make another payment attempt, or if it should be
// considered a terminal error. Terminal errors will be recorded with the
// control tower. It analyzes the sendErr for the payment attempt received from
// the switch and updates mission control and/or channel policies. Depending on
// the error type, the error is either the final outcome of the payment or we
// need to continue with an alternative route. A final outcome is indicated by
// a non-nil reason value.
func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt,
sendErr error) (*attemptResult, error) {
internalErrorReason := channeldb.FailureReasonError
attemptID := attempt.AttemptID
// reportAndFail is a helper closure that reports the failure to the
// mission control, which helps us to decide whether we want to retry
// the payment or not. If a non nil reason is returned from mission
// control, it will further fail the payment via control tower.
reportAndFail := func(srcIdx *int,
msg lnwire.FailureMessage) (*attemptResult, error) {
// Report outcome to mission control.
reason, err := p.router.cfg.MissionControl.ReportPaymentFail(
attemptID, &attempt.Route, srcIdx, msg,
)
if err != nil {
log.Errorf("Error reporting payment result to mc: %v",
err)
reason = &internalErrorReason
}
// Fail the attempt only if there's no reason.
if reason == nil {
// Fail the attempt.
return p.failAttempt(attemptID, sendErr)
}
// Otherwise fail both the payment and the attempt.
return p.failPaymentAndAttempt(attemptID, reason, sendErr)
}
// If this attempt ID is unknown to the Switch, it means it was never
// checkpointed and forwarded by the switch before a restart. In this
// case we can safely send a new payment attempt, and wait for its
// result to be available.
if errors.Is(sendErr, htlcswitch.ErrPaymentIDNotFound) {
log.Warnf("Failing attempt=%v for payment=%v as it's not "+
"found in the Switch", attempt.AttemptID, p.identifier)
return p.failAttempt(attemptID, sendErr)
}
if errors.Is(sendErr, htlcswitch.ErrUnreadableFailureMessage) {
log.Warn("Unreadable failure when sending htlc: id=%v, hash=%v",
attempt.AttemptID, attempt.Hash)
// Since this error message cannot be decrypted, we will send a
// nil error message to our mission controller and fail the
// payment.
return reportAndFail(nil, nil)
}
// If the error is a ClearTextError, we have received a valid wire
// failure message, either from our own outgoing link or from a node
// down the route. If the error is not related to the propagation of
// our payment, we can stop trying because an internal error has
// occurred.
var rtErr htlcswitch.ClearTextError
ok := errors.As(sendErr, &rtErr)
if !ok {
return p.failPaymentAndAttempt(
attemptID, &internalErrorReason, sendErr,
)
}
// failureSourceIdx is the index of the node that the failure occurred
// at. If the ClearTextError received is not a ForwardingError the
// payment error occurred at our node, so we leave this value as 0
// to indicate that the failure occurred locally. If the error is a
// ForwardingError, it did not originate at our node, so we set
// failureSourceIdx to the index of the node where the failure occurred.
failureSourceIdx := 0
var source *htlcswitch.ForwardingError
ok = errors.As(rtErr, &source)
if ok {
failureSourceIdx = source.FailureSourceIdx
}
// Extract the wire failure and apply channel update if it contains one.
// If we received an unknown failure message from a node along the
// route, the failure message will be nil.
failureMessage := rtErr.WireMessage()
err := p.handleFailureMessage(
&attempt.Route, failureSourceIdx, failureMessage,
)
if err != nil {
return p.failPaymentAndAttempt(
attemptID, &internalErrorReason, sendErr,
)
}
log.Tracef("Node=%v reported failure when sending htlc",
failureSourceIdx)
return reportAndFail(&failureSourceIdx, failureMessage)
}
// handleFailureMessage tries to apply a channel update present in the failure
// message if any.
func (p *paymentLifecycle) handleFailureMessage(rt *route.Route,
errorSourceIdx int, failure lnwire.FailureMessage) error {
if failure == nil {
return nil
}
// It makes no sense to apply our own channel updates.
if errorSourceIdx == 0 {
log.Errorf("Channel update of ourselves received")
return nil
}
// Extract channel update if the error contains one.
update := p.router.extractChannelUpdate(failure)
if update == nil {
return nil
}
// Parse pubkey to allow validation of the channel update. This should
// always succeed, otherwise there is something wrong in our
// implementation. Therefore, return an error.
errVertex := rt.Hops[errorSourceIdx-1].PubKeyBytes
errSource, err := btcec.ParsePubKey(errVertex[:])
if err != nil {
log.Errorf("Cannot parse pubkey: idx=%v, pubkey=%v",
errorSourceIdx, errVertex)
return err
}
var (
isAdditionalEdge bool
policy *models.CachedEdgePolicy
)
// Before we apply the channel update, we need to decide whether the
// update is for additional (ephemeral) edge or normal edge stored in
// db.
//
// Note: the p.paySession might be nil here if it's called inside
// SendToRoute where there's no payment lifecycle.
if p.paySession != nil {
policy = p.paySession.GetAdditionalEdgePolicy(
errSource, update.ShortChannelID.ToUint64(),
)
if policy != nil {
isAdditionalEdge = true
}
}
// Apply channel update to additional edge policy.
if isAdditionalEdge {
if !p.paySession.UpdateAdditionalEdge(
update, errSource, policy) {
log.Debugf("Invalid channel update received: node=%v",
errVertex)
}
return nil
}
// Apply channel update to the channel edge policy in our db.
if !p.router.cfg.ApplyChannelUpdate(update) {
log.Debugf("Invalid channel update received: node=%v",
errVertex)
}
return nil
}
// failAttempt calls control tower to fail the current payment attempt.
func (p *paymentLifecycle) failAttempt(attemptID uint64,
sendError error) (*attemptResult, error) {
log.Warnf("Attempt %v for payment %v failed: %v", attemptID,
p.identifier, sendError)
failInfo := marshallError(
sendError,
p.router.cfg.Clock.Now(),
)
// Now that we are failing this payment attempt, cancel the shard with
// the ShardTracker such that it can derive the correct hash for the
// next attempt.
if err := p.shardTracker.CancelShard(attemptID); err != nil {
return nil, err
}
attempt, err := p.router.cfg.Control.FailAttempt(
p.identifier, attemptID, failInfo,
)
if err != nil {
return nil, err
}
return &attemptResult{
attempt: attempt,
err: sendError,
}, nil
}
// marshallError marshall an error as received from the switch to a structure
// that is suitable for database storage.
func marshallError(sendError error, time time.Time) *channeldb.HTLCFailInfo {
response := &channeldb.HTLCFailInfo{
FailTime: time,
}
switch {
case errors.Is(sendError, htlcswitch.ErrPaymentIDNotFound):
response.Reason = channeldb.HTLCFailInternal
return response
case errors.Is(sendError, htlcswitch.ErrUnreadableFailureMessage):
response.Reason = channeldb.HTLCFailUnreadable
return response
}
var rtErr htlcswitch.ClearTextError
ok := errors.As(sendError, &rtErr)
if !ok {
response.Reason = channeldb.HTLCFailInternal
return response
}
message := rtErr.WireMessage()
if message != nil {
response.Reason = channeldb.HTLCFailMessage
response.Message = message
} else {
response.Reason = channeldb.HTLCFailUnknown
}
// If the ClearTextError received is a ForwardingError, the error
// originated from a node along the route, not locally on our outgoing
// link. We set failureSourceIdx to the index of the node where the
// failure occurred. If the error is not a ForwardingError, the failure
// occurred at our node, so we leave the index as 0 to indicate that
// we failed locally.
var fErr *htlcswitch.ForwardingError
ok = errors.As(rtErr, &fErr)
if ok {
response.FailureSourceIndex = uint32(fErr.FailureSourceIdx)
}
return response
}
// patchLegacyPaymentHash will make a copy of the passed attempt and sets its
// Hash field to be the payment hash if it's nil.
//
// NOTE: For legacy payments, which were created before the AMP feature was
// enabled, the `Hash` field in their HTLC attempts is nil. In that case, we use
// the payment hash as the `attempt.Hash` as they are identical.
func (p *paymentLifecycle) patchLegacyPaymentHash(
a channeldb.HTLCAttempt) channeldb.HTLCAttempt {
// Exit early if this is not a legacy attempt.
if a.Hash != nil {
return a
}
// Log a warning if the user is still using legacy payments, which has
// weaker support.
log.Warnf("Found legacy htlc attempt %v in payment %v", a.AttemptID,
p.identifier)
// Set the attempt's hash to be the payment hash, which is the payment's
// `PaymentHash`` in the `PaymentCreationInfo`. For legacy payments
// before AMP feature, the `Hash` field was not set so we use the
// payment hash instead.
//
// NOTE: During the router's startup, we have a similar logic in
// `resumePayments`, in which we will use the payment hash instead if
// the attempt's hash is nil.
a.Hash = &p.identifier
return a
}
// reloadInflightAttempts is called when the payment lifecycle is resumed after
// a restart. It reloads all inflight attempts from the control tower and
// collects the results of the attempts that have been sent before.
func (p *paymentLifecycle) reloadInflightAttempts() (DBMPPayment, error) {
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
if err != nil {
return nil, err
}
for _, a := range payment.InFlightHTLCs() {
a := a
log.Infof("Resuming HTLC attempt %v for payment %v",
a.AttemptID, p.identifier)
// Potentially attach the payment hash to the `Hash` field if
// it's a legacy payment.
a = p.patchLegacyPaymentHash(a)
p.resultCollector(&a)
}
return payment, nil
}
// reloadPayment returns the latest payment found in the db (control tower).
func (p *paymentLifecycle) reloadPayment() (DBMPPayment,
*channeldb.MPPaymentState, error) {
// Read the db to get the latest state of the payment.
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
if err != nil {
return nil, nil, err
}
ps := payment.GetState()
remainingFees := p.calcFeeBudget(ps.FeesPaid)
log.Debugf("Payment %v: status=%v, active_shards=%v, rem_value=%v, "+
"fee_limit=%v", p.identifier, payment.GetStatus(),
ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees)
return payment, ps, nil
}
// handleAttemptResult processes the result of an HTLC attempt returned from
// the htlcswitch.
func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt,
result *htlcswitch.PaymentResult) (*attemptResult, error) {
// If the result has an error, we need to further process it by failing
// the attempt and maybe fail the payment.
if result.Error != nil {
return p.handleSwitchErr(attempt, result.Error)
}
// We got an attempt settled result back from the switch.
log.Debugf("Payment(%v): attempt(%v) succeeded", p.identifier,
attempt.AttemptID)
// Report success to mission control.
err := p.router.cfg.MissionControl.ReportPaymentSuccess(
attempt.AttemptID, &attempt.Route,
)
if err != nil {
log.Errorf("Error reporting payment success to mc: %v", err)
}
// In case of success we atomically store settle result to the DB and
// move the shard to the settled state.
htlcAttempt, err := p.router.cfg.Control.SettleAttempt(
p.identifier, attempt.AttemptID,
&channeldb.HTLCSettleInfo{
Preimage: result.Preimage,
SettleTime: p.router.cfg.Clock.Now(),
},
)
if err != nil {
log.Errorf("Error settling attempt %v for payment %v with "+
"preimage %v: %v", attempt.AttemptID, p.identifier,
result.Preimage, err)
// We won't mark the attempt as failed since we already have
// the preimage.
return nil, err
}
return &attemptResult{
attempt: htlcAttempt,
}, nil
}
// collectAndHandleResult waits for the result for the given attempt to be
// available from the Switch, then records the attempt outcome with the control
// tower. An attemptResult is returned, indicating the final outcome of this
// HTLC attempt.
func (p *paymentLifecycle) collectAndHandleResult(
attempt *channeldb.HTLCAttempt) (*attemptResult, error) {
result, err := p.collectResult(attempt)
if err != nil {
return nil, err
}
return p.handleAttemptResult(attempt, result)
}
package routing
import (
"fmt"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/channeldb"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/netann"
"github.com/lightningnetwork/lnd/routing/route"
)
// BlockPadding is used to increment the finalCltvDelta value for the last hop
// to prevent an HTLC being failed if some blocks are mined while it's in-flight.
const BlockPadding uint16 = 3
// ValidateCLTVLimit is a helper function that validates that the cltv limit is
// greater than the final cltv delta parameter, optionally including the
// BlockPadding in this calculation.
func ValidateCLTVLimit(limit uint32, delta uint16, includePad bool) error {
if includePad {
delta += BlockPadding
}
if limit <= uint32(delta) {
return fmt.Errorf("cltv limit %v should be greater than %v",
limit, delta)
}
return nil
}
// noRouteError encodes a non-critical error encountered during path finding.
type noRouteError uint8
const (
// errNoTlvPayload is returned when the destination hop does not support
// a tlv payload.
errNoTlvPayload noRouteError = iota
// errNoPaymentAddr is returned when the destination hop does not
// support payment addresses.
errNoPaymentAddr
// errNoPathFound is returned when a path to the target destination does
// not exist in the graph.
errNoPathFound
// errInsufficientLocalBalance is returned when none of the local
// channels have enough balance for the payment.
errInsufficientBalance
// errEmptyPaySession is returned when the empty payment session is
// queried for a route.
errEmptyPaySession
// errUnknownRequiredFeature is returned when the destination node
// requires an unknown feature.
errUnknownRequiredFeature
// errMissingDependentFeature is returned when the destination node
// misses a feature that a feature that we require depends on.
errMissingDependentFeature
)
var (
// DefaultShardMinAmt is the default amount beyond which we won't try to
// further split the payment if no route is found. It is the minimum
// amount that we use as the shard size when splitting.
DefaultShardMinAmt = lnwire.NewMSatFromSatoshis(10000)
)
// Error returns the string representation of the noRouteError.
func (e noRouteError) Error() string {
switch e {
case errNoTlvPayload:
return "destination hop doesn't understand new TLV payloads"
case errNoPaymentAddr:
return "destination hop doesn't understand payment addresses"
case errNoPathFound:
return "unable to find a path to destination"
case errEmptyPaySession:
return "empty payment session"
case errInsufficientBalance:
return "insufficient local balance"
case errUnknownRequiredFeature:
return "unknown required feature"
case errMissingDependentFeature:
return "missing dependent feature"
default:
return "unknown no-route error"
}
}
// FailureReason converts a path finding error into a payment-level failure.
func (e noRouteError) FailureReason() channeldb.FailureReason {
switch e {
case
errNoTlvPayload,
errNoPaymentAddr,
errNoPathFound,
errEmptyPaySession,
errUnknownRequiredFeature,
errMissingDependentFeature:
return channeldb.FailureReasonNoRoute
case errInsufficientBalance:
return channeldb.FailureReasonInsufficientBalance
default:
return channeldb.FailureReasonError
}
}
// PaymentSession is used during SendPayment attempts to provide routes to
// attempt. It also defines methods to give the PaymentSession additional
// information learned during the previous attempts.
type PaymentSession interface {
// RequestRoute returns the next route to attempt for routing the
// specified HTLC payment to the target node. The returned route should
// carry at most maxAmt to the target node, and pay at most feeLimit in
// fees. It can carry less if the payment is MPP. The activeShards
// argument should be set to instruct the payment session about the
// number of in flight HTLCS for the payment, such that it can choose
// splitting strategy accordingly.
//
// A noRouteError is returned if a non-critical error is encountered
// during path finding.
RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32,
firstHopCustomRecords lnwire.CustomRecords) (*route.Route,
error)
// UpdateAdditionalEdge takes an additional channel edge policy
// (private channels) and applies the update from the message. Returns
// a boolean to indicate whether the update has been applied without
// error.
UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool
// GetAdditionalEdgePolicy uses the public key and channel ID to query
// the ephemeral channel edge policy for additional edges. Returns a nil
// if nothing found.
GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *models.CachedEdgePolicy
}
// paymentSession is used during an HTLC routings session to prune the local
// chain view in response to failures, and also report those failures back to
// MissionController. The snapshot copied for this session will only ever grow,
// and will now be pruned after a decay like the main view within mission
// control. We do this as we want to avoid the case where we continually try a
// bad edge or route multiple times in a session. This can lead to an infinite
// loop if payment attempts take long enough. An additional set of edges can
// also be provided to assist in reaching the payment's destination.
type paymentSession struct {
selfNode route.Vertex
additionalEdges map[route.Vertex][]AdditionalEdge
getBandwidthHints func(Graph) (bandwidthHints, error)
payment *LightningPayment
empty bool
pathFinder pathFinder
graphSessFactory GraphSessionFactory
// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probability.
pathFindingConfig PathFindingConfig
missionControl MissionControlQuerier
// minShardAmt is the amount beyond which we won't try to further split
// the payment if no route is found. If the maximum number of htlcs
// specified in the payment is one, under no circumstances splitting
// will happen and this value remains unused.
minShardAmt lnwire.MilliSatoshi
// log is a payment session-specific logger.
log btclog.Logger
}
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
getBandwidthHints func(Graph) (bandwidthHints, error),
graphSessFactory GraphSessionFactory,
missionControl MissionControlQuerier,
pathFindingConfig PathFindingConfig) (*paymentSession, error) {
edges, err := RouteHintsToEdges(p.RouteHints, p.Target)
if err != nil {
return nil, err
}
if p.BlindedPathSet != nil {
if len(edges) != 0 {
return nil, fmt.Errorf("cannot have both route hints " +
"and blinded path")
}
edges, err = p.BlindedPathSet.ToRouteHints()
if err != nil {
return nil, err
}
}
logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier())
return &paymentSession{
selfNode: selfNode,
additionalEdges: edges,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
graphSessFactory: graphSessFactory,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
log: log.WithPrefix(logPrefix),
}, nil
}
// pathFindingError is a wrapper error type that is used to distinguish path
// finding errors from other errors in path finding loop.
type pathFindingError struct {
error
}
// Unwrap returns the underlying error.
func (e *pathFindingError) Unwrap() error {
return e.error
}
// RequestRoute returns a route which is likely to be capable for successfully
// routing the specified HTLC payment to the target node. Initially the first
// set of paths returned from this method may encounter routing failure along
// the way, however as more payments are sent, mission control will start to
// build an up to date view of the network itself. With each payment a new area
// will be explored, which feeds into the recommendations made for routing.
//
// NOTE: This function is safe for concurrent access.
// NOTE: Part of the PaymentSession interface.
func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32,
firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) {
if p.empty {
return nil, errEmptyPaySession
}
// Add BlockPadding to the finalCltvDelta so that the receiving node
// does not reject the HTLC if some blocks are mined while it's in-flight.
finalCltvDelta := p.payment.FinalCLTVDelta
finalCltvDelta += BlockPadding
// We need to subtract the final delta before passing it into path
// finding. The optimal path is independent of the final cltv delta and
// the path finding algorithm is unaware of this value.
cltvLimit := p.payment.CltvLimit - uint32(finalCltvDelta)
// TODO(roasbeef): sync logic amongst dist sys
// Taking into account this prune view, we'll attempt to locate a path
// to our destination, respecting the recommendations from
// MissionController.
restrictions := &RestrictParams{
ProbabilitySource: p.missionControl.GetProbability,
FeeLimit: feeLimit,
OutgoingChannelIDs: p.payment.OutgoingChannelIDs,
LastHop: p.payment.LastHop,
CltvLimit: cltvLimit,
DestCustomRecords: p.payment.DestCustomRecords,
DestFeatures: p.payment.DestFeatures,
PaymentAddr: p.payment.PaymentAddr,
Amp: p.payment.amp,
Metadata: p.payment.Metadata,
FirstHopCustomRecords: firstHopCustomRecords,
}
finalHtlcExpiry := int32(height) + int32(finalCltvDelta)
// Before we enter the loop below, we'll make sure to respect the max
// payment shard size (if it's set), which is effectively our
// client-side MTU that we'll attempt to respect at all times.
maxShardActive := p.payment.MaxShardAmt != nil
if maxShardActive && maxAmt > *p.payment.MaxShardAmt {
p.log.Debugf("Clamping payment attempt from %v to %v due to "+
"max shard size of %v", maxAmt, *p.payment.MaxShardAmt,
maxAmt)
maxAmt = *p.payment.MaxShardAmt
}
var path []*unifiedEdge
findPath := func(graph graphdb.NodeTraverser) error {
// We'll also obtain a set of bandwidthHints from the lower
// layer for each of our outbound channels. This will allow the
// path finding to skip any links that aren't active or just
// don't have enough bandwidth to carry the payment. New
// bandwidth hints are queried for every new path finding
// attempt, because concurrent payments may change balances.
bandwidthHints, err := p.getBandwidthHints(graph)
if err != nil {
return err
}
p.log.Debugf("pathfinding for amt=%v", maxAmt)
// Find a route for the current amount.
path, _, err = p.pathFinder(
&graphParams{
additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints,
graph: graph,
},
restrictions, &p.pathFindingConfig,
p.selfNode, p.selfNode, p.payment.Target,
maxAmt, p.payment.TimePref, finalHtlcExpiry,
)
if err != nil {
// Wrap the error to distinguish path finding errors
// from other errors in this closure.
return &pathFindingError{err}
}
return nil
}
for {
err := p.graphSessFactory.GraphSession(findPath)
// If there is an error, and it is not a path finding error, we
// return it immediately.
if err != nil && !lnutils.ErrorAs[*pathFindingError](err) {
return nil, err
} else if err != nil {
// If the error is a path finding error, we'll unwrap it
// to check the underlying error.
//
//nolint:errorlint
pErr, _ := err.(*pathFindingError)
err = pErr.Unwrap()
}
// Otherwise, we'll switch on the path finding error.
switch {
case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp
// record. If it has a blinded path though, then we
// can split. Split payments to blinded paths won't have
// MPP records.
if p.payment.PaymentAddr.IsNone() &&
p.payment.BlindedPathSet == nil {
p.log.Debugf("not splitting because payment " +
"address is unspecified")
return nil, errNoPathFound
}
if p.payment.DestFeatures == nil {
p.log.Debug("Not splitting because " +
"destination DestFeatures is nil")
return nil, errNoPathFound
}
destFeatures := p.payment.DestFeatures
if !destFeatures.HasFeature(lnwire.MPPOptional) &&
!destFeatures.HasFeature(lnwire.AMPOptional) {
p.log.Debug("not splitting because " +
"destination doesn't declare MPP or " +
"AMP")
return nil, errNoPathFound
}
// No splitting if this is the last shard.
isLastShard := activeShards+1 >= p.payment.MaxParts
if isLastShard {
p.log.Debugf("not splitting because shard "+
"limit %v has been reached",
p.payment.MaxParts)
return nil, errNoPathFound
}
// This is where the magic happens. If we can't find a
// route, try it for half the amount.
maxAmt /= 2
// Put a lower bound on the minimum shard size.
if maxAmt < p.minShardAmt {
p.log.Debugf("not splitting because minimum "+
"shard amount %v has been reached",
p.minShardAmt)
return nil, errNoPathFound
}
// Go pathfinding.
continue
// If there isn't enough local bandwidth, there is no point in
// splitting. It won't be possible to create a complete set in
// any case, but the sent out partial payments would be held by
// the receiver until the mpp timeout.
case err == errInsufficientBalance:
p.log.Debug("not splitting because local balance " +
"is insufficient")
return nil, err
case err != nil:
return nil, err
}
// With the next candidate path found, we'll attempt to turn
// this into a route by applying the time-lock and fee
// requirements.
route, err := newRoute(
p.selfNode, path, height,
finalHopParams{
amt: maxAmt,
totalAmt: p.payment.Amount,
cltvDelta: finalCltvDelta,
records: p.payment.DestCustomRecords,
paymentAddr: p.payment.PaymentAddr,
metadata: p.payment.Metadata,
}, p.payment.BlindedPathSet,
)
if err != nil {
return nil, err
}
return route, err
}
}
// UpdateAdditionalEdge updates the channel edge policy for a private edge. It
// validates the message signature and checks it's up to date, then applies the
// updates to the supplied policy. It returns a boolean to indicate whether
// there's an error when applying the updates.
func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate1,
pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {
// Validate the message signature.
if err := netann.VerifyChannelUpdateSignature(msg, pubKey); err != nil {
log.Errorf(
"Unable to validate channel update signature: %v", err,
)
return false
}
// Update channel policy for the additional edge.
policy.TimeLockDelta = msg.TimeLockDelta
policy.FeeBaseMSat = lnwire.MilliSatoshi(msg.BaseFee)
policy.FeeProportionalMillionths = lnwire.MilliSatoshi(msg.FeeRate)
log.Debugf("New private channel update applied: %v",
lnutils.SpewLogClosure(msg))
return true
}
// GetAdditionalEdgePolicy uses the public key and channel ID to query the
// ephemeral channel edge policy for additional edges. Returns a nil if nothing
// found.
func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *models.CachedEdgePolicy {
target := route.NewVertex(pubKey)
edges, ok := p.additionalEdges[target]
if !ok {
return nil
}
for _, edge := range edges {
policy := edge.EdgePolicy()
if policy.ChannelID != channelID {
continue
}
return policy
}
return nil
}
package routing
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/zpay32"
)
// A compile time assertion to ensure SessionSource meets the
// PaymentSessionSource interface.
var _ PaymentSessionSource = (*SessionSource)(nil)
// SessionSource defines a source for the router to retrieve new payment
// sessions.
type SessionSource struct {
// GraphSessionFactory can be used to gain access to a Graph session.
// If the backing DB allows it, this will mean that a read transaction
// is being held during the use of the session.
GraphSessionFactory GraphSessionFactory
// SourceNode is the graph's source node.
SourceNode *models.LightningNode
// GetLink is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link
// to be traversed. If the link isn't available, then a value of zero
// should be returned. Otherwise, the current up to date knowledge of
// the available bandwidth of the link should be returned.
GetLink getLinkQuery
// MissionControl is a shared memory of sorts that executions of payment
// path finding use in order to remember which vertexes/edges were
// pruned from prior attempts. During payment execution, errors sent by
// nodes are mapped into a vertex or edge to be pruned. Each run will
// then take into account this set of pruned vertexes/edges to reduce
// route failure and pass on graph information gained to the next
// execution.
MissionControl MissionControlQuerier
// PathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probability.
PathFindingConfig PathFindingConfig
}
// NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
// payment's destination.
func (m *SessionSource) NewPaymentSession(p *LightningPayment,
firstHopBlob fn.Option[tlv.Blob],
trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession,
error) {
getBandwidthHints := func(graph Graph) (bandwidthHints, error) {
return newBandwidthManager(
graph, m.SourceNode.PubKeyBytes, m.GetLink,
firstHopBlob, trafficShaper,
)
}
session, err := newPaymentSession(
p, m.SourceNode.PubKeyBytes, getBandwidthHints,
m.GraphSessionFactory, m.MissionControl, m.PathFindingConfig,
)
if err != nil {
return nil, err
}
return session, nil
}
// NewPaymentSessionEmpty creates a new paymentSession instance that is empty,
// and will be exhausted immediately. Used for failure reporting to
// missioncontrol for resumed payment we don't want to make more attempts for.
func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession {
return &paymentSession{
empty: true,
}
}
// RouteHintsToEdges converts a list of invoice route hints to an edge map that
// can be passed into pathfinding.
func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) (
map[route.Vertex][]AdditionalEdge, error) {
edges := make(map[route.Vertex][]AdditionalEdge)
// Traverse through all of the available hop hints and include them in
// our edges map, indexed by the public key of the channel's starting
// node.
for _, routeHint := range routeHints {
// If multiple hop hints are provided within a single route
// hint, we'll assume they must be chained together and sorted
// in forward order in order to reach the target successfully.
for i, hopHint := range routeHint {
// In order to determine the end node of this hint,
// we'll need to look at the next hint's start node. If
// we've reached the end of the hints list, we can
// assume we've reached the destination.
endNode := &models.LightningNode{}
if i != len(routeHint)-1 {
endNode.AddPubKey(routeHint[i+1].NodeID)
} else {
targetPubKey, err := btcec.ParsePubKey(
target[:],
)
if err != nil {
return nil, err
}
endNode.AddPubKey(targetPubKey)
}
// Finally, create the channel edge from the hop hint
// and add it to list of edges corresponding to the node
// at the start of the channel.
edgePolicy := &models.CachedEdgePolicy{
ToNodePubKey: func() route.Vertex {
return endNode.PubKeyBytes
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
ChannelID: hopHint.ChannelID,
FeeBaseMSat: lnwire.MilliSatoshi(
hopHint.FeeBaseMSat,
),
FeeProportionalMillionths: lnwire.MilliSatoshi(
hopHint.FeeProportionalMillionths,
),
TimeLockDelta: hopHint.CLTVExpiryDelta,
}
edge := &PrivateEdge{
policy: edgePolicy,
}
v := route.NewVertex(hopHint.NodeID)
edges[v] = append(edges[v], edge)
}
}
return edges, nil
}
package routing
import (
"errors"
"fmt"
"math"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
// CapacityFraction and capacitySmearingFraction define how
// capacity-related probability reweighting works. CapacityFraction
// defines the fraction of the channel capacity at which the effect
// roughly sets in and capacitySmearingFraction defines over which range
// the factor changes from 1 to minCapacityFactor.
// DefaultCapacityFraction is the default value for CapacityFraction.
// It is chosen such that the capacity factor is active but with a small
// effect. This value together with capacitySmearingFraction leads to a
// noticeable reduction in probability if the amount starts to come
// close to 90% of a channel's capacity.
DefaultCapacityFraction = 0.9999
// capacitySmearingFraction defines how quickly the capacity factor
// drops from 1 to minCapacityFactor. This value results in about a
// variation over 20% of the capacity.
capacitySmearingFraction = 0.025
// minCapacityFactor is the minimal value the capacityFactor can take.
// Having a too low value can lead to discarding of paths due to the
// enforced minimal probability or to too high pathfinding weights.
minCapacityFactor = 0.5
// minCapacityFraction is the minimum allowed value for
// CapacityFraction. The success probability in the random balance model
// (which may not be an accurate description of the liquidity
// distribution in the network) can be approximated with P(a) = 1 - a/c,
// for amount a and capacity c. If we require a probability P(a) = 0.25,
// this translates into a value of 0.75 for a/c. We limit this value in
// order to not discard too many channels.
minCapacityFraction = 0.75
// AprioriEstimatorName is used to identify the apriori probability
// estimator.
AprioriEstimatorName = "apriori"
)
var (
// ErrInvalidHalflife is returned when we get an invalid half life.
ErrInvalidHalflife = errors.New("penalty half life must be >= 0")
// ErrInvalidHopProbability is returned when we get an invalid hop
// probability.
ErrInvalidHopProbability = errors.New("hop probability must be in " +
"[0, 1]")
// ErrInvalidAprioriWeight is returned when we get an apriori weight
// that is out of range.
ErrInvalidAprioriWeight = errors.New("apriori weight must be in [0, 1]")
// ErrInvalidCapacityFraction is returned when we get a capacity
// fraction that is out of range.
ErrInvalidCapacityFraction = fmt.Errorf("capacity fraction must be in "+
"[%v, 1]", minCapacityFraction)
)
// AprioriConfig contains configuration for our probability estimator.
type AprioriConfig struct {
// PenaltyHalfLife defines after how much time a penalized node or
// channel is back at 50% probability.
PenaltyHalfLife time.Duration
// AprioriHopProbability is the assumed success probability of a hop in
// a route when no other information is available.
AprioriHopProbability float64
// AprioriWeight is a value in the range [0, 1] that defines to what
// extent historical results should be extrapolated to untried
// connections. Setting it to one will completely ignore historical
// results and always assume the configured a priori probability for
// untried connections. A value of zero will ignore the a priori
// probability completely and only base the probability on historical
// results, unless there are none available.
AprioriWeight float64
// CapacityFraction is the fraction of a channel's capacity that we
// consider to have liquidity. For amounts that come close to or exceed
// the fraction, an additional penalty is applied. A value of 1.0
// disables the capacityFactor.
CapacityFraction float64
}
// validate checks the configuration of the estimator for allowed values.
func (p AprioriConfig) validate() error {
if p.PenaltyHalfLife < 0 {
return ErrInvalidHalflife
}
if p.AprioriHopProbability < 0 || p.AprioriHopProbability > 1 {
return ErrInvalidHopProbability
}
if p.AprioriWeight < 0 || p.AprioriWeight > 1 {
return ErrInvalidAprioriWeight
}
if p.CapacityFraction < minCapacityFraction || p.CapacityFraction > 1 {
return ErrInvalidCapacityFraction
}
return nil
}
// DefaultAprioriConfig returns the default configuration for the estimator.
func DefaultAprioriConfig() AprioriConfig {
return AprioriConfig{
PenaltyHalfLife: DefaultPenaltyHalfLife,
AprioriHopProbability: DefaultAprioriHopProbability,
AprioriWeight: DefaultAprioriWeight,
CapacityFraction: DefaultCapacityFraction,
}
}
// AprioriEstimator returns node and pair probabilities based on historical
// payment results. It uses a preconfigured success probability value for
// untried hops (AprioriHopProbability) and returns a high success probability
// for hops that could previously conduct a payment (prevSuccessProbability).
// Successful edges are retried until proven otherwise. Recently failed hops are
// penalized by an exponential time decay (PenaltyHalfLife), after which they
// are reconsidered for routing. If information was learned about a forwarding
// node, the information is taken into account to estimate a per node
// probability that mixes with the a priori probability (AprioriWeight).
type AprioriEstimator struct {
// AprioriConfig contains configuration options for our estimator.
AprioriConfig
// prevSuccessProbability is the assumed probability for node pairs that
// successfully relayed the previous attempt.
prevSuccessProbability float64
}
// NewAprioriEstimator creates a new AprioriEstimator.
func NewAprioriEstimator(cfg AprioriConfig) (*AprioriEstimator, error) {
if err := cfg.validate(); err != nil {
return nil, err
}
return &AprioriEstimator{
AprioriConfig: cfg,
prevSuccessProbability: prevSuccessProbability,
}, nil
}
// Compile-time checks that interfaces are implemented.
var _ Estimator = (*AprioriEstimator)(nil)
var _ estimatorConfig = (*AprioriConfig)(nil)
// Config returns the estimator's configuration.
func (p *AprioriEstimator) Config() estimatorConfig {
return p.AprioriConfig
}
// String returns the estimator's configuration as a string representation.
func (p *AprioriEstimator) String() string {
return fmt.Sprintf("estimator type: %v, penalty halflife time: %v, "+
"apriori hop probability: %v, apriori weight: %v, previous "+
"success probability: %v, capacity fraction: %v",
AprioriEstimatorName, p.PenaltyHalfLife,
p.AprioriHopProbability, p.AprioriWeight,
p.prevSuccessProbability, p.CapacityFraction)
}
// getNodeProbability calculates the probability for connections from a node
// that have not been tried before. The results parameter is a list of last
// payment results for that node.
func (p *AprioriEstimator) getNodeProbability(now time.Time,
results NodeResults, amt lnwire.MilliSatoshi,
capacity btcutil.Amount) float64 {
// We reduce the apriori hop probability if the amount comes close to
// the capacity.
apriori := p.AprioriHopProbability * capacityFactor(
amt, capacity, p.CapacityFraction,
)
// If the channel history is not to be taken into account, we can return
// early here with the configured a priori probability.
if p.AprioriWeight == 1 {
return apriori
}
// If there is no channel history, our best estimate is still the a
// priori probability.
if len(results) == 0 {
return apriori
}
// The value of the apriori weight is in the range [0, 1]. Convert it to
// a factor that properly expresses the intention of the weight in the
// following weight average calculation. When the apriori weight is 0,
// the apriori factor is also 0. This means it won't have any effect on
// the weighted average calculation below. When the apriori weight
// approaches 1, the apriori factor goes to infinity. It will heavily
// outweigh any observations that have been collected.
aprioriFactor := 1/(1-p.AprioriWeight) - 1
// Calculate a weighted average consisting of the apriori probability
// and historical observations. This is the part that incentivizes nodes
// to make sure that all (not just some) of their channels are in good
// shape. Senders will steer around nodes that have shown a few
// failures, even though there may be many channels still untried.
//
// If there is just a single observation and the apriori weight is 0,
// this single observation will totally determine the node probability.
// The node probability is returned for all other channels of the node.
// This means that one failure will lead to the success probability
// estimates for all other channels being 0 too. The probability for the
// channel that was tried will not even recover, because it is
// recovering to the node probability (which is zero). So one failure
// effectively prunes all channels of the node forever. This is the most
// aggressive way in which we can penalize nodes and unlikely to yield
// good results in a real network.
probabilitiesTotal := apriori * aprioriFactor
totalWeight := aprioriFactor
for _, result := range results {
switch {
// Weigh success with a constant high weight of 1. There is no
// decay. Amt is never zero, so this clause is never executed
// when result.SuccessAmt is zero.
case amt <= result.SuccessAmt:
totalWeight++
probabilitiesTotal += p.prevSuccessProbability
// Weigh failures in accordance with their age. The base
// probability of a failure is considered zero, so nothing needs
// to be added to probabilitiesTotal.
case !result.FailTime.IsZero() && amt >= result.FailAmt:
age := now.Sub(result.FailTime)
totalWeight += p.getWeight(age)
}
}
return probabilitiesTotal / totalWeight
}
// getWeight calculates a weight in the range [0, 1] that should be assigned to
// a payment result. Weight follows an exponential curve that starts at 1 when
// the result is fresh and asymptotically approaches zero over time. The rate at
// which this happens is controlled by the penaltyHalfLife parameter.
func (p *AprioriEstimator) getWeight(age time.Duration) float64 {
exp := -age.Hours() / p.PenaltyHalfLife.Hours()
return math.Pow(2, exp)
}
// capacityFactor is a multiplier that can be used to reduce the probability
// depending on how much of the capacity is sent. In other words, the factor
// sorts out channels that don't provide enough liquidity. Effectively, this
// leads to usage of larger channels in total to increase success probability,
// but it may also increase fees. The limits are 1 for amt == 0 and
// minCapacityFactor for amt >> capacityCutoffFraction. The function drops
// significantly when amt reaches cutoffMsat. smearingMsat determines over which
// scale the reduction takes place.
func capacityFactor(amt lnwire.MilliSatoshi, capacity btcutil.Amount,
capacityCutoffFraction float64) float64 {
// The special value of 1.0 for capacityFactor disables any effect from
// this factor.
if capacityCutoffFraction == 1 {
return 1.0
}
// If we don't have information about the capacity, which can be the
// case for hop hints or local channels, we return unity to not alter
// anything.
if capacity == 0 {
return 1.0
}
capMsat := float64(lnwire.NewMSatFromSatoshis(capacity))
amtMsat := float64(amt)
if amtMsat > capMsat {
return 0
}
cutoffMsat := capacityCutoffFraction * capMsat
smearingMsat := capacitySmearingFraction * capMsat
// We compute a logistic function mirrored around the y axis, centered
// at cutoffMsat, decaying over the smearingMsat scale.
denominator := 1 + math.Exp(-(amtMsat-cutoffMsat)/smearingMsat)
// The numerator decides what the minimal value of this function will
// be. The minimal value is set by minCapacityFactor.
numerator := 1 - minCapacityFactor
return 1 - numerator/denominator
}
// PairProbability estimates the probability of successfully traversing to
// toNode based on historical payment outcomes for the from node. Those outcomes
// are passed in via the results parameter.
func (p *AprioriEstimator) PairProbability(now time.Time,
results NodeResults, toNode route.Vertex, amt lnwire.MilliSatoshi,
capacity btcutil.Amount) float64 {
nodeProbability := p.getNodeProbability(now, results, amt, capacity)
return p.calculateProbability(
now, results, nodeProbability, toNode, amt,
)
}
// LocalPairProbability estimates the probability of successfully traversing
// our own local channels to toNode.
func (p *AprioriEstimator) LocalPairProbability(
now time.Time, results NodeResults, toNode route.Vertex) float64 {
// For local channels that have never been tried before, we assume them
// to be successful. We have accurate balance and online status
// information on our own channels, so when we select them in a route it
// is close to certain that those channels will work.
nodeProbability := p.prevSuccessProbability
return p.calculateProbability(
now, results, nodeProbability, toNode, lnwire.MaxMilliSatoshi,
)
}
// calculateProbability estimates the probability of successfully traversing to
// toNode based on historical payment outcomes and a fall-back node probability.
func (p *AprioriEstimator) calculateProbability(
now time.Time, results NodeResults,
nodeProbability float64, toNode route.Vertex,
amt lnwire.MilliSatoshi) float64 {
// Retrieve the last pair outcome.
lastPairResult, ok := results[toNode]
// If there is no history for this pair, return the node probability
// that is a probability estimate for untried channel.
if !ok {
return nodeProbability
}
// For successes, we have a fixed (high) probability. Those pairs will
// be assumed good until proven otherwise. Amt is never zero, so this
// clause is never executed when lastPairResult.SuccessAmt is zero.
if amt <= lastPairResult.SuccessAmt {
return p.prevSuccessProbability
}
// Take into account a minimum penalize amount. For balance errors, a
// failure may be reported with such a minimum to prevent too aggressive
// penalization. If the current amount is smaller than the amount that
// previously triggered a failure, we act as if this is an untried
// channel.
if lastPairResult.FailTime.IsZero() || amt < lastPairResult.FailAmt {
return nodeProbability
}
timeSinceLastFailure := now.Sub(lastPairResult.FailTime)
// Calculate success probability based on the weight of the last
// failure. When the failure is fresh, its weight is 1 and we'll return
// probability 0. Over time the probability recovers to the node
// probability. It would be as if this channel was never tried before.
weight := p.getWeight(timeSinceLastFailure)
probability := nodeProbability * (1 - weight)
return probability
}
package routing
import (
"fmt"
"math"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
// DefaultBimodalScaleMsat is the default value for BimodalScaleMsat in
// BimodalConfig. It describes the distribution of funds in the LN based
// on empirical findings. We assume an unbalanced network by default.
DefaultBimodalScaleMsat = lnwire.MilliSatoshi(300_000_000)
// DefaultBimodalNodeWeight is the default value for the
// BimodalNodeWeight in BimodalConfig. It is chosen such that past
// forwardings on other channels of a router are only slightly taken
// into account.
DefaultBimodalNodeWeight = 0.2
// DefaultBimodalDecayTime is the default value for BimodalDecayTime.
// We will forget about previous learnings about channel liquidity on
// the timescale of about a week.
DefaultBimodalDecayTime = 7 * 24 * time.Hour
// BimodalScaleMsatMax is the maximum value for BimodalScaleMsat. We
// limit it here to the fakeHopHintCapacity to avoid issues with hop
// hint probability calculations.
BimodalScaleMsatMax = lnwire.MilliSatoshi(
1000 * fakeHopHintCapacity / 4,
)
// BimodalEstimatorName is used to identify the bimodal estimator.
BimodalEstimatorName = "bimodal"
)
var (
// ErrInvalidScale is returned when we get a scale below or equal zero.
ErrInvalidScale = errors.New("scale must be >= 0 and sane")
// ErrInvalidNodeWeight is returned when we get a node weight that is
// out of range.
ErrInvalidNodeWeight = errors.New("node weight must be in [0, 1]")
// ErrInvalidDecayTime is returned when we get a decay time below zero.
ErrInvalidDecayTime = errors.New("decay time must be larger than zero")
// ErrZeroCapacity is returned when we encounter a channel with zero
// capacity in probability estimation.
ErrZeroCapacity = errors.New("capacity must be larger than zero")
)
// BimodalConfig contains configuration for our probability estimator.
type BimodalConfig struct {
// BimodalNodeWeight defines how strongly other previous forwardings on
// channels of a router should be taken into account when computing a
// channel's probability to route. The allowed values are in the range
// [0, 1], where a value of 0 means that only direct information about a
// channel is taken into account.
BimodalNodeWeight float64
// BimodalScaleMsat describes the scale over which channels
// statistically have some liquidity left. The value determines how
// quickly the bimodal distribution drops off from the edges of a
// channel. A larger value (compared to typical channel capacities)
// means that the drop off is slow and that channel balances are
// distributed more uniformly. A small value leads to the assumption of
// very unbalanced channels.
BimodalScaleMsat lnwire.MilliSatoshi
// BimodalDecayTime is the scale for the exponential information decay
// over time for previous successes or failures.
BimodalDecayTime time.Duration
}
// validate checks the configuration of the estimator for allowed values.
func (p BimodalConfig) validate() error {
if p.BimodalDecayTime <= 0 {
return fmt.Errorf("%v: %w", BimodalEstimatorName,
ErrInvalidDecayTime)
}
if p.BimodalNodeWeight < 0 || p.BimodalNodeWeight > 1 {
return fmt.Errorf("%v: %w", BimodalEstimatorName,
ErrInvalidNodeWeight)
}
if p.BimodalScaleMsat == 0 || p.BimodalScaleMsat > BimodalScaleMsatMax {
return fmt.Errorf("%v: %w", BimodalEstimatorName,
ErrInvalidScale)
}
return nil
}
// DefaultBimodalConfig returns the default configuration for the estimator.
func DefaultBimodalConfig() BimodalConfig {
return BimodalConfig{
BimodalNodeWeight: DefaultBimodalNodeWeight,
BimodalScaleMsat: DefaultBimodalScaleMsat,
BimodalDecayTime: DefaultBimodalDecayTime,
}
}
// BimodalEstimator returns node and pair probabilities based on historical
// payment results based on a liquidity distribution model of the LN. The main
// function is to estimate the direct channel probability based on a depleted
// liquidity distribution model, with additional information decay over time. A
// per-node probability can be mixed with the direct probability, taking into
// account successes/failures on other channels of the forwarder.
type BimodalEstimator struct {
// BimodalConfig contains configuration options for our estimator.
BimodalConfig
}
// NewBimodalEstimator creates a new BimodalEstimator.
func NewBimodalEstimator(cfg BimodalConfig) (*BimodalEstimator, error) {
if err := cfg.validate(); err != nil {
return nil, err
}
return &BimodalEstimator{
BimodalConfig: cfg,
}, nil
}
// Compile-time checks that interfaces are implemented.
var _ Estimator = (*BimodalEstimator)(nil)
var _ estimatorConfig = (*BimodalConfig)(nil)
// config returns the current configuration of the estimator.
func (p *BimodalEstimator) Config() estimatorConfig {
return p.BimodalConfig
}
// String returns the estimator's configuration as a string representation.
func (p *BimodalEstimator) String() string {
return fmt.Sprintf("estimator type: %v, decay time: %v, liquidity "+
"scale: %v, node weight: %v", BimodalEstimatorName,
p.BimodalDecayTime, p.BimodalScaleMsat, p.BimodalNodeWeight)
}
// PairProbability estimates the probability of successfully traversing to
// toNode based on historical payment outcomes for the from node. Those outcomes
// are passed in via the results parameter.
func (p *BimodalEstimator) PairProbability(now time.Time,
results NodeResults, toNode route.Vertex, amt lnwire.MilliSatoshi,
capacity btcutil.Amount) float64 {
// We first compute the probability for the desired hop taking into
// account previous knowledge.
directProbability := p.directProbability(
now, results, toNode, amt, lnwire.NewMSatFromSatoshis(capacity),
)
// The final probability is computed by taking into account other
// channels of the from node.
return p.calculateProbability(directProbability, now, results, toNode)
}
// LocalPairProbability computes the probability to reach toNode given a set of
// previous learnings.
func (p *BimodalEstimator) LocalPairProbability(now time.Time,
results NodeResults, toNode route.Vertex) float64 {
// For direct local probabilities we assume to know exactly how much we
// can send over a channel, which assumes that channels are active and
// have enough liquidity.
directProbability := 1.0
// If we had an unexpected failure for this node, we reduce the
// probability for some time to avoid infinite retries.
result, ok := results[toNode]
if ok && !result.FailTime.IsZero() {
timeAgo := now.Sub(result.FailTime)
// We only expect results in the past to get a probability
// between 0 and 1.
if timeAgo < 0 {
timeAgo = 0
}
exponent := -float64(timeAgo) / float64(p.BimodalDecayTime)
directProbability -= math.Exp(exponent)
}
return directProbability
}
// directProbability computes the probability to reach a node based on the
// liquidity distribution in the LN.
func (p *BimodalEstimator) directProbability(now time.Time,
results NodeResults, toNode route.Vertex, amt lnwire.MilliSatoshi,
capacity lnwire.MilliSatoshi) float64 {
// We first determine the time-adjusted success and failure amounts to
// then compute a probability. We know that we can send a zero amount.
successAmount := lnwire.MilliSatoshi(0)
// We know that we cannot send the full capacity.
failAmount := capacity
// If we have information about past successes or failures, we modify
// them with a time decay.
result, ok := results[toNode]
if ok {
// Apply a time decay for the amount we cannot send.
if !result.FailTime.IsZero() {
failAmount = cannotSend(
result.FailAmt, capacity, now, result.FailTime,
p.BimodalDecayTime,
)
}
// Apply a time decay for the amount we can send.
if !result.SuccessTime.IsZero() {
successAmount = canSend(
result.SuccessAmt, now, result.SuccessTime,
p.BimodalDecayTime,
)
}
}
// Compute the direct channel probability.
probability, err := p.probabilityFormula(
capacity, successAmount, failAmount, amt,
)
if err != nil {
log.Errorf("error computing probability to node: %v "+
"(node: %v, results: %v, amt: %v, capacity: %v)",
err, toNode, results, amt, capacity)
return 0.0
}
return probability
}
// calculateProbability computes the total hop probability combining the channel
// probability and historic forwarding data of other channels of the node we try
// to send from.
//
// Goals:
// * We want to incentivize good routing nodes: the more routable channels a
// node has, the more we want to incentivize (vice versa for failures).
// -> We reduce/increase the direct probability depending on past
// failures/successes for other channels of the node.
//
// * We want to be forgiving/give other nodes a chance as well: we want to
// forget about (non-)routable channels over time.
// -> We weight the successes/failures with a time decay such that they will not
// influence the total probability if a long time went by.
//
// * If we don't have other info, we want to solely rely on the direct
// probability.
//
// * We want to be able to specify how important the other channels are compared
// to the direct channel.
// -> Introduce a node weight factor that weights the direct probability against
// the node-wide average. The larger the node weight, the more important other
// channels of the node are.
//
// How do failures on low fee nodes redirect routing to higher fee nodes?
// Assumptions:
// * attemptCostPPM of 1000 PPM
// * constant direct channel probability of P0 (usually 0.5 for large amounts)
// * node weight w of 0.2
//
// The question we want to answer is:
// How often would a zero-fee node be tried (even if there were failures for its
// other channels) over trying a high-fee node with 2000 PPM and no direct
// knowledge about the channel to send over?
//
// The probability of a route of length l is P(l) = l * P0.
//
// The total probability after n failures (with the implemented method here) is:
// P(l, n) = P(l-1) * P(n)
// = P(l-1) * (P0 + n*0) / (1 + n*w)
// = P(l) / (1 + n*w)
//
// Condition for a high-fee channel to overcome a low fee channel in the
// Dijkstra weight function (only looking at fee and probability PPM terms):
// highFeePPM + attemptCostPPM * 1/P(l) = 0PPM + attemptCostPPM * 1/P(l, n)
// highFeePPM/attemptCostPPM = 1/P(l, n) - 1/P(l) =
// = (1 + n*w)/P(l) - 1/P(l) =
// = n*w/P(l)
//
// Therefore:
// n = (highFeePPM/attemptCostPPM) * (P(l)/w) =
// = (2000/1000) * 0.5 * l / w = l/w
//
// For a one-hop route we get:
// n = 1/0.2 = 5 tolerated failures
//
// For a three-hop route we get:
// n = 3/0.2 = 15 tolerated failures
//
// For more details on the behavior see tests.
func (p *BimodalEstimator) calculateProbability(directProbability float64,
now time.Time, results NodeResults, toNode route.Vertex) float64 {
// If we don't take other channels into account, we can return early.
if p.BimodalNodeWeight == 0.0 {
return directProbability
}
// If we have up-to-date information about the channel we want to use,
// i.e. the info stems from results not longer ago than the decay time,
// we will only use the direct probability. This is needed in order to
// avoid that other previous results (on all other channels of the same
// routing node) will distort and pin the calculated probability even if
// we have accurate direct information. This helps to dip the
// probability below the min probability in case of failures, to start
// the splitting process.
directResult, ok := results[toNode]
if ok {
latest := directResult.SuccessTime
if directResult.FailTime.After(latest) {
latest = directResult.FailTime
}
// We use BimonodalDecayTime to judge the currentness of the
// data. It is the time scale on which we assume to have lost
// information.
if now.Sub(latest) < p.BimodalDecayTime {
log.Tracef("Using direct probability for node %v: %v",
toNode, directResult)
return directProbability
}
}
// w is a parameter which determines how strongly the other channels of
// a node should be incorporated, the higher the stronger.
w := p.BimodalNodeWeight
// dt determines the timeliness of the previous successes/failures
// to be taken into account.
dt := float64(p.BimodalDecayTime)
// The direct channel probability is weighted fully, all other results
// are weighted according to how recent the information is.
totalProbabilities := directProbability
totalWeights := 1.0
for peer, result := range results {
// We don't include the direct hop probability here because it
// is already included in totalProbabilities.
if peer == toNode {
continue
}
// We add probabilities weighted by how recent the info is.
var weight float64
if result.SuccessAmt > 0 {
exponent := -float64(now.Sub(result.SuccessTime)) / dt
weight = math.Exp(exponent)
totalProbabilities += w * weight
totalWeights += w * weight
}
if result.FailAmt > 0 {
exponent := -float64(now.Sub(result.FailTime)) / dt
weight = math.Exp(exponent)
// Failures don't add to total success probability.
totalWeights += w * weight
}
}
return totalProbabilities / totalWeights
}
// canSend returns the sendable amount over the channel, respecting time decay.
// canSend approaches zero, if we wait for a much longer time than the decay
// time.
func canSend(successAmount lnwire.MilliSatoshi, now, successTime time.Time,
decayConstant time.Duration) lnwire.MilliSatoshi {
// The factor approaches 0 for successTime a long time in the past,
// is 1 when the successTime is now.
factor := math.Exp(
-float64(now.Sub(successTime)) / float64(decayConstant),
)
canSend := factor * float64(successAmount)
return lnwire.MilliSatoshi(canSend)
}
// cannotSend returns the not sendable amount over the channel, respecting time
// decay. cannotSend approaches the capacity, if we wait for a much longer time
// than the decay time.
func cannotSend(failAmount, capacity lnwire.MilliSatoshi, now,
failTime time.Time, decayConstant time.Duration) lnwire.MilliSatoshi {
if failAmount > capacity {
failAmount = capacity
}
// The factor approaches 0 for failTime a long time in the past and it
// is 1 when the failTime is now.
factor := math.Exp(
-float64(now.Sub(failTime)) / float64(decayConstant),
)
cannotSend := capacity - lnwire.MilliSatoshi(
factor*float64(capacity-failAmount),
)
return cannotSend
}
// primitive computes the indefinite integral of our assumed (normalized)
// liquidity probability distribution. The distribution of liquidity x here is
// the function P(x) ~ exp(-x/s) + exp((x-c)/s), i.e., two exponentials residing
// at the ends of channels. This means that we expect liquidity to be at either
// side of the channel with capacity c. The s parameter (scale) defines how far
// the liquidity leaks into the channel. A very low scale assumes completely
// unbalanced channels, a very high scale assumes a random distribution. More
// details can be found in
// https://github.com/lightningnetwork/lnd/issues/5988#issuecomment-1131234858.
func (p *BimodalEstimator) primitive(c, x float64) float64 {
s := float64(p.BimodalScaleMsat)
// The indefinite integral of P(x) is given by
// Int P(x) dx = H(x) = s * (-e(-x/s) + e((x-c)/s)),
// and its norm from 0 to c can be computed from it,
// norm = [H(x)]_0^c = s * (-e(-c/s) + 1 -(1 + e(-c/s))).
ecs := math.Exp(-c / s)
exs := math.Exp(-x / s)
// It would be possible to split the next term and reuse the factors
// from before, but this can lead to numerical issues with large
// numbers.
excs := math.Exp((x - c) / s)
// norm can only become zero, if c is zero, which we sorted out before
// calling this method.
norm := -2*ecs + 2
// We end up with the primitive function of the normalized P(x).
return (-exs + excs) / norm
}
// integral computes the integral of our liquidity distribution from the lower
// to the upper value.
func (p *BimodalEstimator) integral(capacity, lower, upper float64) float64 {
if lower < 0 || lower > upper {
log.Errorf("probability integral limits nonsensical: capacity:"+
"%v lower: %v upper: %v", capacity, lower, upper)
return 0.0
}
return p.primitive(capacity, upper) - p.primitive(capacity, lower)
}
// probabilityFormula computes the expected probability for a payment of
// amountMsat given prior learnings for a channel of certain capacity.
// successAmountMsat and failAmountMsat stand for the unsettled success and
// failure amounts, respectively. The formula is derived using the formalism
// presented in Pickhardt et al., https://arxiv.org/abs/2103.08576.
func (p *BimodalEstimator) probabilityFormula(capacityMsat, successAmountMsat,
failAmountMsat, amountMsat lnwire.MilliSatoshi) (float64, error) {
// Convert to positive-valued floats.
capacity := float64(capacityMsat)
successAmount := float64(successAmountMsat)
failAmount := float64(failAmountMsat)
amount := float64(amountMsat)
// In order for this formula to give reasonable results, we need to have
// an estimate of the capacity of a channel (or edge between nodes).
if capacity == 0.0 {
return 0, ErrZeroCapacity
}
// We cannot send more than the capacity.
if amount > capacity {
return 0.0, nil
}
// Mission control may have some outdated values, we correct them here.
// TODO(bitromortac): there may be better decisions to make in these
// cases, e.g., resetting failAmount=cap and successAmount=0.
// failAmount should be capacity at max.
if failAmount > capacity {
log.Debugf("Correcting failAmount %v to capacity %v",
failAmount, capacity)
failAmount = capacity
}
// successAmount should be capacity at max.
if successAmount > capacity {
log.Debugf("Correcting successAmount %v to capacity %v",
successAmount, capacity)
successAmount = capacity
}
// The next statement is a safety check against an illogical condition,
// otherwise the renormalization integral would become zero. This may
// happen if a large channel gets closed and smaller ones remain, but
// it should recover with the time decay.
if failAmount <= successAmount {
log.Tracef("fail amount (%v) is smaller than or equal the "+
"success amount (%v) for capacity (%v)",
failAmountMsat, successAmountMsat, capacityMsat)
return 0.0, nil
}
// We cannot send more than the fail amount.
if amount >= failAmount {
return 0.0, nil
}
// The success probability for payment amount a is the integral over the
// prior distribution P(x), the probability to find liquidity between
// the amount a and channel capacity c (or failAmount a_f):
// P(X >= a | X < a_f) = Integral_{a}^{a_f} P(x) dx
prob := p.integral(capacity, amount, failAmount)
if math.IsNaN(prob) {
return 0.0, fmt.Errorf("non-normalized probability is NaN, "+
"capacity: %v, amount: %v, fail amount: %v",
capacity, amount, failAmount)
}
// If we have payment information, we need to adjust the prior
// distribution P(x) and get the posterior distribution by renormalizing
// the prior distribution in such a way that the probability mass lies
// between a_s and a_f.
reNorm := p.integral(capacity, successAmount, failAmount)
if math.IsNaN(reNorm) {
return 0.0, fmt.Errorf("normalization factor is NaN, "+
"capacity: %v, success amount: %v, fail amount: %v",
capacity, successAmount, failAmount)
}
// The normalization factor can only be zero if the success amount is
// equal or larger than the fail amount. This should not happen as we
// have checked this scenario above.
if reNorm == 0.0 {
return 0.0, fmt.Errorf("normalization factor is zero, "+
"capacity: %v, success amount: %v, fail amount: %v",
capacity, successAmount, failAmount)
}
prob /= reNorm
// Note that for payment amounts smaller than successAmount, we can get
// a value larger than unity, which we cap here to get a proper
// probability.
if prob > 1.0 {
if amount > successAmount {
return 0.0, fmt.Errorf("unexpected large probability "+
"(%v) capacity: %v, amount: %v, success "+
"amount: %v, fail amount: %v", prob, capacity,
amount, successAmount, failAmount)
}
return 1.0, nil
} else if prob < 0.0 {
return 0.0, fmt.Errorf("negative probability "+
"(%v) capacity: %v, amount: %v, success "+
"amount: %v, fail amount: %v", prob, capacity,
amount, successAmount, failAmount)
}
return prob, nil
}
package routing
import (
"bytes"
"fmt"
"io"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
// Instantiate variables to allow taking a reference from the failure reason.
var (
reasonError = channeldb.FailureReasonError
reasonIncorrectDetails = channeldb.FailureReasonPaymentDetails
)
// pairResult contains the result of the interpretation of a payment attempt for
// a specific node pair.
type pairResult struct {
// amt is the amount that was forwarded for this pair. Can be set to
// zero for failures that are amount independent.
amt lnwire.MilliSatoshi
// success indicates whether the payment attempt was successful through
// this pair.
success bool
}
// failPairResult creates a new result struct for a failure.
func failPairResult(minPenalizeAmt lnwire.MilliSatoshi) pairResult {
return pairResult{
amt: minPenalizeAmt,
}
}
// newSuccessPairResult creates a new result struct for a success.
func successPairResult(successAmt lnwire.MilliSatoshi) pairResult {
return pairResult{
success: true,
amt: successAmt,
}
}
// String returns the human-readable representation of a pair result.
func (p pairResult) String() string {
var resultType string
if p.success {
resultType = "success"
} else {
resultType = "failed"
}
return fmt.Sprintf("%v (amt=%v)", resultType, p.amt)
}
// interpretedResult contains the result of the interpretation of a payment
// attempt.
type interpretedResult struct {
// nodeFailure points to a node pubkey if all channels of that node are
// responsible for the result.
nodeFailure *route.Vertex
// pairResults contains a map of node pairs for which we have a result.
pairResults map[DirectedNodePair]pairResult
// finalFailureReason is set to a non-nil value if it makes no more
// sense to start another payment attempt. It will contain the reason
// why.
finalFailureReason *channeldb.FailureReason
// policyFailure is set to a node pair if there is a policy failure on
// that connection. This is used to control the second chance logic for
// policy failures.
policyFailure *DirectedNodePair
}
// interpretResult interprets a payment outcome and returns an object that
// contains information required to update mission control.
func interpretResult(rt *mcRoute,
failure fn.Option[paymentFailure]) *interpretedResult {
i := &interpretedResult{
pairResults: make(map[DirectedNodePair]pairResult),
}
return fn.ElimOption(failure, func() *interpretedResult {
i.processSuccess(rt)
return i
}, func(info paymentFailure) *interpretedResult {
i.processFail(rt, info)
return i
})
}
// processSuccess processes a successful payment attempt.
func (i *interpretedResult) processSuccess(route *mcRoute) {
// For successes, all nodes must have acted in the right way.
// Therefore we mark all of them with a success result. However we need
// to handle the blinded route part separately because for intermediate
// blinded nodes the amount field is set to zero so we use the receiver
// amount.
introIdx, isBlinded := introductionPointIndex(route)
if isBlinded {
// Report success for all the pairs until the introduction
// point.
i.successPairRange(route, 0, introIdx-1)
// Handle the blinded route part.
//
// NOTE: The introIdx index here does describe the node after
// the introduction point.
i.markBlindedRouteSuccess(route, introIdx)
return
}
// Mark nodes as successful in the non-blinded case of the payment.
i.successPairRange(route, 0, len(route.hops.Val)-1)
}
// processFail processes a failed payment attempt.
func (i *interpretedResult) processFail(rt *mcRoute, failure paymentFailure) {
if failure.info.IsNone() {
i.processPaymentOutcomeUnknown(rt)
return
}
var (
idx int
failMsg lnwire.FailureMessage
)
failure.info.WhenSome(
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
idx = int(r.Val.sourceIdx.Val)
failMsg = r.Val.msg.Val.FailureMessage
},
)
// If the payment was to a blinded route and we received an error from
// after the introduction point, handle this error separately - there
// has been a protocol violation from the introduction node. This
// penalty applies regardless of the error code that is returned.
introIdx, isBlinded := introductionPointIndex(rt)
if isBlinded && introIdx < idx {
i.processPaymentOutcomeBadIntro(rt, introIdx, idx)
return
}
switch idx {
// We are the source of the failure.
case 0:
i.processPaymentOutcomeSelf(rt, failMsg)
// A failure from the final hop was received.
case len(rt.hops.Val):
i.processPaymentOutcomeFinal(rt, failMsg)
// An intermediate hop failed. Interpret the outcome, update reputation
// and try again.
default:
i.processPaymentOutcomeIntermediate(rt, idx, failMsg)
}
}
// processPaymentOutcomeBadIntro handles the case where we have made payment
// to a blinded route, but received an error from a node after the introduction
// node. This indicates that the introduction node is not obeying the route
// blinding specification, as we expect all errors from the introduction node
// to be source from it.
func (i *interpretedResult) processPaymentOutcomeBadIntro(route *mcRoute,
introIdx, errSourceIdx int) {
// We fail the introduction node for not obeying the specification.
i.failNode(route, introIdx)
// Other preceding channels in the route forwarded correctly. Note
// that we do not assign success to the incoming link to the
// introduction node because it has not handled the error correctly.
if introIdx > 1 {
i.successPairRange(route, 0, introIdx-2)
}
// If the source of the failure was from the final node, we also set
// a final failure reason because the recipient can't process the
// payment (independent of the introduction failing to convert the
// error, we can't complete the payment if the last hop fails).
if errSourceIdx == len(route.hops.Val) {
i.finalFailureReason = &reasonError
}
}
// processPaymentOutcomeSelf handles failures sent by ourselves.
func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute,
failure lnwire.FailureMessage) {
switch failure.(type) {
// We receive a malformed htlc failure from our peer. We trust ourselves
// to send the correct htlc, so our peer must be at fault.
case *lnwire.FailInvalidOnionVersion,
*lnwire.FailInvalidOnionHmac,
*lnwire.FailInvalidOnionKey:
i.failNode(rt, 1)
// If this was a payment to a direct peer, we can stop trying.
if len(rt.hops.Val) == 1 {
i.finalFailureReason = &reasonError
}
// Any other failure originating from ourselves should be temporary and
// caused by changing conditions between path finding and execution of
// the payment. We just retry and trust that the information locally
// available in the link has been updated.
default:
log.Warnf("Routing failure for local channel %v occurred",
rt.hops.Val[0].channelID)
}
}
// processPaymentOutcomeFinal handles failures sent by the final hop.
func (i *interpretedResult) processPaymentOutcomeFinal(route *mcRoute,
failure lnwire.FailureMessage) {
n := len(route.hops.Val)
failNode := func() {
i.failNode(route, n)
// Other channels in the route forwarded correctly.
if n > 1 {
i.successPairRange(route, 0, n-2)
}
i.finalFailureReason = &reasonError
}
// If a failure from the final node is received, we will fail the
// payment in almost all cases. Only when the penultimate node sends an
// incorrect htlc, we want to retry via another route. Invalid onion
// failures are not expected, because the final node wouldn't be able to
// encrypt that failure.
switch failure.(type) {
// Expiry or amount of the HTLC doesn't match the onion, try another
// route.
case *lnwire.FailFinalIncorrectCltvExpiry,
*lnwire.FailFinalIncorrectHtlcAmount:
// We trust ourselves. If this is a direct payment, we penalize
// the final node and fail the payment.
if n == 1 {
i.failNode(route, n)
i.finalFailureReason = &reasonError
return
}
// Otherwise penalize the last pair of the route and retry.
// Either the final node is at fault, or it gets sent a bad htlc
// from its predecessor.
i.failPair(route, n-1)
// The other hops relayed correctly, so assign those pairs a
// success result. At this point, n >= 2.
i.successPairRange(route, 0, n-2)
// We are using wrong payment hash or amount, fail the payment.
case *lnwire.FailIncorrectPaymentAmount,
*lnwire.FailIncorrectDetails:
// Assign all pairs a success result, as the payment reached the
// destination correctly.
i.successPairRange(route, 0, n-1)
i.finalFailureReason = &reasonIncorrectDetails
// The HTLC that was extended to the final hop expires too soon. Fail
// the payment, because we may be using the wrong final cltv delta.
case *lnwire.FailFinalExpiryTooSoon:
// TODO(roasbeef): can happen to to race condition, try again
// with recent block height
// TODO(joostjager): can also happen because a node delayed
// deliberately. What to penalize?
i.finalFailureReason = &reasonIncorrectDetails
case *lnwire.FailMPPTimeout:
// Assign all pairs a success result, as the payment reached the
// destination correctly. Continue the payment process.
i.successPairRange(route, 0, n-1)
// We do not expect to receive an invalid blinding error from the final
// node in the route. This could erroneously happen in the following
// cases:
// 1. Unblinded node: misuses the error code.
// 2. A receiving introduction node: erroneously sends the error code,
// as the spec indicates that receiving introduction nodes should
// use regular errors.
//
// Note that we expect the case where this error is sent from a node
// after the introduction node to be handled elsewhere as this is part
// of a more general class of errors where the introduction node has
// failed to convert errors for the blinded route.
case *lnwire.FailInvalidBlinding:
failNode()
// All other errors are considered terminal if coming from the
// final hop. They indicate that something is wrong at the
// recipient, so we do apply a penalty.
default:
failNode()
}
}
// processPaymentOutcomeIntermediate handles failures sent by an intermediate
// hop.
//
//nolint:funlen
func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute,
errorSourceIdx int, failure lnwire.FailureMessage) {
reportOutgoing := func() {
i.failPair(
route, errorSourceIdx,
)
}
reportOutgoingBalance := func() {
i.failPairBalance(
route, errorSourceIdx,
)
// All nodes up to the failing pair must have forwarded
// successfully.
i.successPairRange(route, 0, errorSourceIdx-1)
}
reportIncoming := func() {
// We trust ourselves. If the error comes from the first hop, we
// can penalize the whole node. In that case there is no
// uncertainty as to which node to blame.
if errorSourceIdx == 1 {
i.failNode(route, errorSourceIdx)
return
}
// Otherwise report the incoming pair.
i.failPair(
route, errorSourceIdx-1,
)
// All nodes up to the failing pair must have forwarded
// successfully.
if errorSourceIdx > 1 {
i.successPairRange(route, 0, errorSourceIdx-2)
}
}
reportNode := func() {
// Fail only the node that reported the failure.
i.failNode(route, errorSourceIdx)
// Other preceding channels in the route forwarded correctly.
if errorSourceIdx > 1 {
i.successPairRange(route, 0, errorSourceIdx-2)
}
}
reportAll := func() {
// We trust ourselves. If the error comes from the first hop, we
// can penalize the whole node. In that case there is no
// uncertainty as to which node to blame.
if errorSourceIdx == 1 {
i.failNode(route, errorSourceIdx)
return
}
// Otherwise penalize all pairs up to the error source. This
// includes our own outgoing connection.
i.failPairRange(
route, 0, errorSourceIdx-1,
)
}
switch failure.(type) {
// If a node reports onion payload corruption or an invalid version,
// that node may be responsible, but it could also be that it is just
// relaying a malformed htlc failure from it successor. By reporting the
// outgoing channel set, we will surely hit the responsible node. At
// this point, it is not possible that the node's predecessor corrupted
// the onion blob. If the predecessor would have corrupted the payload,
// the error source wouldn't have been able to encrypt this failure
// message for us.
case *lnwire.FailInvalidOnionVersion,
*lnwire.FailInvalidOnionHmac,
*lnwire.FailInvalidOnionKey:
reportOutgoing()
// If InvalidOnionPayload is received, we penalize only the reporting
// node. We know the preceding hop didn't corrupt the onion, since the
// reporting node is able to send the failure. We assume that we
// constructed a valid onion payload and that the failure is most likely
// an unknown required type or a bug in their implementation.
case *lnwire.InvalidOnionPayload:
reportNode()
// If the next hop in the route wasn't known or offline, we'll only
// penalize the channel set which we attempted to route over. This is
// conservative, and it can handle faulty channels between nodes
// properly. Additionally, this guards against routing nodes returning
// errors in order to attempt to black list another node.
case *lnwire.FailUnknownNextPeer:
reportOutgoing()
// Some implementations use this error when the next hop is offline, so we
// do the same as FailUnknownNextPeer and also process the channel update.
case *lnwire.FailChannelDisabled:
// Set the node pair for which a channel update may be out of
// date. The second chance logic uses the policyFailure field.
i.policyFailure = &DirectedNodePair{
From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val,
To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val,
}
reportOutgoing()
// All nodes up to the failing pair must have forwarded
// successfully.
i.successPairRange(route, 0, errorSourceIdx-1)
// If we get a permanent channel, we'll prune the channel set in both
// directions and continue with the rest of the routes.
case *lnwire.FailPermanentChannelFailure:
reportOutgoing()
// When an HTLC parameter is incorrect, the node sending the error may
// be doing something wrong. But it could also be that its predecessor
// is intentionally modifying the htlc parameters that we instructed it
// via the hop payload. Therefore we penalize the incoming node pair. A
// third cause of this error may be that we have an out of date channel
// update. This is handled by the second chance logic up in mission
// control.
case *lnwire.FailAmountBelowMinimum,
*lnwire.FailFeeInsufficient,
*lnwire.FailIncorrectCltvExpiry:
// Set the node pair for which a channel update may be out of
// date. The second chance logic uses the policyFailure field.
i.policyFailure = &DirectedNodePair{
From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val,
To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val,
}
// We report incoming channel. If a second pair is granted in
// mission control, this report is ignored.
reportIncoming()
// If the outgoing channel doesn't have enough capacity, we penalize.
// But we penalize only in a single direction and only for amounts
// greater than the attempted amount.
case *lnwire.FailTemporaryChannelFailure:
reportOutgoingBalance()
// If FailExpiryTooSoon is received, there must have been some delay
// along the path. We can't know which node is causing the delay, so we
// penalize all of them up to the error source.
//
// Alternatively it could also be that we ourselves have fallen behind
// somehow. We ignore that case for now.
case *lnwire.FailExpiryTooSoon:
reportAll()
// We only expect to get FailInvalidBlinding from an introduction node
// in a blinded route. The introduction node in a blinded route is
// always responsible for reporting errors for the blinded portion of
// the route (to protect the privacy of the members of the route), so
// we need to be careful not to unfairly "shoot the messenger".
//
// The introduction node has no incentive to falsely report errors to
// sabotage the blinded route because:
// 1. Its ability to route this payment is strictly tied to the
// blinded route.
// 2. The pubkeys in the blinded route are ephemeral, so doing so
// will have no impact on the nodes beyond the individual payment.
//
// Here we handle a few cases where we could unexpectedly receive this
// error:
// 1. Outside of a blinded route: erring node is not spec compliant.
// 2. Before the introduction point: erring node is not spec compliant.
//
// Note that we expect the case where this error is sent from a node
// after the introduction node to be handled elsewhere as this is part
// of a more general class of errors where the introduction node has
// failed to convert errors for the blinded route.
case *lnwire.FailInvalidBlinding:
introIdx, isBlinded := introductionPointIndex(route)
// Deal with cases where a node has incorrectly returned a
// blinding error:
// 1. A node before the introduction point returned it.
// 2. A node in a non-blinded route returned it.
if errorSourceIdx < introIdx || !isBlinded {
reportNode()
return
}
// Otherwise, the error was at the introduction node. All
// nodes up until the introduction node forwarded correctly,
// so we award them as successful.
if introIdx >= 1 {
i.successPairRange(route, 0, introIdx-1)
}
// If the hop after the introduction node that sent us an
// error is the final recipient, then we finally fail the
// payment because the receiver has generated a blinded route
// that they're unable to use. We have this special case so
// that we don't penalize the introduction node, and there is
// no point in retrying the payment while LND only supports
// one blinded route per payment.
//
// Note that if LND is extended to support multiple blinded
// routes, this will terminate the payment without re-trying
// the other routes.
if introIdx == len(route.hops.Val)-1 {
i.finalFailureReason = &reasonError
} else {
// We penalize the final hop of the blinded route which
// is sufficient to not reuse this route again and is
// also more memory efficient because the other hops
// of the blinded path are ephemeral and will only be
// used in conjunction with the final hop. Moreover we
// don't want to punish the introduction node because
// the blinded failure does not necessarily mean that
// the introduction node was at fault.
//
// TODO(ziggie): Make sure we only keep mc data for
// blinded paths, in both the success and failure case,
// in memory during the time of the payment and remove
// it afterwards. Blinded paths and their blinded hop
// keys are always changing per blinded route so there
// is no point in persisting this data.
i.failBlindedRoute(route)
}
// In all other cases, we penalize the reporting node. These are all
// failures that should not happen.
default:
i.failNode(route, errorSourceIdx)
}
}
// introductionPointIndex returns the index of an introduction point in a
// route, using the same indexing in the route that we use for errorSourceIdx
// (i.e., that we consider our own node to be at index zero). A boolean is
// returned to indicate whether the route contains a blinded portion at all.
func introductionPointIndex(route *mcRoute) (int, bool) {
for i, hop := range route.hops.Val {
if hop.hasBlindingPoint.IsSome() {
return i + 1, true
}
}
return 0, false
}
// processPaymentOutcomeUnknown processes a payment outcome for which no failure
// message or source is available.
func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) {
n := len(route.hops.Val)
// If this is a direct payment, the destination must be at fault.
if n == 1 {
i.failNode(route, n)
i.finalFailureReason = &reasonError
return
}
// Otherwise penalize all channels in the route to make sure the
// responsible node is at least hit too. We even penalize the connection
// to our own peer, because that peer could also be responsible.
i.failPairRange(route, 0, n-1)
}
// extractMCRoute extracts the fields required by MC from the Route struct to
// create the more minimal mcRoute struct.
func extractMCRoute(r *route.Route) *mcRoute {
return &mcRoute{
sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey),
totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount),
hops: tlv.NewRecordT[tlv.TlvType2](
extractMCHops(r.Hops),
),
}
}
// extractMCHops extracts the Hop fields that MC actually uses from a slice of
// Hops.
func extractMCHops(hops []*route.Hop) mcHops {
return fn.Map(hops, extractMCHop)
}
// extractMCHop extracts the Hop fields that MC actually uses from a Hop.
func extractMCHop(hop *route.Hop) *mcHop {
h := mcHop{
channelID: tlv.NewPrimitiveRecord[tlv.TlvType0](
hop.ChannelID,
),
pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](hop.PubKeyBytes),
amtToFwd: tlv.NewRecordT[tlv.TlvType2](hop.AmtToForward),
}
if hop.BlindingPoint != nil {
h.hasBlindingPoint = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType3](lnwire.TrueBoolean{}),
)
}
if hop.CustomRecords != nil {
h.hasCustomRecords = tlv.SomeRecordT(
tlv.NewRecordT[tlv.TlvType4](lnwire.TrueBoolean{}),
)
}
return &h
}
// mcRoute holds the bare minimum info about a payment attempt route that MC
// requires.
type mcRoute struct {
sourcePubKey tlv.RecordT[tlv.TlvType0, route.Vertex]
totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi]
hops tlv.RecordT[tlv.TlvType2, mcHops]
}
// Record returns a TLV record that can be used to encode/decode an mcRoute
// to/from a TLV stream.
func (r *mcRoute) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCRoute(&b, r, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, r, recordSize, encodeMCRoute, decodeMCRoute,
)
}
func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error {
if v, ok := val.(*mcRoute); ok {
return serializeRoute(w, v)
}
return tlv.NewTypeForEncodingErr(val, "routing.mcRoute")
}
func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if v, ok := val.(*mcRoute); ok {
route, err := deserializeRoute(io.LimitReader(r, int64(l)))
if err != nil {
return err
}
*v = *route
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l)
}
// mcHops is a list of mcHop records.
type mcHops []*mcHop
// Record returns a TLV record that can be used to encode/decode a list of
// mcHop to/from a TLV stream.
func (h *mcHops) Record() tlv.Record {
recordSize := func() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
if err := encodeMCHops(&b, h, &buf); err != nil {
panic(err)
}
return uint64(len(b.Bytes()))
}
return tlv.MakeDynamicRecord(
0, h, recordSize, encodeMCHops, decodeMCHops,
)
}
func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*mcHops); ok {
// Encode the number of hops as a var int.
if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for _, hop := range *v {
var hopBytes bytes.Buffer
if err := serializeHop(&hopBytes, hop); err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(hopBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
if _, err := w.Write(hopBytes.Bytes()); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "routing.mcHops")
}
func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*mcHops); ok {
// First, we'll decode the varint that encodes how many hops
// are encoded in the stream.
numHops, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numHops; i++ {
// Read out the varint that encodes the size of this
// inner TLV record.
hopSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := &io.LimitedReader{
R: r,
N: int64(hopSize),
}
hop, err := deserializeHop(innerTlvReader)
if err != nil {
return err
}
*v = append(*v, hop)
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l)
}
// mcHop holds the bare minimum info about a payment attempt route hop that MC
// requires.
type mcHop struct {
channelID tlv.RecordT[tlv.TlvType0, uint64]
pubKeyBytes tlv.RecordT[tlv.TlvType1, route.Vertex]
amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi]
hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean]
hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean]
}
// failNode marks the node indicated by idx in the route as failed. It also
// marks the incoming and outgoing channels of the node as failed. This function
// intentionally panics when the self node is failed.
func (i *interpretedResult) failNode(rt *mcRoute, idx int) {
// Mark the node as failing.
i.nodeFailure = &rt.hops.Val[idx-1].pubKeyBytes.Val
// Mark the incoming connection as failed for the node. We intent to
// penalize as much as we can for a node level failure, including future
// outgoing traffic for this connection. The pair as it is returned by
// getPair is penalized in the original and the reversed direction. Note
// that this will also affect the score of the failing node's peers.
// This is necessary to prevent future routes from keep going into the
// same node again.
incomingChannelIdx := idx - 1
inPair, _ := getPair(rt, incomingChannelIdx)
i.pairResults[inPair] = failPairResult(0)
i.pairResults[inPair.Reverse()] = failPairResult(0)
// If not the ultimate node, mark the outgoing connection as failed for
// the node.
if idx < len(rt.hops.Val) {
outgoingChannelIdx := idx
outPair, _ := getPair(rt, outgoingChannelIdx)
i.pairResults[outPair] = failPairResult(0)
i.pairResults[outPair.Reverse()] = failPairResult(0)
}
}
// failPairRange marks the node pairs from node fromIdx to node toIdx as failed
// in both direction.
func (i *interpretedResult) failPairRange(rt *mcRoute, fromIdx, toIdx int) {
for idx := fromIdx; idx <= toIdx; idx++ {
i.failPair(rt, idx)
}
}
// failPair marks a pair as failed in both directions.
func (i *interpretedResult) failPair(rt *mcRoute, idx int) {
pair, _ := getPair(rt, idx)
// Report pair in both directions without a minimum penalization amount.
i.pairResults[pair] = failPairResult(0)
i.pairResults[pair.Reverse()] = failPairResult(0)
}
// failPairBalance marks a pair as failed with a minimum penalization amount.
func (i *interpretedResult) failPairBalance(rt *mcRoute, channelIdx int) {
pair, amt := getPair(rt, channelIdx)
i.pairResults[pair] = failPairResult(amt)
}
// successPairRange marks the node pairs from node fromIdx to node toIdx as
// succeeded.
func (i *interpretedResult) successPairRange(rt *mcRoute, fromIdx, toIdx int) {
for idx := fromIdx; idx <= toIdx; idx++ {
pair, amt := getPair(rt, idx)
i.pairResults[pair] = successPairResult(amt)
}
}
// failBlindedRoute marks a blinded route as failed for the specific amount to
// send by only punishing the last pair.
func (i *interpretedResult) failBlindedRoute(rt *mcRoute) {
// We fail the last pair of the route, in order to fail the complete
// blinded route. This is because the combination of ephemeral pubkeys
// is unique to the route. We fail the last pair in order to not punish
// the introduction node, since we don't want to disincentivize them
// from providing that service.
pair, _ := getPair(rt, len(rt.hops.Val)-1)
// Since all the hops along a blinded path don't have any amount set, we
// extract the minimal amount to punish from the value that is tried to
// be sent to the receiver.
amt := rt.hops.Val[len(rt.hops.Val)-1].amtToFwd.Val
i.pairResults[pair] = failPairResult(amt)
}
// markBlindedRouteSuccess marks the hops of the blinded route AFTER the
// introduction node as successful.
//
// NOTE: The introIdx must be the index of the first hop of the blinded route
// AFTER the introduction node.
func (i *interpretedResult) markBlindedRouteSuccess(rt *mcRoute, introIdx int) {
// For blinded hops we do not have the forwarding amount so we take the
// minimal amount which went through the route by looking at the last
// hop.
successAmt := rt.hops.Val[len(rt.hops.Val)-1].amtToFwd.Val
for idx := introIdx; idx < len(rt.hops.Val); idx++ {
pair, _ := getPair(rt, idx)
i.pairResults[pair] = successPairResult(successAmt)
}
}
// getPair returns a node pair from the route and the amount passed between that
// pair.
func getPair(rt *mcRoute, channelIdx int) (DirectedNodePair,
lnwire.MilliSatoshi) {
nodeTo := rt.hops.Val[channelIdx].pubKeyBytes.Val
var (
nodeFrom route.Vertex
amt lnwire.MilliSatoshi
)
if channelIdx == 0 {
nodeFrom = rt.sourcePubKey.Val
amt = rt.totalAmount.Val
} else {
nodeFrom = rt.hops.Val[channelIdx-1].pubKeyBytes.Val
amt = rt.hops.Val[channelIdx-1].amtToFwd.Val
}
pair := NewDirectedNodePair(nodeFrom, nodeTo)
return pair, amt
}
package route
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"strconv"
"strings"
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
)
// VertexSize is the size of the array to store a vertex.
const VertexSize = 33
var (
// ErrNoRouteHopsProvided is returned when a caller attempts to
// construct a new sphinx packet, but provides an empty set of hops for
// each route.
ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided")
// ErrMaxRouteHopsExceeded is returned when a caller attempts to
// construct a new sphinx packet, but provides too many hops.
ErrMaxRouteHopsExceeded = fmt.Errorf("route has too many hops")
// ErrIntermediateMPPHop is returned when a hop tries to deliver an MPP
// record to an intermediate hop, only final hops can receive MPP
// records.
ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate")
// ErrAMPMissingMPP is returned when the caller tries to attach an AMP
// record but no MPP record is presented for the final hop.
ErrAMPMissingMPP = errors.New("cannot send AMP without MPP record")
// ErrMissingField is returned if a required TLV is missing.
ErrMissingField = errors.New("required tlv missing")
// ErrUnexpectedField is returned if a tlv field is included when it
// should not be.
ErrUnexpectedField = errors.New("unexpected tlv included")
)
// Vertex is a simple alias for the serialization of a compressed Bitcoin
// public key.
type Vertex [VertexSize]byte
// NewVertex returns a new Vertex given a public key.
func NewVertex(pub *btcec.PublicKey) Vertex {
var v Vertex
copy(v[:], pub.SerializeCompressed())
return v
}
// NewVertexFromBytes returns a new Vertex based on a serialized pubkey in a
// byte slice.
func NewVertexFromBytes(b []byte) (Vertex, error) {
vertexLen := len(b)
if vertexLen != VertexSize {
return Vertex{}, fmt.Errorf("invalid vertex length of %v, "+
"want %v", vertexLen, VertexSize)
}
var v Vertex
copy(v[:], b)
return v, nil
}
// NewVertexFromStr returns a new Vertex given its hex-encoded string format.
func NewVertexFromStr(v string) (Vertex, error) {
// Return error if hex string is of incorrect length.
if len(v) != VertexSize*2 {
return Vertex{}, fmt.Errorf("invalid vertex string length of "+
"%v, want %v", len(v), VertexSize*2)
}
vertex, err := hex.DecodeString(v)
if err != nil {
return Vertex{}, err
}
return NewVertexFromBytes(vertex)
}
// String returns a human readable version of the Vertex which is the
// hex-encoding of the serialized compressed public key.
func (v Vertex) String() string {
return fmt.Sprintf("%x", v[:])
}
// Record returns a TLV record that can be used to encode/decode a Vertex
// to/from a TLV stream.
func (v *Vertex) Record() tlv.Record {
return tlv.MakeStaticRecord(
0, v, VertexSize, encodeVertex, decodeVertex,
)
}
func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error {
if b, ok := val.(*Vertex); ok {
_, err := w.Write(b[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "Vertex")
}
func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error {
if b, ok := val.(*Vertex); ok {
_, err := io.ReadFull(r, b[:])
return err
}
return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize)
}
// Hop represents an intermediate or final node of the route. This naming
// is in line with the definition given in BOLT #4: Onion Routing Protocol.
// The struct houses the channel along which this hop can be reached and
// the values necessary to create the HTLC that needs to be sent to the
// next hop. It is also used to encode the per-hop payload included within
// the Sphinx packet.
type Hop struct {
// PubKeyBytes is the raw bytes of the public key of the target node.
PubKeyBytes Vertex
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// OutgoingTimeLock is the timelock value that should be used when
// crafting the _outgoing_ HTLC from this hop.
OutgoingTimeLock uint32
// AmtToForward is the amount that this hop will forward to the next
// hop. This value is less than the value that the incoming HTLC
// carries as a fee will be subtracted by the hop.
AmtToForward lnwire.MilliSatoshi
// MPP encapsulates the data required for option_mpp. This field should
// only be set for the final hop.
MPP *record.MPP
// AMP encapsulates the data required for option_amp. This field should
// only be set for the final hop.
AMP *record.AMP
// CustomRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node.
CustomRecords record.CustomSet
// LegacyPayload if true, then this signals that this node doesn't
// understand the new TLV payload, so we must instead use the legacy
// payload.
//
// NOTE: we should no longer ever create a Hop with Legacy set to true.
// The only reason we are keeping this member is that it could be the
// case that we have serialised hops persisted to disk where
// LegacyPayload is true.
LegacyPayload bool
// Metadata is additional data that is sent along with the payment to
// the payee.
Metadata []byte
// EncryptedData is an encrypted data blob includes for hops that are
// part of a blinded route.
EncryptedData []byte
// BlindingPoint is an ephemeral public key used by introduction nodes
// in blinded routes to unblind their portion of the route and pass on
// the next ephemeral key to the next blinded node to do the same.
BlindingPoint *btcec.PublicKey
// TotalAmtMsat is the total amount for a blinded payment, potentially
// spread over more than one HTLC. This field should only be set for
// the final hop in a blinded path.
TotalAmtMsat lnwire.MilliSatoshi
}
// Copy returns a deep copy of the Hop.
func (h *Hop) Copy() *Hop {
c := *h
if h.MPP != nil {
m := *h.MPP
c.MPP = &m
}
if h.AMP != nil {
a := *h.AMP
c.AMP = &a
}
if h.BlindingPoint != nil {
b := *h.BlindingPoint
c.BlindingPoint = &b
}
return &c
}
// PackHopPayload writes to the passed io.Writer, the series of byes that can
// be placed directly into the per-hop payload (EOB) for this hop. This will
// include the required routing fields, as well as serializing any of the
// passed optional TLVRecords. nextChanID is the unique channel ID that
// references the _outgoing_ channel ID that follows this hop. The lastHop bool
// is used to signal whether this hop is the final hop in a route. Previously,
// a zero nextChanID would be used for this purpose, but with the addition of
// blinded routes which allow zero nextChanID values for intermediate hops we
// add an explicit signal.
func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64,
finalHop bool) error {
// If this is a legacy payload, then we'll exit here as this method
// shouldn't be called.
if h.LegacyPayload {
return fmt.Errorf("cannot pack hop payloads for legacy " +
"payloads")
}
// Otherwise, we'll need to make a new stream that includes our
// required routing fields, as well as these optional values.
var records []tlv.Record
// Hops that are not part of a blinded path will have an amount and
// a CLTV expiry field. In a blinded route (where encrypted data is
// non-nil), these values may be omitted for intermediate nodes.
// Validate these fields against the structure of the payload so that
// we know they're included (or excluded) correctly.
isBlinded := h.EncryptedData != nil
if err := optionalBlindedField(
h.AmtToForward == 0, isBlinded, finalHop,
); err != nil {
return fmt.Errorf("%w: amount to forward: %v", err,
h.AmtToForward)
}
if err := optionalBlindedField(
h.OutgoingTimeLock == 0, isBlinded, finalHop,
); err != nil {
return fmt.Errorf("%w: outgoing timelock: %v", err,
h.OutgoingTimeLock)
}
// Once we've validated that these TLVs are set as we expect, we can
// go ahead and include them if non-zero.
amt := uint64(h.AmtToForward)
if amt != 0 {
records = append(
records, record.NewAmtToFwdRecord(&amt),
)
}
if h.OutgoingTimeLock != 0 {
records = append(
records, record.NewLockTimeRecord(&h.OutgoingTimeLock),
)
}
// Validate channel TLV is present as expected based on location in
// route and whether this hop is blinded.
err := validateNextChanID(nextChanID != 0, isBlinded, finalHop)
if err != nil {
return fmt.Errorf("%w: channel id: %v", err, nextChanID)
}
if nextChanID != 0 {
records = append(records,
record.NewNextHopIDRecord(&nextChanID),
)
}
// If an MPP record is destined for this hop, ensure that we only ever
// attach it to the final hop. Otherwise the route was constructed
// incorrectly.
if h.MPP != nil {
if finalHop {
records = append(records, h.MPP.Record())
} else {
return ErrIntermediateMPPHop
}
}
// Add encrypted data and blinding point if present.
if h.EncryptedData != nil {
records = append(records, record.NewEncryptedDataRecord(
&h.EncryptedData,
))
}
if h.BlindingPoint != nil {
records = append(records, record.NewBlindingPointRecord(
&h.BlindingPoint,
))
}
// If an AMP record is destined for this hop, ensure that we only ever
// attach it if we also have an MPP record. We can infer that this is
// already a final hop if MPP is non-nil otherwise we would have exited
// above.
if h.AMP != nil {
if h.MPP != nil {
records = append(records, h.AMP.Record())
} else {
return ErrAMPMissingMPP
}
}
// If metadata is specified, generate a tlv record for it.
if h.Metadata != nil {
records = append(records,
record.NewMetadataRecord(&h.Metadata),
)
}
if h.TotalAmtMsat != 0 {
totalAmtInt := uint64(h.TotalAmtMsat)
records = append(records,
record.NewTotalAmtMsatBlinded(&totalAmtInt),
)
}
// Append any custom types destined for this hop.
tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
// To ensure we produce a canonical stream, we'll sort the records
// before encoding them as a stream in the hop payload.
tlv.SortRecords(records)
tlvStream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return tlvStream.Encode(w)
}
// optionalBlindedField validates fields that we expect to be non-zero for all
// hops in a regular route, but may be zero for intermediate nodes in a blinded
// route. It will validate the following cases:
// - Not blinded: require non-zero values.
// - Intermediate blinded node: require zero values.
// - Final blinded node: require non-zero values.
func optionalBlindedField(isZero, blindedHop, finalHop bool) error {
switch {
// We are not in a blinded route and the TLV is not set when it should
// be.
case !blindedHop && isZero:
return ErrMissingField
// We are not in a blinded route and the TLV is set as expected.
case !blindedHop:
return nil
// In a blinded route the final hop is expected to have TLV values set.
case finalHop && isZero:
return ErrMissingField
// In an intermediate hop in a blinded route and the field is not zero.
case !finalHop && !isZero:
return ErrUnexpectedField
}
return nil
}
// validateNextChanID validates the presence of the nextChanID TLV field in
// a payload. For regular payments, it is expected to be present for all hops
// except the final hop. For blinded paths, it is not expected to be included
// at all (as this value is provided in encrypted data).
func validateNextChanID(nextChanIDIsSet, isBlinded, finalHop bool) error {
switch {
// Hops in a blinded route should not have a next channel ID set.
case isBlinded && nextChanIDIsSet:
return ErrUnexpectedField
// Otherwise, blinded hops are allowed to have a zero value.
case isBlinded:
return nil
// The final hop in a regular route is expected to have a zero value.
case finalHop && nextChanIDIsSet:
return ErrUnexpectedField
// Intermediate hops in regular routes require non-zero value.
case !finalHop && !nextChanIDIsSet:
return ErrMissingField
default:
return nil
}
}
// PayloadSize returns the total size this hop's payload would take up in the
// onion packet.
func (h *Hop) PayloadSize(nextChanID uint64) uint64 {
if h.LegacyPayload {
return sphinx.LegacyHopDataSize
}
var payloadSize uint64
addRecord := func(tlvType tlv.Type, length uint64) {
payloadSize += tlv.VarIntSize(uint64(tlvType)) +
tlv.VarIntSize(length) + length
}
// Add amount size.
if h.AmtToForward != 0 {
addRecord(record.AmtOnionType, tlv.SizeTUint64(
uint64(h.AmtToForward),
))
}
// Add lock time size.
if h.OutgoingTimeLock != 0 {
addRecord(
record.LockTimeOnionType,
tlv.SizeTUint64(uint64(h.OutgoingTimeLock)),
)
}
// Add next hop if present.
if nextChanID != 0 {
addRecord(record.NextHopOnionType, 8)
}
// Add mpp if present.
if h.MPP != nil {
addRecord(record.MPPOnionType, h.MPP.PayloadSize())
}
// Add amp if present.
if h.AMP != nil {
addRecord(record.AMPOnionType, h.AMP.PayloadSize())
}
// Add encrypted data and blinding point if present.
if h.EncryptedData != nil {
addRecord(
record.EncryptedDataOnionType,
uint64(len(h.EncryptedData)),
)
}
if h.BlindingPoint != nil {
addRecord(
record.BlindingPointOnionType,
btcec.PubKeyBytesLenCompressed,
)
}
// Add metadata if present.
if h.Metadata != nil {
addRecord(record.MetadataOnionType, uint64(len(h.Metadata)))
}
if h.TotalAmtMsat != 0 {
addRecord(
record.TotalAmtMsatBlindedType,
tlv.SizeTUint64(uint64(h.AmtToForward)),
)
}
// Add custom records.
for k, v := range h.CustomRecords {
addRecord(tlv.Type(k), uint64(len(v)))
}
// Add the size required to encode the payload length.
payloadSize += tlv.VarIntSize(payloadSize)
// Add HMAC.
payloadSize += sphinx.HMACSize
return payloadSize
}
// Route represents a path through the channel graph which runs over one or
// more channels in succession. This struct carries all the information
// required to craft the Sphinx onion packet, and send the payment along the
// first hop in the path. A route is only selected as valid if all the channels
// have sufficient capacity to carry the initial payment amount after fees are
// accounted for.
type Route struct {
// TotalTimeLock is the cumulative (final) time lock across the entire
// route. This is the CLTV value that should be extended to the first
// hop in the route. All other hops will decrement the time-lock as
// advertised, leaving enough time for all hops to wait for or present
// the payment preimage to complete the payment.
TotalTimeLock uint32
// TotalAmount is the total amount of funds required to complete a
// payment over this route. This value includes the cumulative fees at
// each hop. As a result, the HTLC extended to the first-hop in the
// route will need to have at least this many satoshis, otherwise the
// route will fail at an intermediate node due to an insufficient
// amount of fees.
TotalAmount lnwire.MilliSatoshi
// SourcePubKey is the pubkey of the node where this route originates
// from.
SourcePubKey Vertex
// Hops contains details concerning the specific forwarding details at
// each hop.
Hops []*Hop
// FirstHopAmount is the amount that should actually be sent to the
// first hop in the route. This is only different from TotalAmount above
// for custom channels where the on-chain amount doesn't necessarily
// reflect all the value of an outgoing payment.
FirstHopAmount tlv.RecordT[
tlv.TlvType0, tlv.BigSizeT[lnwire.MilliSatoshi],
]
// FirstHopWireCustomRecords is a set of custom records that should be
// included in the wire message sent to the first hop. This is only set
// on custom channels and is used to include additional information
// about the actual value of the payment.
//
// NOTE: Since these records already represent TLV records, and we
// enforce them to be in the custom range (e.g. >= 65536), we don't use
// another parent record type here. Instead, when serializing the Route
// we merge the TLV records together with the custom records and encode
// everything as a single TLV stream.
FirstHopWireCustomRecords lnwire.CustomRecords
}
// Copy returns a deep copy of the Route.
func (r *Route) Copy() *Route {
c := *r
c.Hops = make([]*Hop, len(r.Hops))
for i := range r.Hops {
c.Hops[i] = r.Hops[i].Copy()
}
return &c
}
// HopFee returns the fee charged by the route hop indicated by hopIndex.
//
// This calculation takes into account the possibility that the route contains
// some blinded hops, that will not have the amount to forward set. We take
// note of various points in the blinded route.
//
// Given the following route where Carol is the introduction node and B2 is
// the recipient, Carol and B1's hops will not have an amount to forward set:
// Alice --- Bob ---- Carol (introduction) ----- B1 ----- B2
//
// We locate ourselves in the route as follows:
// * Regular Hop (eg Alice - Bob):
//
// incomingAmt !=0
// outgoingAmt !=0
// -> Fee = incomingAmt - outgoingAmt
//
// * Introduction Hop (eg Bob - Carol):
//
// incomingAmt !=0
// outgoingAmt = 0
// -> Fee = incomingAmt - receiverAmt
//
// This has the impact of attributing the full fees for the blinded route to
// the introduction node.
//
// * Blinded Intermediate Hop (eg Carol - B1):
//
// incomingAmt = 0
// outgoingAmt = 0
// -> Fee = 0
//
// * Final Blinded Hop (B1 - B2):
//
// incomingAmt = 0
// outgoingAmt !=0
// -> Fee = 0
func (r *Route) HopFee(hopIndex int) lnwire.MilliSatoshi {
var incomingAmt lnwire.MilliSatoshi
if hopIndex == 0 {
incomingAmt = r.TotalAmount
} else {
incomingAmt = r.Hops[hopIndex-1].AmtToForward
}
outgoingAmt := r.Hops[hopIndex].AmtToForward
switch {
// If both incoming and outgoing amounts are set, we're in a normal
// hop
case incomingAmt != 0 && outgoingAmt != 0:
return incomingAmt - outgoingAmt
// If the incoming amount is zero, we're at an intermediate hop in
// a blinded route, so the fee is zero.
case incomingAmt == 0:
return 0
// If we have a non-zero incoming amount and a zero outgoing amount,
// we're at the introduction hop so we express the fees for the full
// blinded route at this hop.
default:
return incomingAmt - r.ReceiverAmt()
}
}
// TotalFees is the sum of the fees paid at each hop within the final route. In
// the case of a one-hop payment, this value will be zero as we don't need to
// pay a fee to ourself.
func (r *Route) TotalFees() lnwire.MilliSatoshi {
if len(r.Hops) == 0 {
return 0
}
return r.TotalAmount - r.ReceiverAmt()
}
// ReceiverAmt is the amount received by the final hop of this route.
func (r *Route) ReceiverAmt() lnwire.MilliSatoshi {
if len(r.Hops) == 0 {
return 0
}
return r.Hops[len(r.Hops)-1].AmtToForward
}
// FinalHop returns the last hop of the route, or nil if the route is empty.
func (r *Route) FinalHop() *Hop {
if len(r.Hops) == 0 {
return nil
}
return r.Hops[len(r.Hops)-1]
}
// NewRouteFromHops creates a new Route structure from the minimally required
// information to perform the payment. It infers fee amounts and populates the
// node, chan and prev/next hop maps.
func NewRouteFromHops(amtToSend lnwire.MilliSatoshi, timeLock uint32,
sourceVertex Vertex, hops []*Hop) (*Route, error) {
if len(hops) == 0 {
return nil, ErrNoRouteHopsProvided
}
// First, we'll create a route struct and populate it with the fields
// for which the values are provided as arguments of this function.
// TotalFees is determined based on the difference between the amount
// that is send from the source and the final amount that is received
// by the destination.
route := &Route{
SourcePubKey: sourceVertex,
Hops: hops,
TotalTimeLock: timeLock,
TotalAmount: amtToSend,
}
return route, nil
}
// ToSphinxPath converts a complete route into a sphinx PaymentPath that
// contains the per-hop payloads used to encoding the HTLC routing data for each
// hop in the route. This method also accepts an optional EOB payload for the
// final hop.
func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) {
var path sphinx.PaymentPath
// We can only construct a route if there are hops provided.
if len(r.Hops) == 0 {
return nil, ErrNoRouteHopsProvided
}
// Check maximum route length.
if len(r.Hops) > sphinx.NumMaxHops {
return nil, ErrMaxRouteHopsExceeded
}
// For each hop encoded within the route, we'll convert the hop struct
// to an OnionHop with matching per-hop payload within the path as used
// by the sphinx package.
for i, hop := range r.Hops {
pub, err := btcec.ParsePubKey(hop.PubKeyBytes[:])
if err != nil {
return nil, err
}
// As a base case, the next hop is set to all zeroes in order
// to indicate that the "last hop" as no further hops after it.
nextHop := uint64(0)
// If we aren't on the last hop, then we set the "next address"
// field to be the channel that directly follows it.
finalHop := i == len(r.Hops)-1
if !finalHop {
nextHop = r.Hops[i+1].ChannelID
}
var payload sphinx.HopPayload
// If this is the legacy payload, then we can just include the
// hop data as normal.
if hop.LegacyPayload {
// Before we encode this value, we'll pack the next hop
// into the NextAddress field of the hop info to ensure
// we point to the right now.
hopData := sphinx.HopData{
ForwardAmount: uint64(hop.AmtToForward),
OutgoingCltv: hop.OutgoingTimeLock,
}
binary.BigEndian.PutUint64(
hopData.NextAddress[:], nextHop,
)
payload, err = sphinx.NewLegacyHopPayload(&hopData)
if err != nil {
return nil, err
}
} else {
// For non-legacy payloads, we'll need to pack the
// routing information, along with any extra TLV
// information into the new per-hop payload format.
// We'll also pass in the chan ID of the hop this
// channel should be forwarded to so we can construct a
// valid payload.
var b bytes.Buffer
err := hop.PackHopPayload(&b, nextHop, finalHop)
if err != nil {
return nil, err
}
payload, err = sphinx.NewTLVHopPayload(b.Bytes())
if err != nil {
return nil, err
}
}
path[i] = sphinx.OnionHop{
NodePub: *pub,
HopPayload: payload,
}
}
return &path, nil
}
// String returns a human readable representation of the route.
func (r *Route) String() string {
var b strings.Builder
amt := r.TotalAmount
for i, hop := range r.Hops {
if i > 0 {
b.WriteString(" -> ")
}
b.WriteString(fmt.Sprintf("%v (%v)",
strconv.FormatUint(hop.ChannelID, 10),
amt,
))
amt = hop.AmtToForward
}
return fmt.Sprintf("%v, cltv %v",
b.String(), r.TotalTimeLock,
)
}
package routing
import (
"context"
"fmt"
"math"
"math/big"
"sort"
"sync"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/amp"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/zpay32"
)
const (
// DefaultPayAttemptTimeout is the default payment attempt timeout. The
// payment attempt timeout defines the duration after which we stop
// trying more routes for a payment.
DefaultPayAttemptTimeout = time.Second * 60
// MinCLTVDelta is the minimum CLTV value accepted by LND for all
// timelock deltas. This includes both forwarding CLTV deltas set on
// channel updates, as well as final CLTV deltas used to create BOLT 11
// payment requests.
//
// NOTE: For payment requests, BOLT 11 stipulates that a final CLTV
// delta of 9 should be used when no value is decoded. This however
// leads to inflexibility in upgrading this default parameter, since it
// can create inconsistencies around the assumed value between sender
// and receiver. Specifically, if the receiver assumes a higher value
// than the sender, the receiver will always see the received HTLCs as
// invalid due to their timelock not meeting the required delta.
//
// We skirt this by always setting an explicit CLTV delta when creating
// invoices. This allows LND nodes to freely update the minimum without
// creating incompatibilities during the upgrade process. For some time
// LND has used an explicit default final CLTV delta of 40 blocks for
// bitcoin, though we now clamp the lower end of this
// range for user-chosen deltas to 18 blocks to be conservative.
MinCLTVDelta = 18
// MaxCLTVDelta is the maximum CLTV value accepted by LND for all
// timelock deltas.
MaxCLTVDelta = math.MaxUint16
)
var (
// ErrRouterShuttingDown is returned if the router is in the process of
// shutting down.
ErrRouterShuttingDown = fmt.Errorf("router shutting down")
// ErrSelfIntro is a failure returned when the source node of a
// route request is also the introduction node. This is not yet
// supported because LND does not support blinded forwardingg.
ErrSelfIntro = errors.New("introduction point as own node not " +
"supported")
// ErrHintsAndBlinded is returned if a route request has both
// bolt 11 route hints and a blinded path set.
ErrHintsAndBlinded = errors.New("bolt 11 route hints and blinded " +
"paths are mutually exclusive")
// ErrExpiryAndBlinded is returned if a final cltv and a blinded path
// are provided, as the cltv should be provided within the blinded
// path.
ErrExpiryAndBlinded = errors.New("final cltv delta and blinded " +
"paths are mutually exclusive")
// ErrTargetAndBlinded is returned is a target destination and a
// blinded path are both set (as the target is inferred from the
// blinded path).
ErrTargetAndBlinded = errors.New("target node and blinded paths " +
"are mutually exclusive")
// ErrNoTarget is returned when the target node for a route is not
// provided by either a blinded route or a cleartext pubkey.
ErrNoTarget = errors.New("destination not set in target or blinded " +
"path")
// ErrSkipTempErr is returned when a non-MPP is made yet the
// skipTempErr flag is set.
ErrSkipTempErr = errors.New("cannot skip temp error for non-MPP")
)
// PaymentAttemptDispatcher is used by the router to send payment attempts onto
// the network, and receive their results.
type PaymentAttemptDispatcher interface {
// SendHTLC is a function that directs a link-layer switch to
// forward a fully encoded payment to the first hop in the route
// denoted by its public key. A non-nil error is to be returned if the
// payment was unsuccessful.
SendHTLC(firstHop lnwire.ShortChannelID,
attemptID uint64,
htlcAdd *lnwire.UpdateAddHTLC) error
// GetAttemptResult returns the result of the payment attempt with
// the given attemptID. The paymentHash should be set to the payment's
// overall hash, or in case of AMP payments the payment's unique
// identifier.
//
// The method returns a channel where the payment result will be sent
// when available, or an error is encountered during forwarding. When a
// result is received on the channel, the HTLC is guaranteed to no
// longer be in flight. The switch shutting down is signaled by
// closing the channel. If the attemptID is unknown,
// ErrPaymentIDNotFound will be returned.
GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error)
// CleanStore calls the underlying result store, telling it is safe to
// delete all entries except the ones in the keepPids map. This should
// be called periodically to let the switch clean up payment results
// that we have handled.
// NOTE: New payment attempts MUST NOT be made after the keepPids map
// has been created and this method has returned.
CleanStore(keepPids map[uint64]struct{}) error
// HasAttemptResult reads the network result store to fetch the
// specified attempt. Returns true if the attempt result exists.
//
// NOTE: This method is used and should only be used by the router to
// resume payments during startup. It can be viewed as a subset of
// `GetAttemptResult` in terms of its functionality, and can be removed
// once we move the construction of `UpdateAddHTLC` and
// `ErrorDecrypter` into `htlcswitch`.
HasAttemptResult(attemptID uint64) (bool, error)
}
// PaymentSessionSource is an interface that defines a source for the router to
// retrieve new payment sessions.
type PaymentSessionSource interface {
// NewPaymentSession creates a new payment session that will produce
// routes to the given target. An optional set of routing hints can be
// provided in order to populate additional edges to explore when
// finding a path to the payment's destination.
NewPaymentSession(p *LightningPayment,
firstHopBlob fn.Option[tlv.Blob],
ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession,
error)
// NewPaymentSessionEmpty creates a new paymentSession instance that is
// empty, and will be exhausted immediately. Used for failure reporting
// to missioncontrol for resumed payment we don't want to make more
// attempts for.
NewPaymentSessionEmpty() PaymentSession
}
// MissionControlQuerier is an interface that exposes failure reporting and
// probability estimation.
type MissionControlQuerier interface {
// ReportPaymentFail reports a failed payment to mission control as
// input for future probability estimates. It returns a bool indicating
// whether this error is a final error and no further payment attempts
// need to be made.
ReportPaymentFail(attemptID uint64, rt *route.Route,
failureSourceIdx *int, failure lnwire.FailureMessage) (
*channeldb.FailureReason, error)
// ReportPaymentSuccess reports a successful payment to mission control
// as input for future probability estimates.
ReportPaymentSuccess(attemptID uint64, rt *route.Route) error
// GetProbability is expected to return the success probability of a
// payment from fromNode along edge.
GetProbability(fromNode, toNode route.Vertex,
amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64
}
// FeeSchema is the set fee configuration for a Lightning Node on the network.
// Using the coefficients described within the schema, the required fee to
// forward outgoing payments can be derived.
type FeeSchema struct {
// BaseFee is the base amount of milli-satoshis that will be chained
// for ANY payment forwarded.
BaseFee lnwire.MilliSatoshi
// FeeRate is the rate that will be charged for forwarding payments.
// This value should be interpreted as the numerator for a fraction
// (fixed point arithmetic) whose denominator is 1 million. As a result
// the effective fee rate charged per mSAT will be: (amount *
// FeeRate/1,000,000).
FeeRate uint32
// InboundFee is the inbound fee schedule that applies to forwards
// coming in through a channel to which this FeeSchema pertains.
InboundFee fn.Option[models.InboundFee]
}
// ChannelPolicy holds the parameters that determine the policy we enforce
// when forwarding payments on a channel. These parameters are communicated
// to the rest of the network in ChannelUpdate messages.
type ChannelPolicy struct {
// FeeSchema holds the fee configuration for a channel.
FeeSchema
// TimeLockDelta is the required HTLC timelock delta to be used
// when forwarding payments.
TimeLockDelta uint32
// MaxHTLC is the maximum HTLC size including fees we are allowed to
// forward over this channel.
MaxHTLC lnwire.MilliSatoshi
// MinHTLC is the minimum HTLC size including fees we are allowed to
// forward over this channel.
MinHTLC *lnwire.MilliSatoshi
}
// Config defines the configuration for the ChannelRouter. ALL elements within
// the configuration MUST be non-nil for the ChannelRouter to carry out its
// duties.
type Config struct {
// SelfNode is the public key of the node that this channel router
// belongs to.
SelfNode route.Vertex
// RoutingGraph is a graph source that will be used for pathfinding.
RoutingGraph Graph
// Chain is the router's source to the most up-to-date blockchain data.
// All incoming advertised channels will be checked against the chain
// to ensure that the channels advertised are still open.
Chain lnwallet.BlockChainIO
// Payer is an instance of a PaymentAttemptDispatcher and is used by
// the router to send payment attempts onto the network, and receive
// their results.
Payer PaymentAttemptDispatcher
// Control keeps track of the status of ongoing payments, ensuring we
// can properly resume them across restarts.
Control ControlTower
// MissionControl is a shared memory of sorts that executions of
// payment path finding use in order to remember which vertexes/edges
// were pruned from prior attempts. During SendPayment execution,
// errors sent by nodes are mapped into a vertex or edge to be pruned.
// Each run will then take into account this set of pruned
// vertexes/edges to reduce route failure and pass on graph information
// gained to the next execution.
MissionControl MissionControlQuerier
// SessionSource defines a source for the router to retrieve new payment
// sessions.
SessionSource PaymentSessionSource
// GetLink is a method that allows the router to query the lower link
// layer to determine the up-to-date available bandwidth at a
// prospective link to be traversed. If the link isn't available, then
// a value of zero should be returned. Otherwise, the current up-to-
// date knowledge of the available bandwidth of the link should be
// returned.
GetLink getLinkQuery
// NextPaymentID is a method that guarantees to return a new, unique ID
// each time it is called. This is used by the router to generate a
// unique payment ID for each payment it attempts to send, such that
// the switch can properly handle the HTLC.
NextPaymentID func() (uint64, error)
// PathFindingConfig defines global path finding parameters.
PathFindingConfig PathFindingConfig
// Clock is mockable time provider.
Clock clock.Clock
// ApplyChannelUpdate can be called to apply a new channel update to the
// graph that we received from a payment failure.
ApplyChannelUpdate func(msg *lnwire.ChannelUpdate1) bool
// ClosedSCIDs is used by the router to fetch closed channels.
//
// TODO(yy): remove it once the root cause of stuck payments is found.
ClosedSCIDs map[lnwire.ShortChannelID]struct{}
// TrafficShaper is an optional traffic shaper that can be used to
// control the outgoing channel of a payment.
TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper]
}
// EdgeLocator is a struct used to identify a specific edge.
type EdgeLocator struct {
// ChannelID is the channel of this edge.
ChannelID uint64
// Direction takes the value of 0 or 1 and is identical in definition to
// the channel direction flag. A value of 0 means the direction from the
// lower node pubkey to the higher.
Direction uint8
}
// String returns a human-readable version of the edgeLocator values.
func (e *EdgeLocator) String() string {
return fmt.Sprintf("%v:%v", e.ChannelID, e.Direction)
}
// ChannelRouter is the layer 3 router within the Lightning stack. Below the
// ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain
// itself. The primary role of the ChannelRouter is to respond to queries for
// potential routes that can support a payment amount, and also general graph
// reachability questions. The router will prune the channel graph
// automatically as new blocks are discovered which spend certain known funding
// outpoints, thereby closing their respective channels.
type ChannelRouter struct {
started uint32 // To be used atomically.
stopped uint32 // To be used atomically.
// cfg is a copy of the configuration struct that the ChannelRouter was
// initialized with.
cfg *Config
quit chan struct{}
wg sync.WaitGroup
}
// New creates a new instance of the ChannelRouter with the specified
// configuration parameters. As part of initialization, if the router detects
// that the channel graph isn't fully in sync with the latest UTXO (since the
// channel graph is a subset of the UTXO set) set, then the router will proceed
// to fully sync to the latest state of the UTXO set.
func New(cfg Config) (*ChannelRouter, error) {
return &ChannelRouter{
cfg: &cfg,
quit: make(chan struct{}),
}, nil
}
// Start launches all the goroutines the ChannelRouter requires to carry out
// its duties. If the router has already been started, then this method is a
// noop.
func (r *ChannelRouter) Start() error {
if !atomic.CompareAndSwapUint32(&r.started, 0, 1) {
return nil
}
log.Info("Channel Router starting")
// If any payments are still in flight, we resume, to make sure their
// results are properly handled.
if err := r.resumePayments(); err != nil {
log.Error("Failed to resume payments during startup")
}
return nil
}
// Stop signals the ChannelRouter to gracefully halt all routines. This method
// will *block* until all goroutines have excited. If the channel router has
// already stopped then this method will return immediately.
func (r *ChannelRouter) Stop() error {
if !atomic.CompareAndSwapUint32(&r.stopped, 0, 1) {
return nil
}
log.Info("Channel Router shutting down...")
defer log.Debug("Channel Router shutdown complete")
close(r.quit)
r.wg.Wait()
return nil
}
// RouteRequest contains the parameters for a pathfinding request. It may
// describe a request to make a regular payment or one to a blinded path
// (incdicated by a non-nil BlindedPayment field).
type RouteRequest struct {
// Source is the node that the path originates from.
Source route.Vertex
// Target is the node that the path terminates at. If the route
// includes a blinded path, target will be the blinded node id of the
// final hop in the blinded route.
Target route.Vertex
// Amount is the Amount in millisatoshis to be delivered to the target
// node.
Amount lnwire.MilliSatoshi
// TimePreference expresses the caller's time preference for
// pathfinding.
TimePreference float64
// Restrictions provides a set of additional Restrictions that the
// route must adhere to.
Restrictions *RestrictParams
// CustomRecords is a set of custom tlv records to include for the
// final hop.
CustomRecords record.CustomSet
// RouteHints contains an additional set of edges to include in our
// view of the graph. This may either be a set of hints for private
// channels or a "virtual" hop hint that represents a blinded route.
RouteHints RouteHints
// FinalExpiry is the cltv delta for the final hop. If paying to a
// blinded path, this value is a duplicate of the delta provided
// in blinded payment.
FinalExpiry uint16
// BlindedPathSet contains a set of optional blinded paths and
// parameters used to reach a target node blinded paths. This field is
// mutually exclusive with the Target field.
BlindedPathSet *BlindedPaymentPathSet
}
// RouteHints is an alias type for a set of route hints, with the source node
// as the map's key and the details of the hint(s) in the edge policy.
type RouteHints map[route.Vertex][]AdditionalEdge
// NewRouteRequest produces a new route request for a regular payment or one
// to a blinded route, validating that the target, routeHints and finalExpiry
// parameters are mutually exclusive with the blindedPayment parameter (which
// contains these values for blinded payments).
func NewRouteRequest(source route.Vertex, target *route.Vertex,
amount lnwire.MilliSatoshi, timePref float64,
restrictions *RestrictParams, customRecords record.CustomSet,
routeHints RouteHints, blindedPathSet *BlindedPaymentPathSet,
finalExpiry uint16) (*RouteRequest, error) {
var (
// Assume that we're starting off with a regular payment.
requestHints = routeHints
requestExpiry = finalExpiry
err error
)
if blindedPathSet != nil {
if blindedPathSet.IsIntroNode(source) {
return nil, ErrSelfIntro
}
// Check that the values for a clear path have not been set,
// as this is an ambiguous signal from the caller.
if routeHints != nil {
return nil, ErrHintsAndBlinded
}
if finalExpiry != 0 {
return nil, ErrExpiryAndBlinded
}
requestExpiry = blindedPathSet.FinalCLTVDelta()
requestHints, err = blindedPathSet.ToRouteHints()
if err != nil {
return nil, err
}
}
requestTarget, err := getTargetNode(target, blindedPathSet)
if err != nil {
return nil, err
}
return &RouteRequest{
Source: source,
Target: requestTarget,
Amount: amount,
TimePreference: timePref,
Restrictions: restrictions,
CustomRecords: customRecords,
RouteHints: requestHints,
FinalExpiry: requestExpiry,
BlindedPathSet: blindedPathSet,
}, nil
}
func getTargetNode(target *route.Vertex,
blindedPathSet *BlindedPaymentPathSet) (route.Vertex, error) {
var (
blinded = blindedPathSet != nil
targetSet = target != nil
)
switch {
case blinded && targetSet:
return route.Vertex{}, ErrTargetAndBlinded
case blinded:
return route.NewVertex(blindedPathSet.TargetPubKey()), nil
case targetSet:
return *target, nil
default:
return route.Vertex{}, ErrNoTarget
}
}
// FindRoute attempts to query the ChannelRouter for the optimum path to a
// particular target destination to which it is able to send `amt` after
// factoring in channel capacities and cumulative fees along the route.
func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64,
error) {
log.Debugf("Searching for path to %v, sending %v", req.Target,
req.Amount)
// We'll attempt to obtain a set of bandwidth hints that can help us
// eliminate certain routes early on in the path finding process.
bandwidthHints, err := newBandwidthManager(
r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink,
fn.None[tlv.Blob](), r.cfg.TrafficShaper,
)
if err != nil {
return nil, 0, err
}
// We'll fetch the current block height, so we can properly calculate
// the required HTLC time locks within the route.
_, currentHeight, err := r.cfg.Chain.GetBestBlock()
if err != nil {
return nil, 0, err
}
// Now that we know the destination is reachable within the graph, we'll
// execute our path finding algorithm.
finalHtlcExpiry := currentHeight + int32(req.FinalExpiry)
// Validate time preference.
timePref := req.TimePreference
if timePref < -1 || timePref > 1 {
return nil, 0, errors.New("time preference out of range")
}
path, probability, err := findPath(
&graphParams{
additionalEdges: req.RouteHints,
bandwidthHints: bandwidthHints,
graph: r.cfg.RoutingGraph,
},
req.Restrictions, &r.cfg.PathFindingConfig,
r.cfg.SelfNode, req.Source, req.Target, req.Amount,
req.TimePreference, finalHtlcExpiry,
)
if err != nil {
return nil, 0, err
}
// Create the route with absolute time lock values.
route, err := newRoute(
req.Source, path, uint32(currentHeight),
finalHopParams{
amt: req.Amount,
totalAmt: req.Amount,
cltvDelta: req.FinalExpiry,
records: req.CustomRecords,
}, req.BlindedPathSet,
)
if err != nil {
return nil, 0, err
}
go log.Tracef("Obtained path to send %v to %x: %v",
req.Amount, req.Target, lnutils.SpewLogClosure(route))
return route, probability, nil
}
// probabilitySource defines the signature of a function that can be used to
// query the success probability of sending a given amount between the two
// given vertices.
type probabilitySource func(route.Vertex, route.Vertex, lnwire.MilliSatoshi,
btcutil.Amount) float64
// BlindedPathRestrictions are a set of constraints to adhere to when
// choosing a set of blinded paths to this node.
type BlindedPathRestrictions struct {
// MinDistanceFromIntroNode is the minimum number of _real_ (non-dummy)
// hops to include in a blinded path. Since we post-fix dummy hops, this
// is the minimum distance between our node and the introduction node
// of the path. This doesn't include our node, so if the minimum is 1,
// then the path will contain at minimum our node along with an
// introduction node hop.
MinDistanceFromIntroNode uint8
// NumHops is the number of hops that each blinded path should consist
// of. If paths are found with a number of hops less that NumHops, then
// dummy hops will be padded on to the route. This value doesn't
// include our node, so if the maximum is 1, then the path will contain
// our node along with an introduction node hop.
NumHops uint8
// MaxNumPaths is the maximum number of blinded paths to select.
MaxNumPaths uint8
// NodeOmissionSet is a set of nodes that should not be used within any
// of the blinded paths that we generate.
NodeOmissionSet fn.Set[route.Vertex]
}
// FindBlindedPaths finds a selection of paths to the destination node that can
// be used in blinded payment paths.
func (r *ChannelRouter) FindBlindedPaths(destination route.Vertex,
amt lnwire.MilliSatoshi, probabilitySrc probabilitySource,
restrictions *BlindedPathRestrictions) ([]*route.Route, error) {
// First, find a set of candidate paths given the destination node and
// path length restrictions.
paths, err := findBlindedPaths(
r.cfg.RoutingGraph, destination, &blindedPathRestrictions{
minNumHops: restrictions.MinDistanceFromIntroNode,
maxNumHops: restrictions.NumHops,
nodeOmissionSet: restrictions.NodeOmissionSet,
},
)
if err != nil {
return nil, err
}
// routeWithProbability groups a route with the probability of a
// payment of the given amount succeeding on that path.
type routeWithProbability struct {
route *route.Route
probability float64
}
// Iterate over all the candidate paths and determine the success
// probability of each path given the data we have about forwards
// between any two nodes on a path.
routes := make([]*routeWithProbability, 0, len(paths))
for _, path := range paths {
if len(path) < 1 {
return nil, fmt.Errorf("a blinded path must have at " +
"least one hop")
}
var (
introNode = path[0].vertex
prevNode = introNode
hops = make(
[]*route.Hop, 0, len(path)-1,
)
totalRouteProbability = float64(1)
)
// For each set of hops on the path, get the success probability
// of a forward between those two vertices and use that to
// update the overall route probability.
for j := 1; j < len(path); j++ {
probability := probabilitySrc(
prevNode, path[j].vertex, amt,
path[j-1].edgeCapacity,
)
totalRouteProbability *= probability
hops = append(hops, &route.Hop{
PubKeyBytes: path[j].vertex,
ChannelID: path[j-1].channelID,
})
prevNode = path[j].vertex
}
// Don't bother adding a route if its success probability less
// minimum that can be assigned to any single pair.
if totalRouteProbability <= DefaultMinRouteProbability {
continue
}
routes = append(routes, &routeWithProbability{
route: &route.Route{
SourcePubKey: introNode,
Hops: hops,
},
probability: totalRouteProbability,
})
}
// Sort the routes based on probability.
sort.Slice(routes, func(i, j int) bool {
return routes[i].probability > routes[j].probability
})
// Now just choose the best paths up until the maximum number of allowed
// paths.
bestRoutes := make([]*route.Route, 0, restrictions.MaxNumPaths)
for _, route := range routes {
if len(bestRoutes) >= int(restrictions.MaxNumPaths) {
break
}
bestRoutes = append(bestRoutes, route.route)
}
return bestRoutes, nil
}
// generateNewSessionKey generates a new ephemeral private key to be used for a
// payment attempt.
func generateNewSessionKey() (*btcec.PrivateKey, error) {
// Generate a new random session key to ensure that we don't trigger
// any replay.
//
// TODO(roasbeef): add more sources of randomness?
return btcec.NewPrivateKey()
}
// LightningPayment describes a payment to be sent through the network to the
// final destination.
type LightningPayment struct {
// Target is the node in which the payment should be routed towards.
Target route.Vertex
// Amount is the value of the payment to send through the network in
// milli-satoshis.
Amount lnwire.MilliSatoshi
// FeeLimit is the maximum fee in millisatoshis that the payment should
// accept when sending it through the network. The payment will fail
// if there isn't a route with lower fees than this limit.
FeeLimit lnwire.MilliSatoshi
// CltvLimit is the maximum time lock that is allowed for attempts to
// complete this payment.
CltvLimit uint32
// paymentHash is the r-hash value to use within the HTLC extended to
// the first hop. This won't be set for AMP payments.
paymentHash *lntypes.Hash
// amp is an optional field that is set if and only if this is am AMP
// payment.
amp *AMPOptions
// FinalCLTVDelta is the CTLV expiry delta to use for the _final_ hop
// in the route. This means that the final hop will have a CLTV delta
// of at least: currentHeight + FinalCLTVDelta.
FinalCLTVDelta uint16
// PayAttemptTimeout is a timeout value that we'll use to determine
// when we should should abandon the payment attempt after consecutive
// payment failure. This prevents us from attempting to send a payment
// indefinitely. A zero value means the payment will never time out.
//
// TODO(halseth): make wallclock time to allow resume after startup.
PayAttemptTimeout time.Duration
// RouteHints represents the different routing hints that can be used to
// assist a payment in reaching its destination successfully. These
// hints will act as intermediate hops along the route.
//
// NOTE: This is optional unless required by the payment. When providing
// multiple routes, ensure the hop hints within each route are chained
// together and sorted in forward order in order to reach the
// destination successfully. This is mutually exclusive to the
// BlindedPayment field.
RouteHints [][]zpay32.HopHint
// BlindedPathSet holds the information about a set of blinded paths to
// the payment recipient. This is mutually exclusive to the RouteHints
// field.
BlindedPathSet *BlindedPaymentPathSet
// OutgoingChannelIDs is the list of channels that are allowed for the
// first hop. If nil, any channel may be used.
OutgoingChannelIDs []uint64
// LastHop is the pubkey of the last node before the final destination
// is reached. If nil, any node may be used.
LastHop *route.Vertex
// DestFeatures specifies the set of features we assume the final node
// has for pathfinding. Typically, these will be taken directly from an
// invoice, but they can also be manually supplied or assumed by the
// sender. If a nil feature vector is provided, the router will try to
// fall back to the graph in order to load a feature vector for a node
// in the public graph.
DestFeatures *lnwire.FeatureVector
// PaymentAddr is the payment address specified by the receiver. This
// field should be a random 32-byte nonce presented in the receiver's
// invoice to prevent probing of the destination.
PaymentAddr fn.Option[[32]byte]
// PaymentRequest is an optional payment request that this payment is
// attempting to complete.
PaymentRequest []byte
// DestCustomRecords are TLV records that are to be sent to the final
// hop in the new onion payload format. If the destination does not
// understand this new onion payload format, then the payment will
// fail.
DestCustomRecords record.CustomSet
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message and therefore do not affect the onion payload size.
FirstHopCustomRecords lnwire.CustomRecords
// MaxParts is the maximum number of partial payments that may be used
// to complete the full amount.
MaxParts uint32
// MaxShardAmt is the largest shard that we'll attempt to split using.
// If this field is set, and we need to split, rather than attempting
// half of the original payment amount, we'll use this value if half
// the payment amount is greater than it.
//
// NOTE: This field is _optional_.
MaxShardAmt *lnwire.MilliSatoshi
// TimePref is the time preference for this payment. Set to -1 to
// optimize for fees only, to 1 to optimize for reliability only or a
// value in between for a mix.
TimePref float64
// Metadata is additional data that is sent along with the payment to
// the payee.
Metadata []byte
}
// AMPOptions houses information that must be known in order to send an AMP
// payment.
type AMPOptions struct {
SetID [32]byte
RootShare [32]byte
}
// SetPaymentHash sets the given hash as the payment's overall hash. This
// should only be used for non-AMP payments.
func (l *LightningPayment) SetPaymentHash(hash lntypes.Hash) error {
if l.amp != nil {
return fmt.Errorf("cannot set payment hash for AMP payment")
}
l.paymentHash = &hash
return nil
}
// SetAMP sets the given AMP options for the payment.
func (l *LightningPayment) SetAMP(amp *AMPOptions) error {
if l.paymentHash != nil {
return fmt.Errorf("cannot set amp options for payment " +
"with payment hash")
}
l.amp = amp
return nil
}
// Identifier returns a 32-byte slice that uniquely identifies this single
// payment. For non-AMP payments this will be the payment hash, for AMP
// payments this will be the used SetID.
func (l *LightningPayment) Identifier() [32]byte {
if l.amp != nil {
return l.amp.SetID
}
return *l.paymentHash
}
// SendPayment attempts to send a payment as described within the passed
// LightningPayment. This function is blocking and will return either: when the
// payment is successful, or all candidates routes have been attempted and
// resulted in a failed payment. If the payment succeeds, then a non-nil Route
// will be returned which describes the path the successful payment traversed
// within the network to reach the destination. Additionally, the payment
// preimage will also be returned.
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) {
paySession, shardTracker, err := r.PreparePayment(payment)
if err != nil {
return [32]byte{}, nil, err
}
log.Tracef("Dispatching SendPayment for lightning payment: %v",
spewPayment(payment))
return r.sendPayment(
context.Background(), payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
payment.FirstHopCustomRecords,
)
}
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
r.wg.Add(1)
go func() {
defer r.wg.Done()
log.Tracef("Dispatching SendPayment for lightning payment: %v",
spewPayment(payment))
_, _, err := r.sendPayment(
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st,
payment.FirstHopCustomRecords,
)
if err != nil {
log.Errorf("Payment %x failed: %v",
payment.Identifier(), err)
}
}()
}
// spewPayment returns a log closures that provides a spewed string
// representation of the passed payment.
func spewPayment(payment *LightningPayment) lnutils.LogClosure {
return lnutils.NewLogClosure(func() string {
// Make a copy of the payment with a nilled Curve
// before spewing.
var routeHints [][]zpay32.HopHint
for _, routeHint := range payment.RouteHints {
var hopHints []zpay32.HopHint
for _, hopHint := range routeHint {
h := hopHint.Copy()
hopHints = append(hopHints, h)
}
routeHints = append(routeHints, hopHints)
}
p := *payment
p.RouteHints = routeHints
return spew.Sdump(p)
})
}
// PreparePayment creates the payment session and registers the payment with the
// control tower.
func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
PaymentSession, shards.ShardTracker, error) {
// Assemble any custom data we want to send to the first hop only.
var firstHopData fn.Option[tlv.Blob]
if len(payment.FirstHopCustomRecords) > 0 {
if err := payment.FirstHopCustomRecords.Validate(); err != nil {
return nil, nil, fmt.Errorf("invalid first hop custom "+
"records: %w", err)
}
firstHopBlob, err := payment.FirstHopCustomRecords.Serialize()
if err != nil {
return nil, nil, fmt.Errorf("unable to serialize "+
"first hop custom records: %w", err)
}
firstHopData = fn.Some(firstHopBlob)
}
// Before starting the HTLC routing attempt, we'll create a fresh
// payment session which will report our errors back to mission
// control.
paySession, err := r.cfg.SessionSource.NewPaymentSession(
payment, firstHopData, r.cfg.TrafficShaper,
)
if err != nil {
return nil, nil, err
}
// Record this payment hash with the ControlTower, ensuring it is not
// already in-flight.
//
// TODO(roasbeef): store records as part of creation info?
info := &channeldb.PaymentCreationInfo{
PaymentIdentifier: payment.Identifier(),
Value: payment.Amount,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: payment.PaymentRequest,
FirstHopCustomRecords: payment.FirstHopCustomRecords,
}
// Create a new ShardTracker that we'll use during the life cycle of
// this payment.
var shardTracker shards.ShardTracker
switch {
// If this is an AMP payment, we'll use the AMP shard tracker.
case payment.amp != nil:
addr := payment.PaymentAddr.UnwrapOr([32]byte{})
shardTracker = amp.NewShardTracker(
payment.amp.RootShare, payment.amp.SetID, addr,
payment.Amount,
)
// Otherwise we'll use the simple tracker that will map each attempt to
// the same payment hash.
default:
shardTracker = shards.NewSimpleShardTracker(
payment.Identifier(), nil,
)
}
err = r.cfg.Control.InitPayment(payment.Identifier(), info)
if err != nil {
return nil, nil, err
}
return paySession, shardTracker, nil
}
// SendToRoute sends a payment using the provided route and fails the payment
// when an error is returned from the attempt.
func (r *ChannelRouter) SendToRoute(htlcHash lntypes.Hash, rt *route.Route,
firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt,
error) {
return r.sendToRoute(htlcHash, rt, false, firstHopCustomRecords)
}
// SendToRouteSkipTempErr sends a payment using the provided route and fails
// the payment ONLY when a terminal error is returned from the attempt.
func (r *ChannelRouter) SendToRouteSkipTempErr(htlcHash lntypes.Hash,
rt *route.Route,
firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt,
error) {
return r.sendToRoute(htlcHash, rt, true, firstHopCustomRecords)
}
// sendToRoute attempts to send a payment with the given hash through the
// provided route. This function is blocking and will return the attempt
// information as it is stored in the database. For a successful htlc, this
// information will contain the preimage. If an error occurs after the attempt
// was initiated, both return values will be non-nil. If skipTempErr is true,
// the payment won't be failed unless a terminal error has occurred.
func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
skipTempErr bool,
firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt,
error) {
// Helper function to fail a payment. It makes sure the payment is only
// failed once so that the failure reason is not overwritten.
failPayment := func(paymentIdentifier lntypes.Hash,
reason channeldb.FailureReason) error {
payment, fetchErr := r.cfg.Control.FetchPayment(
paymentIdentifier,
)
if fetchErr != nil {
return fetchErr
}
// NOTE: We cannot rely on the payment status to be failed here
// because it can still be in-flight although the payment is
// already failed.
_, failedReason := payment.TerminalInfo()
if failedReason != nil {
return nil
}
return r.cfg.Control.FailPayment(paymentIdentifier, reason)
}
log.Debugf("SendToRoute for payment %v with skipTempErr=%v",
htlcHash, skipTempErr)
// Calculate amount paid to receiver.
amt := rt.ReceiverAmt()
// If this is meant as an MP payment shard, we set the amount for the
// creating info to the total amount of the payment.
finalHop := rt.Hops[len(rt.Hops)-1]
mpp := finalHop.MPP
if mpp != nil {
amt = mpp.TotalMsat()
}
// For non-MPP, there's no such thing as temp error as there's only one
// HTLC attempt being made. When this HTLC is failed, the payment is
// failed hence cannot be retried.
if skipTempErr && mpp == nil {
return nil, ErrSkipTempErr
}
// For non-AMP payments the overall payment identifier will be the same
// hash as used for this HTLC.
paymentIdentifier := htlcHash
// For AMP-payments, we'll use the setID as the unique ID for the
// overall payment.
amp := finalHop.AMP
if amp != nil {
paymentIdentifier = amp.SetID()
}
// Record this payment hash with the ControlTower, ensuring it is not
// already in-flight.
info := &channeldb.PaymentCreationInfo{
PaymentIdentifier: paymentIdentifier,
Value: amt,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: nil,
FirstHopCustomRecords: firstHopCustomRecords,
}
err := r.cfg.Control.InitPayment(paymentIdentifier, info)
switch {
// If this is an MPP attempt and the hash is already registered with
// the database, we can go on to launch the shard.
case mpp != nil && errors.Is(err, channeldb.ErrPaymentInFlight):
case mpp != nil && errors.Is(err, channeldb.ErrPaymentExists):
// Any other error is not tolerated.
case err != nil:
return nil, err
}
log.Tracef("Dispatching SendToRoute for HTLC hash %v: %v", htlcHash,
lnutils.SpewLogClosure(rt))
// Since the HTLC hashes and preimages are specified manually over the
// RPC for SendToRoute requests, we don't have to worry about creating
// a ShardTracker that can generate hashes for AMP payments. Instead, we
// create a simple tracker that can just return the hash for the single
// shard we'll now launch.
shardTracker := shards.NewSimpleShardTracker(htlcHash, nil)
// Create a payment lifecycle using the given route with,
// - zero fee limit as we are not requesting routes.
// - nil payment session (since we already have a route).
// - no payment timeout.
// - no current block height.
p := newPaymentLifecycle(
r, 0, paymentIdentifier, nil, shardTracker, 0,
firstHopCustomRecords,
)
// Allow the traffic shaper to add custom records to the outgoing HTLC
// and also adjust the amount if needed.
err = p.amendFirstHopData(rt)
if err != nil {
return nil, err
}
// We found a route to try, create a new HTLC attempt to try.
//
// NOTE: we use zero `remainingAmt` here to simulate the same effect of
// setting the lastShard to be false, which is used by previous
// implementation.
attempt, err := p.registerAttempt(rt, 0)
if err != nil {
return nil, err
}
// Once the attempt is created, send it to the htlcswitch. Notice that
// the `err` returned here has already been processed by
// `handleSwitchErr`, which means if there's a terminal failure, the
// payment has been failed.
result, err := p.sendAttempt(attempt)
if err != nil {
return nil, err
}
// Since for SendToRoute we won't retry in case the shard fails, we'll
// mark the payment failed with the control tower immediately if the
// skipTempErr is false.
reason := channeldb.FailureReasonError
// If we failed to send the HTLC, we need to further decide if we want
// to fail the payment.
if result.err != nil {
// If skipTempErr, we'll return the attempt and the temp error.
if skipTempErr {
return result.attempt, result.err
}
err := failPayment(paymentIdentifier, reason)
if err != nil {
return nil, err
}
return result.attempt, result.err
}
// The attempt was successfully sent, wait for the result to be
// available.
result, err = p.collectAndHandleResult(attempt)
if err != nil {
return nil, err
}
// We got a successful result.
if result.err == nil {
return result.attempt, nil
}
// An error returned from collecting the result, we'll mark the payment
// as failed if we don't skip temp error.
if !skipTempErr {
err := failPayment(paymentIdentifier, reason)
if err != nil {
return nil, err
}
}
return result.attempt, result.err
}
// sendPayment attempts to send a payment to the passed payment hash. This
// function is blocking and will return either: when the payment is successful,
// or all candidates routes have been attempted and resulted in a failed
// payment. If the payment succeeds, then a non-nil Route will be returned
// which describes the path the successful payment traversed within the network
// to reach the destination. Additionally, the payment preimage will also be
// returned.
//
// This method relies on the ControlTower's internal payment state machine to
// carry out its execution. After restarts, it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
paymentAttemptTimeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker,
firstHopCustomRecords lnwire.CustomRecords) ([32]byte, *route.Route,
error) {
// If the user provides a timeout, we will additionally wrap the context
// in a deadline.
cancel := func() {}
if paymentAttemptTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, paymentAttemptTimeout)
}
// Since resumePayment is a blocking call, we'll cancel this
// context if the payment completes before the optional
// deadline.
defer cancel()
// We'll also fetch the current block height, so we can properly
// calculate the required HTLC time locks within the route.
_, currentHeight, err := r.cfg.Chain.GetBestBlock()
if err != nil {
return [32]byte{}, nil, err
}
// Validate the custom records before we attempt to send the payment.
// TODO(ziggie): Move this check before registering the payment in the
// db (InitPayment).
if err := firstHopCustomRecords.Validate(); err != nil {
return [32]byte{}, nil, err
}
// Now set up a paymentLifecycle struct with these params, such that we
// can resume the payment from the current state.
p := newPaymentLifecycle(
r, feeLimit, identifier, paySession, shardTracker,
currentHeight, firstHopCustomRecords,
)
return p.resumePayment(ctx)
}
// extractChannelUpdate examines the error and extracts the channel update.
func (r *ChannelRouter) extractChannelUpdate(
failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 {
var update *lnwire.ChannelUpdate1
switch onionErr := failure.(type) {
case *lnwire.FailExpiryTooSoon:
update = &onionErr.Update
case *lnwire.FailAmountBelowMinimum:
update = &onionErr.Update
case *lnwire.FailFeeInsufficient:
update = &onionErr.Update
case *lnwire.FailIncorrectCltvExpiry:
update = &onionErr.Update
case *lnwire.FailChannelDisabled:
update = &onionErr.Update
case *lnwire.FailTemporaryChannelFailure:
update = onionErr.Update
}
return update
}
// ErrNoChannel is returned when a route cannot be built because there are no
// channels that satisfy all requirements.
type ErrNoChannel struct {
position int
}
// Error returns a human-readable string describing the error.
func (e ErrNoChannel) Error() string {
return fmt.Sprintf("no matching outgoing channel available for "+
"node index %v", e.position)
}
// BuildRoute returns a fully specified route based on a list of pubkeys. If
// amount is nil, the minimum routable amount is used. To force a specific
// outgoing channel, use the outgoingChan parameter.
func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi],
hops []route.Vertex, outgoingChan *uint64, finalCltvDelta int32,
payAddr fn.Option[[32]byte], firstHopBlob fn.Option[[]byte]) (
*route.Route, error) {
log.Tracef("BuildRoute called: hopsCount=%v, amt=%v", len(hops), amt)
var outgoingChans map[uint64]struct{}
if outgoingChan != nil {
outgoingChans = map[uint64]struct{}{
*outgoingChan: {},
}
}
// We'll attempt to obtain a set of bandwidth hints that helps us select
// the best outgoing channel to use in case no outgoing channel is set.
bandwidthHints, err := newBandwidthManager(
r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, firstHopBlob,
r.cfg.TrafficShaper,
)
if err != nil {
return nil, err
}
sourceNode := r.cfg.SelfNode
// We check that each node in the route has a connection to others that
// can forward in principle.
unifiers, err := getEdgeUnifiers(
r.cfg.SelfNode, hops, outgoingChans, r.cfg.RoutingGraph,
)
if err != nil {
return nil, err
}
var (
receiverAmt lnwire.MilliSatoshi
senderAmt lnwire.MilliSatoshi
pathEdges []*unifiedEdge
)
// We determine the edges compatible with the requested amount, as well
// as the amount to send, which can be used to determine the final
// receiver amount, if a minimal amount was requested.
pathEdges, senderAmt, err = senderAmtBackwardPass(
unifiers, amt, bandwidthHints,
)
if err != nil {
return nil, err
}
// For the minimal amount search, we need to do a forward pass to find a
// larger receiver amount due to possible min HTLC bumps, otherwise we
// just use the requested amount.
receiverAmt, err = fn.ElimOption(
amt,
func() fn.Result[lnwire.MilliSatoshi] {
return fn.NewResult(
receiverAmtForwardPass(senderAmt, pathEdges),
)
},
fn.Ok[lnwire.MilliSatoshi],
).Unpack()
if err != nil {
return nil, err
}
// Fetch the current block height outside the routing transaction, to
// prevent the rpc call blocking the database.
_, height, err := r.cfg.Chain.GetBestBlock()
if err != nil {
return nil, err
}
// Build and return the final route.
return newRoute(
sourceNode, pathEdges, uint32(height),
finalHopParams{
amt: receiverAmt,
totalAmt: receiverAmt,
cltvDelta: uint16(finalCltvDelta),
records: nil,
paymentAddr: payAddr,
}, nil,
)
}
// resumePayments fetches inflight payments and resumes their payment
// lifecycles.
func (r *ChannelRouter) resumePayments() error {
// Get all payments that are inflight.
log.Debugf("Scanning for inflight payments")
payments, err := r.cfg.Control.FetchInFlightPayments()
if err != nil {
return err
}
log.Debugf("Scanning finished, found %d inflight payments",
len(payments))
// Before we restart existing payments and start accepting more
// payments to be made, we clean the network result store of the
// Switch. We do this here at startup to ensure no more payments can be
// made concurrently, so we know the toKeep map will be up-to-date
// until the cleaning has finished.
toKeep := make(map[uint64]struct{})
for _, p := range payments {
for _, a := range p.HTLCs {
toKeep[a.AttemptID] = struct{}{}
// Try to fail the attempt if the route contains a dead
// channel.
r.failStaleAttempt(a, p.Info.PaymentIdentifier)
}
}
log.Debugf("Cleaning network result store.")
if err := r.cfg.Payer.CleanStore(toKeep); err != nil {
return err
}
// launchPayment is a helper closure that handles resuming the payment.
launchPayment := func(payment *channeldb.MPPayment) {
defer r.wg.Done()
// Get the hashes used for the outstanding HTLCs.
htlcs := make(map[uint64]lntypes.Hash)
for _, a := range payment.HTLCs {
a := a
// We check whether the individual attempts have their
// HTLC hash set, if not we'll fall back to the overall
// payment hash.
hash := payment.Info.PaymentIdentifier
if a.Hash != nil {
hash = *a.Hash
}
htlcs[a.AttemptID] = hash
}
payHash := payment.Info.PaymentIdentifier
// Since we are not supporting creating more shards after a
// restart (only receiving the result of the shards already
// outstanding), we create a simple shard tracker that will map
// the attempt IDs to hashes used for the HTLCs. This will be
// enough also for AMP payments, since we only need the hashes
// for the individual HTLCs to regenerate the circuits, and we
// don't currently persist the root share necessary to
// re-derive them.
shardTracker := shards.NewSimpleShardTracker(payHash, htlcs)
// We create a dummy, empty payment session such that we won't
// make another payment attempt when the result for the
// in-flight attempt is received.
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
// We pass in a non-timeout context, to indicate we don't need
// it to timeout. It will stop immediately after the existing
// attempt has finished anyway. We also set a zero fee limit,
// as no more routes should be tried.
noTimeout := time.Duration(0)
_, _, err := r.sendPayment(
context.Background(), 0, payHash, noTimeout, paySession,
shardTracker, payment.Info.FirstHopCustomRecords,
)
if err != nil {
log.Errorf("Resuming payment %v failed: %v", payHash,
err)
return
}
log.Infof("Resumed payment %v completed", payHash)
}
for _, payment := range payments {
log.Infof("Resuming payment %v", payment.Info.PaymentIdentifier)
r.wg.Add(1)
go launchPayment(payment)
}
return nil
}
// failStaleAttempt will fail an HTLC attempt if it's using an unknown channel
// in its route. It first consults the switch to see if there's already a
// network result stored for this attempt. If not, it will check whether the
// first hop of this attempt is using an active channel known to us. If
// inactive, this attempt will be failed.
//
// NOTE: there's an unknown bug that caused the network result for a particular
// attempt to NOT be saved, resulting a payment being stuck forever. More info:
// - https://github.com/lightningnetwork/lnd/issues/8146
// - https://github.com/lightningnetwork/lnd/pull/8174
func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt,
payHash lntypes.Hash) {
// We can only fail inflight HTLCs so we skip the settled/failed ones.
if a.Failure != nil || a.Settle != nil {
return
}
// First, check if we've already had a network result for this attempt.
// If no result is found, we'll check whether the reference link is
// still known to us.
ok, err := r.cfg.Payer.HasAttemptResult(a.AttemptID)
if err != nil {
log.Errorf("Failed to fetch network result for attempt=%v",
a.AttemptID)
return
}
// There's already a network result, no need to fail it here as the
// payment lifecycle will take care of it, so we can exit early.
if ok {
log.Debugf("Already have network result for attempt=%v",
a.AttemptID)
return
}
// We now need to decide whether this attempt should be failed here.
// For very old payments, it's possible that the network results were
// never saved, causing the payments to be stuck inflight. We now check
// whether the first hop is referencing an active channel ID and, if
// not, we will fail the attempt as it has no way to be retried again.
var shouldFail bool
// Validate that the attempt has hop info. If this attempt has no hop
// info it indicates an error in our db.
if len(a.Route.Hops) == 0 {
log.Errorf("Found empty hop for attempt=%v", a.AttemptID)
shouldFail = true
} else {
// Get the short channel ID.
chanID := a.Route.Hops[0].ChannelID
scid := lnwire.NewShortChanIDFromInt(chanID)
// Check whether this link is active. If so, we won't fail the
// attempt but keep waiting for its result.
_, err := r.cfg.GetLink(scid)
if err == nil {
return
}
// We should get the link not found err. If not, we will log an
// error and skip failing this attempt since an unknown error
// occurred.
if !errors.Is(err, htlcswitch.ErrChannelLinkNotFound) {
log.Errorf("Failed to get link for attempt=%v for "+
"payment=%v: %v", a.AttemptID, payHash, err)
return
}
// The channel link is not active, we now check whether this
// channel is already closed. If so, we fail the HTLC attempt
// as there's no need to wait for its network result because
// there's no link. If the channel is still pending, we'll keep
// waiting for the result as we may get a contract resolution
// for this HTLC.
if _, ok := r.cfg.ClosedSCIDs[scid]; ok {
shouldFail = true
}
}
// Exit if there's no need to fail.
if !shouldFail {
return
}
log.Errorf("Failing stale attempt=%v for payment=%v", a.AttemptID,
payHash)
// Fail the attempt in db. If there's an error, there's nothing we can
// do here but logging it.
failInfo := &channeldb.HTLCFailInfo{
Reason: channeldb.HTLCFailUnknown,
FailTime: r.cfg.Clock.Now(),
}
_, err = r.cfg.Control.FailAttempt(payHash, a.AttemptID, failInfo)
if err != nil {
log.Errorf("Fail attempt=%v got error: %v", a.AttemptID, err)
}
}
// getEdgeUnifiers returns a list of edge unifiers for the given route.
func getEdgeUnifiers(source route.Vertex, hops []route.Vertex,
outgoingChans map[uint64]struct{},
graph Graph) ([]*edgeUnifier, error) {
// Allocate a list that will contain the edge unifiers for this route.
unifiers := make([]*edgeUnifier, len(hops))
// Traverse hops backwards to accumulate fees in the running amounts.
for i := len(hops) - 1; i >= 0; i-- {
toNode := hops[i]
var fromNode route.Vertex
if i == 0 {
fromNode = source
} else {
fromNode = hops[i-1]
}
// Build unified policies for this hop based on the channels
// known in the graph. Inbound fees are only active if the edge
// is not the last hop.
isExitHop := i == len(hops)-1
u := newNodeEdgeUnifier(
source, toNode, !isExitHop, outgoingChans,
)
err := u.addGraphPolicies(graph)
if err != nil {
return nil, err
}
// Exit if there are no channels.
edgeUnifier, ok := u.edgeUnifiers[fromNode]
if !ok {
log.Errorf("Cannot find policy for node %v", fromNode)
return nil, ErrNoChannel{position: i}
}
unifiers[i] = edgeUnifier
}
return unifiers, nil
}
// senderAmtBackwardPass returns a list of unified edges for the given route and
// determines the amount that should be sent to fulfill min HTLC requirements.
// The minimal sender amount can be searched for by using amt=None.
func senderAmtBackwardPass(unifiers []*edgeUnifier,
amt fn.Option[lnwire.MilliSatoshi],
bandwidthHints bandwidthHints) ([]*unifiedEdge, lnwire.MilliSatoshi,
error) {
if len(unifiers) == 0 {
return nil, 0, fmt.Errorf("no unifiers provided")
}
var unifiedEdges = make([]*unifiedEdge, len(unifiers))
// We traverse the route backwards and handle the last hop separately.
edgeUnifier := unifiers[len(unifiers)-1]
// incomingAmt tracks the amount that is forwarded on the edges of a
// route. The last hop only forwards the amount that the receiver should
// receive, as there are no fees paid to the last node.
// For minimum amount routes, aim to deliver at least 1 msat to
// the destination. There are nodes in the wild that have a
// min_htlc channel policy of zero, which could lead to a zero
// amount payment being made.
incomingAmt := amt.UnwrapOr(1)
// If using min amt, increase the amount if needed to fulfill min HTLC
// requirements.
if amt.IsNone() {
min := edgeUnifier.minAmt()
if min > incomingAmt {
incomingAmt = min
}
}
// Get an edge for the specific amount that we want to forward.
edge := edgeUnifier.getEdge(incomingAmt, bandwidthHints, 0)
if edge == nil {
log.Errorf("Cannot find policy with amt=%v "+
"for hop %v", incomingAmt, len(unifiers)-1)
return nil, 0, ErrNoChannel{position: len(unifiers) - 1}
}
unifiedEdges[len(unifiers)-1] = edge
// Handle the rest of the route except the last hop.
for i := len(unifiers) - 2; i >= 0; i-- {
edgeUnifier = unifiers[i]
// If using min amt, increase the amount if needed to fulfill
// min HTLC requirements.
if amt.IsNone() {
min := edgeUnifier.minAmt()
if min > incomingAmt {
incomingAmt = min
}
}
// A --current hop-- B --next hop: incomingAmt-- C
// The outbound fee paid to the current end node B is based on
// the amount that the next hop forwards and B's policy for that
// hop.
outboundFee := unifiedEdges[i+1].policy.ComputeFee(
incomingAmt,
)
netAmount := incomingAmt + outboundFee
// We need to select an edge that can forward the requested
// amount.
edge = edgeUnifier.getEdge(
netAmount, bandwidthHints, outboundFee,
)
if edge == nil {
return nil, 0, ErrNoChannel{position: i}
}
// The fee paid to B depends on the current hop's inbound fee
// policy and on the outbound fee for the next hop as any
// inbound fee discount is capped by the outbound fee such that
// the total fee for B can't become negative.
inboundFee := calcCappedInboundFee(edge, netAmount, outboundFee)
fee := lnwire.MilliSatoshi(int64(outboundFee) + inboundFee)
log.Tracef("Select channel %v at position %v",
edge.policy.ChannelID, i)
// Finally, we update the amount that needs to flow into node B
// from A, which is the next hop's forwarding amount plus the
// fee for B: A --current hop: incomingAmt-- B --next hop-- C
incomingAmt += fee
unifiedEdges[i] = edge
}
return unifiedEdges, incomingAmt, nil
}
// receiverAmtForwardPass returns the amount that a receiver will receive after
// deducting all fees from the sender amount.
func receiverAmtForwardPass(runningAmt lnwire.MilliSatoshi,
unifiedEdges []*unifiedEdge) (lnwire.MilliSatoshi, error) {
if len(unifiedEdges) == 0 {
return 0, fmt.Errorf("no edges to forward through")
}
inEdge := unifiedEdges[0]
if !inEdge.amtInRange(runningAmt) {
log.Errorf("Amount %v not in range for hop index %v",
runningAmt, 0)
return 0, ErrNoChannel{position: 0}
}
// Now that we arrived at the start of the route and found out the route
// total amount, we make a forward pass. Because the amount may have
// been increased in the backward pass, fees need to be recalculated and
// amount ranges re-checked.
for i := 1; i < len(unifiedEdges); i++ {
inEdge := unifiedEdges[i-1]
outEdge := unifiedEdges[i]
// Decrease the amount to send while going forward.
runningAmt = outgoingFromIncoming(runningAmt, inEdge, outEdge)
if !outEdge.amtInRange(runningAmt) {
log.Errorf("Amount %v not in range for hop index %v",
runningAmt, i)
return 0, ErrNoChannel{position: i}
}
}
return runningAmt, nil
}
// incomingFromOutgoing computes the incoming amount based on the outgoing
// amount by adding fees to the outgoing amount, replicating the path finding
// and routing process, see also CheckHtlcForward.
func incomingFromOutgoing(outgoingAmt lnwire.MilliSatoshi,
incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi {
outgoingFee := outgoing.policy.ComputeFee(outgoingAmt)
// Net amount is the amount the inbound fees are calculated with.
netAmount := outgoingAmt + outgoingFee
inboundFee := incoming.inboundFees.CalcFee(netAmount)
// The inbound fee is not allowed to reduce the incoming amount below
// the outgoing amount.
if int64(outgoingFee)+inboundFee < 0 {
return outgoingAmt
}
return netAmount + lnwire.MilliSatoshi(inboundFee)
}
// outgoingFromIncoming computes the outgoing amount based on the incoming
// amount by subtracting fees from the incoming amount. Note that this is not
// exactly the inverse of incomingFromOutgoing, because of some rounding.
func outgoingFromIncoming(incomingAmt lnwire.MilliSatoshi,
incoming, outgoing *unifiedEdge) lnwire.MilliSatoshi {
// Convert all quantities to big.Int to be able to hande negative
// values. The formulas to compute the outgoing amount involve terms
// with PPM*PPM*A, which can easily overflow an int64.
A := big.NewInt(int64(incomingAmt))
Ro := big.NewInt(int64(outgoing.policy.FeeProportionalMillionths))
Bo := big.NewInt(int64(outgoing.policy.FeeBaseMSat))
Ri := big.NewInt(int64(incoming.inboundFees.Rate))
Bi := big.NewInt(int64(incoming.inboundFees.Base))
PPM := big.NewInt(1_000_000)
// The following discussion was contributed by user feelancer21, see
//nolint:ll
// https://github.com/feelancer21/lnd/commit/f6f05fa930985aac0d27c3f6681aada1b599162a.
// The incoming amount Ai based on the outgoing amount Ao is computed by
// Ai = max(Ai(Ao), Ao), which caps the incoming amount such that the
// total node fee (Ai - Ao) is non-negative. This is commonly enforced
// by routing nodes.
// The function Ai(Ao) is given by:
// Ai(Ao) = (Ao + Bo + Ro/PPM) + (Bi + (Ao + Ro/PPM + Bo)*Ri/PPM), where
// the first term is the net amount (the outgoing amount plus the
// outbound fee), and the second is the inbound fee computed based on
// the net amount.
// Ai(Ao) can potentially become more negative in absolute value than
// Ao, which is why the above mentioned capping is needed. We can
// abbreviate Ai(Ao) with Ai(Ao) = m*Ao + n, where m and n are:
// m := (1 + Ro/PPM) * (1 + Ri/PPM)
// n := Bi + Bo*(1 + Ri/PPM)
// If we know that m > 0, this is equivalent of Ri/PPM > -1, because Ri
// is the only factor that can become negative. A value or Ri/PPM = -1,
// means that the routing node is willing to give up on 100% of the
// net amount (based on the fee rate), which is likely to not happen in
// practice. This condition will be important for a later trick.
// If we want to compute the incoming amount based on the outgoing
// amount, which is the reverse problem, we need to solve Ai =
// max(Ai(Ao), Ao) for Ao(Ai). Given an incoming amount A,
// we look for an Ao such that A = max(Ai(Ao), Ao).
// The max function separates this into two cases. The case to take is
// not clear yet, because we don't know Ao, but later we see a trick
// how to determine which case is the one to take.
// first case: Ai(Ao) <= Ao:
// Therefore, A = max(Ai(Ao), Ao) = Ao, we find Ao = A.
// This also leads to Ai(A) <= A by substitution into the condition.
// second case: Ai(Ao) > Ao:
// Therefore, A = max(Ai(Ao), Ao) = Ai(Ao) = m*Ao + n. Solving for Ao
// gives Ao = (A - n)/m.
//
// We know
// Ai(Ao) > Ao <=> A = Ai(Ao) > Ao = (A - n)/m,
// so A > (A - n)/m.
//
// **Assuming m > 0**, by multiplying with m, we can transform this to
// A * m + n > A.
//
// We know Ai(A) = A*m + n, therefore Ai(A) > A.
//
// This means that if we apply the incoming amount calculation to the
// **incoming** amount, and this condition holds, then we know that we
// deal with the second case, being able to compute the outgoing amount
// based off the formula Ao = (A - n)/m, otherwise we will just return
// the incoming amount.
// In case the inbound fee rate is less than -1 (-100%), we fail to
// compute the outbound amount and return the incoming amount. This also
// protects against zero division later.
// We compute m in terms of big.Int to be safe from overflows and to be
// consistent with later calculations.
// m := (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri)/(PPM*PPM)
// Compute terms in (PPM*PPM + Ri*PPM + Ro*PPM + Ro*Ri).
m1 := new(big.Int).Mul(PPM, PPM)
m2 := new(big.Int).Mul(Ri, PPM)
m3 := new(big.Int).Mul(Ro, PPM)
m4 := new(big.Int).Mul(Ro, Ri)
// Add up terms m1..m4.
m := big.NewInt(0)
m.Add(m, m1)
m.Add(m, m2)
m.Add(m, m3)
m.Add(m, m4)
// Since we compare to 0, we can multiply by PPM*PPM to avoid the
// division.
if m.Int64() <= 0 {
return incomingAmt
}
// In order to decide if the total fee is negative, we apply the fee
// to the *incoming* amount as mentioned before.
// We compute the test amount in terms of big.Int to be safe from
// overflows and to be consistent later calculations.
// testAmtF := A*m + n =
// = A + Bo + Bi + (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM)
// Compute terms in (A*Ri + A*Ro + Ro*Ri).
t1 := new(big.Int).Mul(A, Ri)
t2 := new(big.Int).Mul(A, Ro)
t3 := new(big.Int).Mul(Ro, Ri)
// Sum up terms t1-t3.
t4 := big.NewInt(0)
t4.Add(t4, t1)
t4.Add(t4, t2)
t4.Add(t4, t3)
// Compute PPM*(A*Ri + A*Ro + Ro*Ri).
t6 := new(big.Int).Mul(PPM, t4)
// Compute A*Ri*Ro.
t7 := new(big.Int).Mul(A, Ri)
t7.Mul(t7, Ro)
// Compute (PPM*(A*Ri + A*Ro + Ro*Ri) + A*Ri*Ro)/(PPM*PPM).
num := new(big.Int).Add(t6, t7)
denom := new(big.Int).Mul(PPM, PPM)
fraction := new(big.Int).Div(num, denom)
// Sum up all terms.
testAmt := big.NewInt(0)
testAmt.Add(testAmt, A)
testAmt.Add(testAmt, Bo)
testAmt.Add(testAmt, Bi)
testAmt.Add(testAmt, fraction)
// Protect against negative values for the integer cast to Msat.
if testAmt.Int64() < 0 {
return incomingAmt
}
// If the second case holds, we have to compute the outgoing amount.
if lnwire.MilliSatoshi(testAmt.Int64()) > incomingAmt {
// Compute the outgoing amount by integer ceiling division. This
// precision is needed because PPM*PPM*A and other terms can
// easily overflow with int64, which happens with about
// A = 10_000 sat.
// out := (A - n) / m = numerator / denominator
// numerator := PPM*(PPM*(A - Bo - Bi) - Bo*Ri)
// denominator := PPM*(PPM + Ri + Ro) + Ri*Ro
var numerator big.Int
// Compute (A - Bo - Bi).
temp1 := new(big.Int).Sub(A, Bo)
temp2 := new(big.Int).Sub(temp1, Bi)
// Compute terms in (PPM*(A - Bo - Bi) - Bo*Ri).
temp3 := new(big.Int).Mul(PPM, temp2)
temp4 := new(big.Int).Mul(Bo, Ri)
// Compute PPM*(PPM*(A - Bo - Bi) - Bo*Ri)
temp5 := new(big.Int).Sub(temp3, temp4)
numerator.Mul(PPM, temp5)
var denominator big.Int
// Compute (PPM + Ri + Ro).
temp1 = new(big.Int).Add(PPM, Ri)
temp2 = new(big.Int).Add(temp1, Ro)
// Compute PPM*(PPM + Ri + Ro) + Ri*Ro.
temp3 = new(big.Int).Mul(PPM, temp2)
temp4 = new(big.Int).Mul(Ri, Ro)
denominator.Add(temp3, temp4)
// We overestimate the outgoing amount by taking the ceiling of
// the division. This means that we may round slightly up by a
// MilliSatoshi, but this helps to ensure that we don't hit min
// HTLC constrains in the context of finding the minimum amount
// of a route.
// ceil = floor((numerator + denominator - 1) / denominator)
ceil := new(big.Int).Add(&numerator, &denominator)
ceil.Sub(ceil, big.NewInt(1))
ceil.Div(ceil, &denominator)
return lnwire.MilliSatoshi(ceil.Int64())
}
// Otherwise the inbound fee made up for the outbound fee, which is why
// we just return the incoming amount.
return incomingAmt
}
package shards
import (
"fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record"
)
// PaymentShard is an interface representing a shard tracked by the
// ShardTracker. It contains options that are specific to the given shard that
// might differ from the overall payment.
type PaymentShard interface {
// Hash returns the hash used for the HTLC representing this shard.
Hash() lntypes.Hash
// MPP returns any extra MPP records that should be set for the final
// hop on the route used by this shard.
MPP() *record.MPP
// AMP returns any extra AMP records that should be set for the final
// hop on the route used by this shard.
AMP() *record.AMP
}
// ShardTracker is an interface representing a tracker that keeps track of the
// inflight shards of a payment, and is able to assign new shards the correct
// options such as hash and extra records.
type ShardTracker interface {
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be
// canceled if it ends up not being used by the overall payment, i.e.
// if the attempt fails.
NewShard(uint64, bool) (PaymentShard, error)
// CancelShard cancel's the shard corresponding to the given attempt
// ID. This lets the ShardTracker free up any slots used by this shard,
// and in case of AMP payments return the share used by this shard to
// the root share.
CancelShard(uint64) error
// GetHash retrieves the hash used by the shard of the given attempt
// ID. This will return an error if the attempt ID is unknown.
GetHash(uint64) (lntypes.Hash, error)
}
// Shard is a struct used for simple shards where we only need to keep map it
// to a single hash.
type Shard struct {
hash lntypes.Hash
}
// Hash returns the hash used for the HTLC representing this shard.
func (s *Shard) Hash() lntypes.Hash {
return s.hash
}
// MPP returns any extra MPP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) MPP() *record.MPP {
return nil
}
// AMP returns any extra AMP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) AMP() *record.AMP {
return nil
}
// SimpleShardTracker is an implementation of the ShardTracker interface that
// simply maps attempt IDs to hashes. New shards will be given a static payment
// hash. This should be used for regular and MPP payments, in addition to
// resumed payments where all the attempt's hashes have already been created.
type SimpleShardTracker struct {
hash lntypes.Hash
shards map[uint64]lntypes.Hash
sync.Mutex
}
// A compile time check to ensure SimpleShardTracker implements the
// ShardTracker interface.
var _ ShardTracker = (*SimpleShardTracker)(nil)
// NewSimpleShardTracker creates a new instance of the SimpleShardTracker with
// the given payment hash and existing attempts.
func NewSimpleShardTracker(paymentHash lntypes.Hash,
shards map[uint64]lntypes.Hash) ShardTracker {
if shards == nil {
shards = make(map[uint64]lntypes.Hash)
}
return &SimpleShardTracker{
hash: paymentHash,
shards: shards,
}
}
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be canceled
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
m.Lock()
m.shards[id] = m.hash
m.Unlock()
return &Shard{
hash: m.hash,
}, nil
}
// CancelShard cancel's the shard corresponding to the given attempt ID.
func (m *SimpleShardTracker) CancelShard(id uint64) error {
m.Lock()
delete(m.shards, id)
m.Unlock()
return nil
}
// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
m.Lock()
hash, ok := m.shards[id]
m.Unlock()
if !ok {
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
"not found", id)
}
return hash, nil
}
package routing
import (
"math"
"github.com/btcsuite/btcd/btcutil"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// nodeEdgeUnifier holds all edge unifiers for connections towards a node.
type nodeEdgeUnifier struct {
// edgeUnifiers contains an edge unifier for every from node.
edgeUnifiers map[route.Vertex]*edgeUnifier
// sourceNode is the sender of a payment. The rules to pick the final
// policy are different for local channels.
sourceNode route.Vertex
// toNode is the node for which the edge unifiers are instantiated.
toNode route.Vertex
// useInboundFees indicates whether to take inbound fees into account.
useInboundFees bool
// outChanRestr is an optional outgoing channel restriction for the
// local channel to use.
outChanRestr map[uint64]struct{}
}
// newNodeEdgeUnifier instantiates a new nodeEdgeUnifier object. Channel
// policies can be added to this object.
func newNodeEdgeUnifier(sourceNode, toNode route.Vertex, useInboundFees bool,
outChanRestr map[uint64]struct{}) *nodeEdgeUnifier {
return &nodeEdgeUnifier{
edgeUnifiers: make(map[route.Vertex]*edgeUnifier),
toNode: toNode,
useInboundFees: useInboundFees,
sourceNode: sourceNode,
outChanRestr: outChanRestr,
}
}
// addPolicy adds a single channel policy. Capacity may be zero if unknown
// (light clients). We expect a non-nil payload size function and will request a
// graceful shutdown if it is not provided as this indicates that edges are
// incorrectly specified.
func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex,
edge *models.CachedEdgePolicy, inboundFee models.InboundFee,
capacity btcutil.Amount, hopPayloadSizeFn PayloadSizeFunc,
blindedPayment *BlindedPayment) {
localChan := fromNode == u.sourceNode
// Skip channels if there is an outgoing channel restriction.
if localChan && u.outChanRestr != nil {
if _, ok := u.outChanRestr[edge.ChannelID]; !ok {
log.Debugf("Skipped adding policy for restricted edge "+
"%v", edge.ChannelID)
return
}
}
// Update the edgeUnifiers map.
unifier, ok := u.edgeUnifiers[fromNode]
if !ok {
unifier = &edgeUnifier{
localChan: localChan,
}
u.edgeUnifiers[fromNode] = unifier
}
// In case no payload size function was provided a graceful shutdown
// is requested, because this function is not used as intended.
if hopPayloadSizeFn == nil {
log.Criticalf("No payloadsize function was provided for the "+
"edge (chanid=%v) when adding it to the edge unifier "+
"of node: %v", edge.ChannelID, fromNode)
return
}
// Zero inbound fee for exit hops.
if !u.useInboundFees {
inboundFee = models.InboundFee{}
}
unifier.edges = append(unifier.edges, newUnifiedEdge(
edge, capacity, inboundFee, hopPayloadSizeFn, blindedPayment,
))
}
// addGraphPolicies adds all policies that are known for the toNode in the
// graph.
func (u *nodeEdgeUnifier) addGraphPolicies(g Graph) error {
cb := func(channel *graphdb.DirectedChannel) error {
// If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have
// come prior to the pivot node in the route.
if channel.InPolicy == nil {
log.Debugf("Skipped adding edge %v due to nil policy",
channel.ChannelID)
return nil
}
// Add this policy to the corresponding edgeUnifier. We default
// to the clear hop payload size function because
// `addGraphPolicies` is only used for cleartext intermediate
// hops in a route.
inboundFee := models.NewInboundFeeFromWire(
channel.InboundFee,
)
u.addPolicy(
channel.OtherNode, channel.InPolicy, inboundFee,
channel.Capacity, defaultHopPayloadSize, nil,
)
return nil
}
// Iterate over all channels of the to node.
return g.ForEachNodeDirectedChannel(u.toNode, cb)
}
// unifiedEdge is the individual channel data that is kept inside an edgeUnifier
// object.
type unifiedEdge struct {
policy *models.CachedEdgePolicy
capacity btcutil.Amount
inboundFees models.InboundFee
// hopPayloadSize supplies an edge with the ability to calculate the
// exact payload size if this edge would be included in a route. This
// is needed because hops of a blinded path differ in their payload
// structure compared to cleartext hops.
hopPayloadSizeFn PayloadSizeFunc
// blindedPayment if set, is the BlindedPayment that this edge was
// derived from originally.
blindedPayment *BlindedPayment
}
// newUnifiedEdge constructs a new unifiedEdge.
func newUnifiedEdge(policy *models.CachedEdgePolicy, capacity btcutil.Amount,
inboundFees models.InboundFee, hopPayloadSizeFn PayloadSizeFunc,
blindedPayment *BlindedPayment) *unifiedEdge {
return &unifiedEdge{
policy: policy,
capacity: capacity,
inboundFees: inboundFees,
hopPayloadSizeFn: hopPayloadSizeFn,
blindedPayment: blindedPayment,
}
}
// amtInRange checks whether an amount falls within the valid range for a
// channel.
func (u *unifiedEdge) amtInRange(amt lnwire.MilliSatoshi) bool {
// If the capacity is available (non-light clients), skip channels that
// are too small.
if u.capacity > 0 &&
amt > lnwire.NewMSatFromSatoshis(u.capacity) {
log.Tracef("Not enough capacity: amt=%v, capacity=%v",
amt, u.capacity)
return false
}
// Skip channels for which this htlc is too large.
if u.policy.MessageFlags.HasMaxHtlc() &&
amt > u.policy.MaxHTLC {
log.Tracef("Exceeds policy's MaxHTLC: amt=%v, MaxHTLC=%v",
amt, u.policy.MaxHTLC)
return false
}
// Skip channels for which this htlc is too small.
if amt < u.policy.MinHTLC {
log.Tracef("below policy's MinHTLC: amt=%v, MinHTLC=%v",
amt, u.policy.MinHTLC)
return false
}
return true
}
// edgeUnifier is an object that covers all channels between a pair of nodes.
type edgeUnifier struct {
edges []*unifiedEdge
localChan bool
}
// getEdge returns the optimal unified edge to use for this connection given a
// specific amount to send. It differentiates between local and network
// channels.
func (u *edgeUnifier) getEdge(netAmtReceived lnwire.MilliSatoshi,
bandwidthHints bandwidthHints,
nextOutFee lnwire.MilliSatoshi) *unifiedEdge {
if u.localChan {
return u.getEdgeLocal(
netAmtReceived, bandwidthHints, nextOutFee,
)
}
return u.getEdgeNetwork(netAmtReceived, nextOutFee)
}
// calcCappedInboundFee calculates the inbound fee for a channel, taking into
// account the total node fee for the "to" node.
func calcCappedInboundFee(edge *unifiedEdge, amt lnwire.MilliSatoshi,
nextOutFee lnwire.MilliSatoshi) int64 {
// Calculate the inbound fee charged for the amount that passes over the
// channel.
inboundFee := edge.inboundFees.CalcFee(amt)
// Take into account that the total node fee cannot be negative.
if inboundFee < -int64(nextOutFee) {
inboundFee = -int64(nextOutFee)
}
return inboundFee
}
// getEdgeLocal returns the optimal unified edge to use for this local
// connection given a specific amount to send.
func (u *edgeUnifier) getEdgeLocal(netAmtReceived lnwire.MilliSatoshi,
bandwidthHints bandwidthHints,
nextOutFee lnwire.MilliSatoshi) *unifiedEdge {
var (
bestEdge *unifiedEdge
maxBandwidth lnwire.MilliSatoshi
)
for _, edge := range u.edges {
// Calculate the inbound fee charged at the receiving node.
inboundFee := calcCappedInboundFee(
edge, netAmtReceived, nextOutFee,
)
// Add inbound fee to get to the amount that is sent over the
// local channel.
amt := netAmtReceived + lnwire.MilliSatoshi(inboundFee)
// Check valid amount range for the channel. We skip this test
// for payments with custom HTLC data, as the amount sent on
// the BTC layer may differ from the amount that is actually
// forwarded in custom channels.
if bandwidthHints.firstHopCustomBlob().IsNone() &&
!edge.amtInRange(amt) {
log.Debugf("Amount %v not in range for edge %v",
netAmtReceived, edge.policy.ChannelID)
continue
}
// For local channels, there is no fee to pay or an extra time
// lock. We only consider the currently available bandwidth for
// channel selection. The disabled flag is ignored for local
// channels.
// Retrieve bandwidth for this local channel. If not
// available, assume this channel has enough bandwidth.
//
// TODO(joostjager): Possibly change to skipping this
// channel. The bandwidth hint is expected to be
// available.
bandwidth, ok := bandwidthHints.availableChanBandwidth(
edge.policy.ChannelID, amt,
)
if !ok {
log.Debugf("Cannot get bandwidth for edge %v, use max "+
"instead", edge.policy.ChannelID)
bandwidth = lnwire.MaxMilliSatoshi
}
// TODO(yy): if the above `!ok` is chosen, we'd have
// `bandwidth` to be the max value, which will end up having
// the `maxBandwidth` to be have the largest value and this
// edge will be the chosen one. This is wrong in two ways,
// 1. we need to understand why `availableChanBandwidth` cannot
// find bandwidth for this edge as something is wrong with this
// channel, and,
// 2. this edge is likely NOT the local channel with the
// highest available bandwidth.
//
// Skip channels that can't carry the payment.
if amt > bandwidth {
log.Debugf("Skipped edge %v: not enough bandwidth, "+
"bandwidth=%v, amt=%v", edge.policy.ChannelID,
bandwidth, amt)
continue
}
// We pick the local channel with the highest available
// bandwidth, to maximize the success probability. It can be
// that the channel state changes between querying the bandwidth
// hints and sending out the htlc.
if bandwidth < maxBandwidth {
log.Debugf("Skipped edge %v: not max bandwidth, "+
"bandwidth=%v, maxBandwidth=%v",
edge.policy.ChannelID, bandwidth, maxBandwidth)
continue
}
maxBandwidth = bandwidth
// Update best edge.
bestEdge = newUnifiedEdge(
edge.policy, edge.capacity, edge.inboundFees,
edge.hopPayloadSizeFn, edge.blindedPayment,
)
}
return bestEdge
}
// getEdgeNetwork returns the optimal unified edge to use for this connection
// given a specific amount to send. The goal is to return a unified edge with a
// policy that maximizes the probability of a successful forward in a non-strict
// forwarding context.
func (u *edgeUnifier) getEdgeNetwork(netAmtReceived lnwire.MilliSatoshi,
nextOutFee lnwire.MilliSatoshi) *unifiedEdge {
var (
bestPolicy *unifiedEdge
maxFee int64 = math.MinInt64
maxTimelock uint16
maxCapMsat lnwire.MilliSatoshi
hopPayloadSizeFn PayloadSizeFunc
)
for _, edge := range u.edges {
// Calculate the inbound fee charged at the receiving node.
inboundFee := calcCappedInboundFee(
edge, netAmtReceived, nextOutFee,
)
// Add inbound fee to get to the amount that is sent over the
// channel.
amt := netAmtReceived + lnwire.MilliSatoshi(inboundFee)
// Check valid amount range for the channel.
if !edge.amtInRange(amt) {
log.Debugf("Amount %v not in range for edge %v",
amt, edge.policy.ChannelID)
continue
}
// For network channels, skip the disabled ones.
edgeFlags := edge.policy.ChannelFlags
isDisabled := edgeFlags&lnwire.ChanUpdateDisabled != 0
if isDisabled {
log.Debugf("Skipped edge %v due to it being disabled",
edge.policy.ChannelID)
continue
}
// Track the maximal capacity for usable channels. If we don't
// know the capacity, we fall back to MaxHTLC.
capMsat := lnwire.NewMSatFromSatoshis(edge.capacity)
if capMsat == 0 && edge.policy.MessageFlags.HasMaxHtlc() {
log.Tracef("No capacity available for channel %v, "+
"using MaxHtlcMsat (%v) as a fallback.",
edge.policy.ChannelID, edge.policy.MaxHTLC)
capMsat = edge.policy.MaxHTLC
}
maxCapMsat = max(capMsat, maxCapMsat)
// Track the maximum time lock of all channels that are
// candidate for non-strict forwarding at the routing node.
maxTimelock = max(maxTimelock, edge.policy.TimeLockDelta)
outboundFee := int64(edge.policy.ComputeFee(amt))
fee := outboundFee + inboundFee
// Use the policy that results in the highest fee for this
// specific amount.
if fee < maxFee {
log.Debugf("Skipped edge %v due to it produces less "+
"fee: fee=%v, maxFee=%v",
edge.policy.ChannelID, fee, maxFee)
continue
}
maxFee = fee
bestPolicy = newUnifiedEdge(
edge.policy, 0, edge.inboundFees, nil,
edge.blindedPayment,
)
// The payload size function for edges to a connected peer is
// always the same hence there is not need to find the maximum.
// This also counts for blinded edges where we only have one
// edge to a blinded peer.
hopPayloadSizeFn = edge.hopPayloadSizeFn
}
// Return early if no channel matches.
if bestPolicy == nil {
return nil
}
// We have already picked the highest fee that could be required for
// non-strict forwarding. To also cover the case where a lower fee
// channel requires a longer time lock, we modify the policy by setting
// the maximum encountered time lock. Note that this results in a
// synthetic policy that is not actually present on the routing node.
//
// The reason we do this, is that we try to maximize the chance that we
// get forwarded. Because we penalize pair-wise, there won't be a second
// chance for this node pair. But this is all only needed for nodes that
// have distinct policies for channels to the same peer.
policyCopy := *bestPolicy.policy
policyCopy.TimeLockDelta = maxTimelock
modifiedEdge := newUnifiedEdge(
&policyCopy, maxCapMsat.ToSatoshis(), bestPolicy.inboundFees,
hopPayloadSizeFn, bestPolicy.blindedPayment,
)
return modifiedEdge
}
// minAmt returns the minimum amount that can be forwarded on this connection.
func (u *edgeUnifier) minAmt() lnwire.MilliSatoshi {
minAmount := lnwire.MaxMilliSatoshi
for _, edge := range u.edges {
minAmount = min(minAmount, edge.policy.MinHTLC)
}
return minAmount
}
package shachain
import (
"crypto/sha256"
"errors"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// element represents the entity which contains the hash and index
// corresponding to it. An element is the output of the shachain PRF. By
// comparing two indexes we're able to mutate the hash in such way to derive
// another element.
type element struct {
index index
hash chainhash.Hash
}
// newElementFromStr creates new element from the given hash string.
func newElementFromStr(s string, index index) (*element, error) {
hash, err := hashFromString(s)
if err != nil {
return nil, err
}
return &element{
index: index,
hash: *hash,
}, nil
}
// derive computes one shachain element from another by applying a series of
// bit flips and hashing operations based on the starting and ending index.
func (e *element) derive(toIndex index) (*element, error) {
fromIndex := e.index
positions, err := fromIndex.deriveBitTransformations(toIndex)
if err != nil {
return nil, err
}
buf := e.hash.CloneBytes()
for _, position := range positions {
// Flip the bit and then hash the current state.
byteNumber := position / 8
bitNumber := position % 8
buf[byteNumber] ^= (1 << bitNumber)
h := sha256.Sum256(buf)
buf = h[:]
}
hash, err := chainhash.NewHash(buf)
if err != nil {
return nil, err
}
return &element{
index: toIndex,
hash: *hash,
}, nil
}
// isEqual returns true if two elements are identical and false otherwise.
func (e *element) isEqual(e2 *element) bool {
return (e.index == e2.index) &&
(&e.hash).IsEqual(&e2.hash)
}
const (
// maxHeight is used to determine the maximum allowable index and the
// length of the array required to order to derive all previous hashes
// by index. The entries of this array as also known as buckets.
maxHeight uint8 = 48
// rootIndex is an index which corresponds to the root hash.
rootIndex index = 0
)
// startIndex is the index of first element in the shachain PRF.
var startIndex index = (1 << maxHeight) - 1
// index is a number which identifies the hash number and serves as a way to
// determine the hashing operation required to derive one hash from another.
// index is initialized with the startIndex and decreases down to zero with
// successive derivations.
type index uint64
// newIndex is used to create index instance. The inner operations with index
// implies that index decreasing from some max number to zero, but for
// simplicity and backward compatibility with previous logic it was transformed
// to work in opposite way.
func newIndex(v uint64) index {
return startIndex - index(v)
}
// deriveBitTransformations function checks that the 'to' index is derivable
// from the 'from' index by checking the indexes are prefixes of another. The
// bit positions where the zeroes should be changed to ones in order for the
// indexes to become the same are returned. This set of bits is needed in order
// to derive one hash from another.
//
// NOTE: The index 'to' is derivable from index 'from' iff index 'from' lies
// left and above index 'to' on graph below, for example:
// 1. 7(0b111) -> 7
// 2. 6(0b110) -> 6,7
// 3. 5(0b101) -> 5
// 4. 4(0b100) -> 4,5,6,7
// 5. 3(0b011) -> 3
// 6. 2(0b010) -> 2, 3
// 7. 1(0b001) -> 1
//
// ^ bucket number
// |
// 3 | x
// | |
// 2 | | x
// | | |
// 1 | | x | x
// | | | | |
// 0 | | x | x | x | x
// | | | | | | | | |
// +---|---|---|---|---|---|---|---|---> index
// 0 1 2 3 4 5 6 7
func (from index) deriveBitTransformations(to index) ([]uint8, error) {
var positions []uint8
if from == to {
return positions, nil
}
// + --------------- +
// | № | from | to |
// + -- + ---- + --- +
// | 48 | 1 | 1 |
// | 47 | 0 | 0 | [48-5] - same part of 'from' and 'to'
// | 46 | 0 | 0 | indexes which also is called prefix.
// ....
// | 5 | 1 | 1 |
// | 4 | 0 | 1 | <--- position after which indexes becomes
// | 3 | 0 | 0 | different, after this position
// | 2 | 0 | 1 | bits in 'from' index all should be
// | 1 | 0 | 0 | zeros or such indexes considered to be
// | 0 | 0 | 1 | not derivable.
// + -- + ---- + --- +
zeros := countTrailingZeros(from)
if uint64(from) != getPrefix(to, zeros) {
return nil, errors.New("prefixes are different - indexes " +
"aren't derivable")
}
// The remaining part of 'to' index represents the positions which we
// will use then in order to derive one element from another.
for position := zeros - 1; ; position-- {
if getBit(to, position) == 1 {
positions = append(positions, position)
}
if position == 0 {
break
}
}
return positions, nil
}
package shachain
import (
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// Producer is an interface which serves as an abstraction over the data
// structure responsible for efficiently generating the secrets for a
// particular index based on a root seed. The generation of secrets should be
// made in such way that secret store might efficiently store and retrieve the
// secrets. This is typically implemented as a tree-based PRF.
type Producer interface {
// AtIndex produces a secret by evaluating using the initial seed and a
// particular index.
AtIndex(uint64) (*chainhash.Hash, error)
// Encode writes a binary serialization of the Producer implementation
// to the passed io.Writer.
Encode(io.Writer) error
}
// RevocationProducer is an implementation of Producer interface using the
// shachain PRF construct. Starting with a single 32-byte element generated
// from a CSPRNG, shachain is able to efficiently generate a nearly unbounded
// number of secrets while maintaining a constant amount of storage. The
// original description of shachain can be found here:
// https://github.com/rustyrussell/ccan/blob/master/ccan/crypto/shachain/design.txt
// with supplementary material here:
// https://github.com/lightningnetwork/lightning-rfc/blob/master/03-transactions.md#per-commitment-secret-requirements
type RevocationProducer struct {
// root is the element from which we may generate all hashes which
// corresponds to the index domain [281474976710655,0].
root *element
}
// A compile time check to ensure RevocationProducer implements the Producer
// interface.
var _ Producer = (*RevocationProducer)(nil)
// NewRevocationProducer creates new instance of shachain producer.
func NewRevocationProducer(root chainhash.Hash) *RevocationProducer {
return &RevocationProducer{
root: &element{
index: rootIndex,
hash: root,
}}
}
// NewRevocationProducerFromBytes deserializes an instance of a
// RevocationProducer encoded in the passed byte slice, returning a fully
// initialized instance of a RevocationProducer.
func NewRevocationProducerFromBytes(data []byte) (*RevocationProducer, error) {
root, err := chainhash.NewHash(data)
if err != nil {
return nil, err
}
return &RevocationProducer{
root: &element{
index: rootIndex,
hash: *root,
},
}, nil
}
// AtIndex produces a secret by evaluating using the initial seed and a
// particular index.
//
// NOTE: Part of the Producer interface.
func (p *RevocationProducer) AtIndex(v uint64) (*chainhash.Hash, error) {
ind := newIndex(v)
element, err := p.root.derive(ind)
if err != nil {
return nil, err
}
return &element.hash, nil
}
// Encode writes a binary serialization of the Producer implementation to the
// passed io.Writer.
//
// NOTE: Part of the Producer interface.
func (p *RevocationProducer) Encode(w io.Writer) error {
_, err := w.Write(p.root.hash[:])
return err
}
package shachain
import (
"encoding/binary"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/go-errors/errors"
)
// Store is an interface which serves as an abstraction over data structure
// responsible for efficiently storing and restoring of hash secrets by given
// indexes.
//
// Description: The Lightning Network wants a chain of (say 1 million)
// unguessable 256 bit values; we generate them and send them one at a time to
// a remote node. We don't want the remote node to have to store all the
// values, so it's better if they can derive them once they see them.
type Store interface {
// LookUp function is used to restore/lookup/fetch the previous secret
// by its index.
LookUp(uint64) (*chainhash.Hash, error)
// AddNextEntry attempts to store the given hash within its internal
// storage in an efficient manner.
//
// NOTE: The hashes derived from the shachain MUST be inserted in the
// order they're produced by a shachain.Producer.
AddNextEntry(*chainhash.Hash) error
// Encode writes a binary serialization of the shachain elements
// currently saved by implementation of shachain.Store to the passed
// io.Writer.
Encode(io.Writer) error
}
// RevocationStore is a concrete implementation of the Store interface. The
// revocation store is able to efficiently store N derived shachain elements in
// a space efficient manner with a space complexity of O(log N). The original
// description of the storage methodology can be found here:
// https://github.com/lightningnetwork/lightning-rfc/blob/master/03-transactions.md#efficient-per-commitment-secret-storage
type RevocationStore struct {
// lenBuckets stores the number of currently active buckets.
lenBuckets uint8
// buckets is an array of elements from which we may derive all
// previous elements, each bucket corresponds to the element with the
// particular number of trailing zeros.
buckets [maxHeight]element
// index is an available index which will be assigned to the new
// element.
index index
}
// A compile time check to ensure RevocationStore implements the Store
// interface.
var _ Store = (*RevocationStore)(nil)
// NewRevocationStore creates the new shachain store.
func NewRevocationStore() *RevocationStore {
return &RevocationStore{
lenBuckets: 0,
index: startIndex,
}
}
// NewRevocationStoreFromBytes recreates the initial store state from the given
// binary shachain store representation.
func NewRevocationStoreFromBytes(r io.Reader) (*RevocationStore, error) {
store := &RevocationStore{}
if err := binary.Read(r, binary.BigEndian, &store.lenBuckets); err != nil {
return nil, err
}
for i := uint8(0); i < store.lenBuckets; i++ {
var hashIndex index
err := binary.Read(r, binary.BigEndian, &hashIndex)
if err != nil {
return nil, err
}
var nextHash chainhash.Hash
if _, err := io.ReadFull(r, nextHash[:]); err != nil {
return nil, err
}
store.buckets[i] = element{
index: hashIndex,
hash: nextHash,
}
}
if err := binary.Read(r, binary.BigEndian, &store.index); err != nil {
return nil, err
}
return store, nil
}
// LookUp function is used to restore/lookup/fetch the previous secret by its
// index. If secret which corresponds to given index was not previously placed
// in store we will not able to derive it and function will fail.
//
// NOTE: This function is part of the Store interface.
func (store *RevocationStore) LookUp(v uint64) (*chainhash.Hash, error) {
ind := newIndex(v)
// Trying to derive the index from one of the existing buckets elements.
for i := uint8(0); i < store.lenBuckets; i++ {
element, err := store.buckets[i].derive(ind)
if err != nil {
continue
}
return &element.hash, nil
}
return nil, errors.Errorf("unable to derive hash #%v", ind)
}
// AddNextEntry attempts to store the given hash within its internal storage in
// an efficient manner.
//
// NOTE: The hashes derived from the shachain MUST be inserted in the order
// they're produced by a shachain.Producer.
//
// NOTE: This function is part of the Store interface.
func (store *RevocationStore) AddNextEntry(hash *chainhash.Hash) error {
newElement := &element{
index: store.index,
hash: *hash,
}
bucket := countTrailingZeros(newElement.index)
for i := uint8(0); i < bucket; i++ {
e, err := newElement.derive(store.buckets[i].index)
if err != nil {
return err
}
if !e.isEqual(&store.buckets[i]) {
return errors.New("hash isn't derivable from " +
"previous ones")
}
}
store.buckets[bucket] = *newElement
if bucket+1 > store.lenBuckets {
store.lenBuckets = bucket + 1
}
store.index--
return nil
}
// Encode writes a binary serialization of the shachain elements currently
// saved by implementation of shachain.Store to the passed io.Writer.
//
// NOTE: This function is part of the Store interface.
func (store *RevocationStore) Encode(w io.Writer) error {
err := binary.Write(w, binary.BigEndian, store.lenBuckets)
if err != nil {
return err
}
for i := uint8(0); i < store.lenBuckets; i++ {
element := store.buckets[i]
err := binary.Write(w, binary.BigEndian, element.index)
if err != nil {
return err
}
if _, err = w.Write(element.hash[:]); err != nil {
return err
}
}
return binary.Write(w, binary.BigEndian, store.index)
}
package shachain
import (
"encoding/hex"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// getBit return bit on index at position.
func getBit(index index, position uint8) uint8 {
return uint8((uint64(index) >> position) & 1)
}
func getPrefix(index index, position uint8) uint64 {
// + -------------------------- +
// | № | value | mask | return |
// + -- + ----- + ---- + ------ +
// | 63 | 1 | 0 | 0 |
// | 62 | 0 | 0 | 0 |
// | 61 | 1 | 0 | 0 |
// ....
// | 4 | 1 | 0 | 0 |
// | 3 | 1 | 0 | 0 |
// | 2 | 1 | 1 | 1 | <--- position
// | 1 | 0 | 1 | 0 |
// | 0 | 1 | 1 | 1 |
// + -- + ----- + ---- + ------ +
var zero uint64
mask := (zero - 1) - uint64((1<<position)-1)
return (uint64(index) & mask)
}
// countTrailingZeros counts number of trailing zero bits, this function is
// used to determine the number of element bucket.
func countTrailingZeros(index index) uint8 {
var zeros uint8
for ; zeros < maxHeight; zeros++ {
if getBit(index, zeros) != 0 {
break
}
}
return zeros
}
// hashFromString takes a hex-encoded string as input and creates an instance of
// chainhash.Hash. The chainhash.NewHashFromStr function not suitable because
// it reverse the given hash.
func hashFromString(s string) (*chainhash.Hash, error) {
// Return an error if hash string is too long.
if len(s) > chainhash.MaxHashStringSize {
return nil, chainhash.ErrHashStrSize
}
// Hex decoder expects the hash to be a multiple of two.
if len(s)%2 != 0 {
s = "0" + s
}
// Convert string hash to bytes.
buf, err := hex.DecodeString(s)
if err != nil {
return nil, err
}
hash, err := chainhash.NewHash(buf)
if err != nil {
return nil, err
}
return hash, nil
}
package subscribe
import (
"errors"
"sync"
"sync/atomic"
"github.com/lightningnetwork/lnd/queue"
)
// ErrServerShuttingDown is an error returned in case the server is in the
// process of shutting down.
var ErrServerShuttingDown = errors.New("subscription server shutting down")
// Client is used to get notified about updates the caller has subscribed to,
type Client struct {
// cancel should be called in case the client no longer wants to
// subscribe for updates from the server.
cancel func()
updates *queue.ConcurrentQueue
quit chan struct{}
}
// Updates returns a read-only channel where the updates the client has
// subscribed to will be delivered.
func (c *Client) Updates() <-chan interface{} {
return c.updates.ChanOut()
}
// Quit is a channel that will be closed in case the server decides to no
// longer deliver updates to this client.
func (c *Client) Quit() <-chan struct{} {
return c.quit
}
// Cancel should be called in case the client no longer wants to
// subscribe for updates from the server.
func (c *Client) Cancel() {
c.cancel()
}
// Server is a struct that manages a set of subscriptions and their
// corresponding clients. Any update will be delivered to all active clients.
type Server struct {
clientCounter uint64 // To be used atomically.
started uint32 // To be used atomically.
stopped uint32 // To be used atomically.
clients map[uint64]*Client
clientUpdates chan *clientUpdate
updates chan interface{}
quit chan struct{}
wg sync.WaitGroup
}
// clientUpdate is an internal message sent to the subscriptionHandler to
// either register a new client for subscription or cancel an existing
// subscription.
type clientUpdate struct {
// cancel indicates if the update to the client is cancelling an
// existing client's subscription. If not then this update will be to
// subscribe a new client.
cancel bool
// clientID is the unique identifier for this client. Any further
// updates (deleting or adding) to this notification client will be
// dispatched according to the target clientID.
clientID uint64
// client is the new client that will receive updates. Will be nil in
// case this is a cancellation update.
client *Client
}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{
clients: make(map[uint64]*Client),
clientUpdates: make(chan *clientUpdate),
updates: make(chan interface{}),
quit: make(chan struct{}),
}
}
// Start starts the Server, making it ready to accept subscriptions and
// updates.
func (s *Server) Start() error {
if !atomic.CompareAndSwapUint32(&s.started, 0, 1) {
return nil
}
s.wg.Add(1)
go s.subscriptionHandler()
return nil
}
// Stop stops the server.
func (s *Server) Stop() error {
if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
return nil
}
close(s.quit)
s.wg.Wait()
return nil
}
// Subscribe returns a Client that will receive updates any time the Server is
// made aware of a new event.
func (s *Server) Subscribe() (*Client, error) {
// We'll first atomically obtain the next ID for this client from the
// incrementing client ID counter.
clientID := atomic.AddUint64(&s.clientCounter, 1)
// Create the client that will be returned. The Cancel method is
// populated to send the cancellation intent to the
// subscriptionHandler.
client := &Client{
updates: queue.NewConcurrentQueue(20),
quit: make(chan struct{}),
cancel: func() {
select {
case s.clientUpdates <- &clientUpdate{
cancel: true,
clientID: clientID,
}:
case <-s.quit:
return
}
},
}
select {
case s.clientUpdates <- &clientUpdate{
cancel: false,
clientID: clientID,
client: client,
}:
case <-s.quit:
return nil, ErrServerShuttingDown
}
return client, nil
}
// SendUpdate is called to send the passed update to all currently active
// subscription clients.
func (s *Server) SendUpdate(update interface{}) error {
select {
case s.updates <- update:
return nil
case <-s.quit:
return ErrServerShuttingDown
}
}
// subscriptionHandler is the main handler for the Server. It will handle
// incoming updates and subscriptions, and forward the incoming updates to the
// registered clients.
//
// NOTE: MUST be run as a goroutine.
func (s *Server) subscriptionHandler() {
defer s.wg.Done()
for {
select {
// If a client update is received, the either a new
// subscription becomes active, or we cancel and existing one.
case update := <-s.clientUpdates:
clientID := update.clientID
// In case this is a cancellation, stop the client's
// underlying queue, and remove the client from the set
// of active subscription clients.
if update.cancel {
client, ok := s.clients[update.clientID]
if ok {
client.updates.Stop()
close(client.quit)
delete(s.clients, clientID)
}
continue
}
// If this was not a cancellation, start the underlying
// queue and add the client to our set of subscription
// clients. It will be notified about any new updates
// the server receives.
update.client.updates.Start()
s.clients[update.clientID] = update.client
// A new update was received, forward it to all active clients.
case upd := <-s.updates:
for _, client := range s.clients {
select {
case client.updates.ChanIn() <- upd:
case <-client.quit:
case <-s.quit:
return
}
}
// In case the server is shutting down, stop the clients and
// close the quit channels to notify them.
case <-s.quit:
for _, client := range s.clients {
client.updates.Stop()
close(client.quit)
}
return
}
}
}
package sweep
import (
"sort"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
// UtxoAggregator defines an interface that takes a list of inputs and
// aggregate them into groups. Each group is used as the inputs to create a
// sweeping transaction.
type UtxoAggregator interface {
// ClusterInputs takes a list of inputs and groups them into input
// sets. Each input set will be used to create a sweeping transaction.
ClusterInputs(inputs InputsMap) []InputSet
}
// BudgetAggregator is a budget-based aggregator that creates clusters based on
// deadlines and budgets of inputs.
type BudgetAggregator struct {
// estimator is used when crafting sweep transactions to estimate the
// necessary fee relative to the expected size of the sweep
// transaction.
estimator chainfee.Estimator
// maxInputs specifies the maximum number of inputs allowed in a single
// sweep tx.
maxInputs uint32
// auxSweeper is an optional interface that can be used to modify the
// way sweep transaction are generated.
auxSweeper fn.Option[AuxSweeper]
}
// Compile-time constraint to ensure BudgetAggregator implements UtxoAggregator.
var _ UtxoAggregator = (*BudgetAggregator)(nil)
// NewBudgetAggregator creates a new instance of a BudgetAggregator.
func NewBudgetAggregator(estimator chainfee.Estimator,
maxInputs uint32, auxSweeper fn.Option[AuxSweeper]) *BudgetAggregator {
return &BudgetAggregator{
estimator: estimator,
maxInputs: maxInputs,
auxSweeper: auxSweeper,
}
}
// clusterGroup defines an alias for a set of inputs that are to be grouped.
type clusterGroup map[int32][]SweeperInput
// ClusterInputs creates a list of input sets from pending inputs.
// 1. filter out inputs whose budget cannot cover min relay fee.
// 2. filter a list of exclusive inputs.
// 3. group the inputs into clusters based on their deadline height.
// 4. sort the inputs in each cluster by their budget.
// 5. optionally split a cluster if it exceeds the max input limit.
// 6. create input sets from each of the clusters.
// 7. create input sets for each of the exclusive inputs.
func (b *BudgetAggregator) ClusterInputs(inputs InputsMap) []InputSet {
// Filter out inputs that have a budget below min relay fee.
filteredInputs := b.filterInputs(inputs)
// Create clusters to group inputs based on their deadline height.
clusters := make(clusterGroup, len(filteredInputs))
// exclusiveInputs is a set of inputs that are not to be included in
// any cluster. These inputs can only be swept independently as there's
// no guarantee which input will be confirmed first, which means
// grouping exclusive inputs may jeopardize non-exclusive inputs.
exclusiveInputs := make(map[wire.OutPoint]clusterGroup)
// Iterate all the inputs and group them based on their specified
// deadline heights.
for _, input := range filteredInputs {
// Get deadline height, and use the specified default deadline
// height if it's not set.
height := input.DeadlineHeight
// Put exclusive inputs in their own set.
if input.params.ExclusiveGroup != nil {
log.Tracef("Input %v is exclusive", input.OutPoint())
exclusiveInputs[input.OutPoint()] = clusterGroup{
height: []SweeperInput{*input},
}
continue
}
cluster, ok := clusters[height]
if !ok {
cluster = make([]SweeperInput, 0)
}
cluster = append(cluster, *input)
clusters[height] = cluster
}
// Now that we have the clusters, we can create the input sets.
//
// NOTE: cannot pre-allocate the slice since we don't know the number
// of input sets in advance.
inputSets := make([]InputSet, 0)
for height, cluster := range clusters {
// Sort the inputs by their economical value.
sortedInputs := b.sortInputs(cluster)
// Split on locktimes if they are different.
splitClusters := splitOnLocktime(sortedInputs)
// Create input sets from the cluster.
for _, cluster := range splitClusters {
sets := b.createInputSets(cluster, height)
inputSets = append(inputSets, sets...)
}
}
// Create input sets from the exclusive inputs.
for _, cluster := range exclusiveInputs {
for height, input := range cluster {
sets := b.createInputSets(input, height)
inputSets = append(inputSets, sets...)
}
}
return inputSets
}
// createInputSet takes a set of inputs which share the same deadline height
// and turns them into a list of `InputSet`, each set is then used to create a
// sweep transaction.
//
// TODO(yy): by the time we call this method, all the invalid/uneconomical
// inputs have been filtered out, all the inputs have been sorted based on
// their budgets, and we are about to create input sets. The only thing missing
// here is, we need to group the inputs here even further based on whether
// their budgets can cover the starting fee rate used for this input set.
func (b *BudgetAggregator) createInputSets(inputs []SweeperInput,
deadlineHeight int32) []InputSet {
// sets holds the InputSets that we will return.
sets := make([]InputSet, 0)
// Copy the inputs to a new slice so we can modify it.
remainingInputs := make([]SweeperInput, len(inputs))
copy(remainingInputs, inputs)
// If the number of inputs is greater than the max inputs allowed, we
// will split them into smaller clusters.
for uint32(len(remainingInputs)) > b.maxInputs {
log.Tracef("Cluster has %v inputs, max is %v, dividing...",
len(inputs), b.maxInputs)
// Copy the inputs to be put into the new set, and update the
// remaining inputs by removing currentInputs.
currentInputs := make([]SweeperInput, b.maxInputs)
copy(currentInputs, remainingInputs[:b.maxInputs])
remainingInputs = remainingInputs[b.maxInputs:]
// Create an InputSet using the max allowed number of inputs.
set, err := NewBudgetInputSet(
currentInputs, deadlineHeight, b.auxSweeper,
)
if err != nil {
log.Errorf("unable to create input set: %v", err)
continue
}
sets = append(sets, set)
}
// Create an InputSet from the remaining inputs.
if len(remainingInputs) > 0 {
set, err := NewBudgetInputSet(
remainingInputs, deadlineHeight, b.auxSweeper,
)
if err != nil {
log.Errorf("unable to create input set: %v", err)
return nil
}
sets = append(sets, set)
}
return sets
}
// filterInputs filters out inputs that have,
// - a budget below the min relay fee.
// - a budget below its requested starting fee.
// - a required output that's below the dust.
func (b *BudgetAggregator) filterInputs(inputs InputsMap) InputsMap {
// Get the current min relay fee for this round.
minFeeRate := b.estimator.RelayFeePerKW()
// filterInputs stores a map of inputs that has a budget that at least
// can pay the minimal fee.
filteredInputs := make(InputsMap, len(inputs))
// Iterate all the inputs and filter out the ones whose budget cannot
// cover the min fee.
for _, pi := range inputs {
op := pi.OutPoint()
// Get the size of the witness and skip if there's an error.
witnessSize, _, err := pi.WitnessType().SizeUpperBound()
if err != nil {
log.Warnf("Skipped input=%v: cannot get its size: %v",
op, err)
continue
}
//nolint:ll
// Calculate the size if the input is included in the tx.
//
// NOTE: When including this input, we need to account the
// non-witness data which is expressed in vb.
//
// TODO(yy): This is not accurate for tapscript input. We need
// to unify calculations used in the `TxWeightEstimator` inside
// `input/size.go` and `weightEstimator` in
// `weight_estimator.go`. And calculate the expected weights
// similar to BOLT-3:
// https://github.com/lightning/bolts/blob/master/03-transactions.md#appendix-a-expected-weights
wu := lntypes.VByte(input.InputSize).ToWU() + witnessSize
// Skip inputs that has too little budget.
minFee := minFeeRate.FeeForWeight(wu)
if pi.params.Budget < minFee {
log.Warnf("Skipped input=%v: has budget=%v, but the "+
"min fee requires %v (feerate=%v), size=%v", op,
pi.params.Budget, minFee,
minFeeRate.FeePerVByte(), wu.ToVB())
continue
}
// Skip inputs that has cannot cover its starting fees.
startingFeeRate := pi.params.StartingFeeRate.UnwrapOr(
chainfee.SatPerKWeight(0),
)
startingFee := startingFeeRate.FeeForWeight(wu)
if pi.params.Budget < startingFee {
log.Errorf("Skipped input=%v: has budget=%v, but the "+
"starting fee requires %v (feerate=%v), "+
"size=%v", op, pi.params.Budget, startingFee,
startingFeeRate.FeePerVByte(), wu.ToVB())
continue
}
// If the input comes with a required tx out that is below
// dust, we won't add it.
//
// NOTE: only HtlcSecondLevelAnchorInput returns non-nil
// RequiredTxOut.
reqOut := pi.RequiredTxOut()
if reqOut != nil {
if isDustOutput(reqOut) {
log.Errorf("Rejected input=%v due to dust "+
"required output=%v", op, reqOut.Value)
continue
}
}
filteredInputs[op] = pi
}
return filteredInputs
}
// sortInputs sorts the inputs based on their economical value.
//
// NOTE: besides the forced inputs, the sorting won't make any difference
// because all the inputs are added to the same set. The exception is when the
// number of inputs exceeds the maxInputs limit, it requires us to split them
// into smaller clusters. In that case, the sorting will make a difference as
// the budgets of the clusters will be different.
func (b *BudgetAggregator) sortInputs(inputs []SweeperInput) []SweeperInput {
// sortedInputs is the final list of inputs sorted by their economical
// value.
sortedInputs := make([]SweeperInput, 0, len(inputs))
// Copy the inputs.
sortedInputs = append(sortedInputs, inputs...)
// Sort the inputs based on their budgets.
//
// NOTE: We can implement more sophisticated algorithm as the budget
// left is a function f(minFeeRate, size) = b1 - s1 * r > b2 - s2 * r,
// where b1 and b2 are budgets, s1 and s2 are sizes of the inputs.
sort.Slice(sortedInputs, func(i, j int) bool {
left := sortedInputs[i].params.Budget
right := sortedInputs[j].params.Budget
// Make sure forced inputs are always put in the front.
leftForce := sortedInputs[i].params.Immediate
rightForce := sortedInputs[j].params.Immediate
// If both are forced inputs, we return the one with the higher
// budget. If neither are forced inputs, we also return the one
// with the higher budget.
if leftForce == rightForce {
return left > right
}
// Otherwise, it's either the left or the right is forced. We
// can simply return `leftForce` here as, if it's true, the
// left is forced and should be put in the front. Otherwise,
// the right is forced and should be put in the front.
return leftForce
})
return sortedInputs
}
// splitOnLocktime splits the list of inputs based on their locktime.
//
// TODO(yy): this is a temporary hack as the blocks are not synced among the
// contractcourt and the sweeper.
func splitOnLocktime(inputs []SweeperInput) map[uint32][]SweeperInput {
result := make(map[uint32][]SweeperInput)
noLocktimeInputs := make([]SweeperInput, 0, len(inputs))
// mergeLocktime is the locktime that we use to merge all the
// nolocktime inputs into.
var mergeLocktime uint32
// Iterate all inputs and split them based on their locktimes.
for _, inp := range inputs {
locktime, required := inp.RequiredLockTime()
if !required {
log.Tracef("No locktime required for input=%v",
inp.OutPoint())
noLocktimeInputs = append(noLocktimeInputs, inp)
continue
}
log.Tracef("Split input=%v on locktime=%v", inp.OutPoint(),
locktime)
// Get the slice - the slice will be initialized if not found.
inputList := result[locktime]
// Add the input to the list.
inputList = append(inputList, inp)
// Update the map.
result[locktime] = inputList
// Update the merge locktime.
mergeLocktime = locktime
}
// If there are locktime inputs, we will merge the no locktime inputs
// to the last locktime group found.
if len(result) > 0 {
log.Tracef("No locktime inputs has been merged to locktime=%v",
mergeLocktime)
result[mergeLocktime] = append(
result[mergeLocktime], noLocktimeInputs...,
)
} else {
// Otherwise just return the no locktime inputs.
result[mergeLocktime] = noLocktimeInputs
}
return result
}
// isDustOutput checks if the given output is considered as dust.
func isDustOutput(output *wire.TxOut) bool {
// Fetch the dust limit for this output.
dustLimit := lnwallet.DustLimitForSize(len(output.PkScript))
// If the output is below the dust limit, we consider it dust.
return btcutil.Amount(output.Value) < dustLimit
}
package sweep
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/chain"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// ErrInvalidBumpResult is returned when the bump result is invalid.
ErrInvalidBumpResult = errors.New("invalid bump result")
// ErrNotEnoughBudget is returned when the fee bumper decides the
// current budget cannot cover the fee.
ErrNotEnoughBudget = errors.New("not enough budget")
// ErrLocktimeImmature is returned when sweeping an input whose
// locktime is not reached.
ErrLocktimeImmature = errors.New("immature input")
// ErrTxNoOutput is returned when an output cannot be created during tx
// preparation, usually due to the output being dust.
ErrTxNoOutput = errors.New("tx has no output")
// ErrUnknownSpent is returned when an unknown tx has spent an input in
// the sweeping tx.
ErrUnknownSpent = errors.New("unknown spend of input")
// ErrInputMissing is returned when a given input no longer exists,
// e.g., spending from an orphan tx.
ErrInputMissing = errors.New("input no longer exists")
)
var (
// dummyChangePkScript is a dummy tapscript change script that's used
// when we don't need a real address, just something that can be used
// for fee estimation.
dummyChangePkScript = []byte{
0x51, 0x20,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
)
// Bumper defines an interface that can be used by other subsystems for fee
// bumping.
type Bumper interface {
// Broadcast is used to publish the tx created from the given inputs
// specified in the request. It handles the tx creation, broadcasts it,
// and monitors its confirmation status for potential fee bumping. It
// returns a chan that the caller can use to receive updates about the
// broadcast result and potential RBF attempts.
Broadcast(req *BumpRequest) <-chan *BumpResult
}
// BumpEvent represents the event of a fee bumping attempt.
type BumpEvent uint8
const (
// TxPublished is sent when the broadcast attempt is finished.
TxPublished BumpEvent = iota
// TxFailed is sent when the tx has encountered a fee-related error
// during its creation or broadcast, or an internal error from the fee
// bumper. In either case the inputs in this tx should be retried with
// either a different grouping strategy or an increased budget.
//
// TODO(yy): Remove the above usage once we remove sweeping non-CPFP
// anchors.
TxFailed
// TxReplaced is sent when the original tx is replaced by a new one.
TxReplaced
// TxConfirmed is sent when the tx is confirmed.
TxConfirmed
// TxUnknownSpend is sent when at least one of the inputs is spent but
// not by the current sweeping tx, this can happen when,
// - a remote party has replaced our sweeping tx by spending the
// input(s), e.g., via the direct preimage spend on our outgoing HTLC.
// - a third party has replaced our sweeping tx, e.g., the anchor output
// after 16 blocks.
// - A previous sweeping tx has confirmed but the fee bumper is not
// aware of it, e.g., a restart happens right after the sweeping tx is
// broadcast and confirmed.
TxUnknownSpend
// TxFatal is sent when the inputs in this tx cannot be retried. Txns
// will end up in this state if they have encountered a non-fee related
// error, which means they cannot be retried with increased budget.
TxFatal
// sentinalEvent is used to check if an event is unknown.
sentinalEvent
)
// String returns a human-readable string for the event.
func (e BumpEvent) String() string {
switch e {
case TxPublished:
return "Published"
case TxFailed:
return "Failed"
case TxReplaced:
return "Replaced"
case TxConfirmed:
return "Confirmed"
case TxUnknownSpend:
return "UnknownSpend"
case TxFatal:
return "Fatal"
default:
return "Unknown"
}
}
// Unknown returns true if the event is unknown.
func (e BumpEvent) Unknown() bool {
return e >= sentinalEvent
}
// BumpRequest is used by the caller to give the Bumper the necessary info to
// create and manage potential fee bumps for a set of inputs.
type BumpRequest struct {
// Budget givens the total amount that can be used as fees by these
// inputs.
Budget btcutil.Amount
// Inputs is the set of inputs to sweep.
Inputs []input.Input
// DeadlineHeight is the block height at which the tx should be
// confirmed.
DeadlineHeight int32
// DeliveryAddress is the script to send the change output to.
DeliveryAddress lnwallet.AddrWithKey
// MaxFeeRate is the maximum fee rate that can be used for fee bumping.
MaxFeeRate chainfee.SatPerKWeight
// StartingFeeRate is an optional parameter that can be used to specify
// the initial fee rate to use for the fee function.
StartingFeeRate fn.Option[chainfee.SatPerKWeight]
// ExtraTxOut tracks if this bump request has an optional set of extra
// outputs to add to the transaction.
ExtraTxOut fn.Option[SweepOutput]
// Immediate is used to specify that the tx should be broadcast
// immediately.
Immediate bool
}
// MaxFeeRateAllowed returns the maximum fee rate allowed for the given
// request. It calculates the feerate using the supplied budget and the weight,
// compares it with the specified MaxFeeRate, and returns the smaller of the
// two.
func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) {
// We'll want to know if we have any blobs, as we need to factor this
// into the max fee rate for this bump request.
hasBlobs := fn.Any(r.Inputs, func(i input.Input) bool {
return fn.MapOptionZ(
i.ResolutionBlob(), func(b tlv.Blob) bool {
return len(b) > 0
},
)
})
sweepAddrs := [][]byte{
r.DeliveryAddress.DeliveryAddress,
}
// If we have blobs, then we'll add an extra sweep addr for the size
// estimate below. We know that these blobs will also always be based on
// p2tr addrs.
if hasBlobs {
// We need to pass in a real address, so we'll use a dummy
// tapscript change script that's used elsewhere for tests.
sweepAddrs = append(sweepAddrs, dummyChangePkScript)
}
// Get the size of the sweep tx, which will be used to calculate the
// budget fee rate.
size, err := calcSweepTxWeight(
r.Inputs, sweepAddrs,
)
if err != nil {
return 0, err
}
// Use the budget and MaxFeeRate to decide the max allowed fee rate.
// This is needed as, when the input has a large value and the user
// sets the budget to be proportional to the input value, the fee rate
// can be very high and we need to make sure it doesn't exceed the max
// fee rate.
maxFeeRateAllowed := chainfee.NewSatPerKWeight(r.Budget, size)
if maxFeeRateAllowed > r.MaxFeeRate {
log.Debugf("Budget feerate %v exceeds MaxFeeRate %v, use "+
"MaxFeeRate instead, txWeight=%v", maxFeeRateAllowed,
r.MaxFeeRate, size)
return r.MaxFeeRate, nil
}
log.Debugf("Budget feerate %v below MaxFeeRate %v, use budget feerate "+
"instead, txWeight=%v", maxFeeRateAllowed, r.MaxFeeRate, size)
return maxFeeRateAllowed, nil
}
// calcSweepTxWeight calculates the weight of the sweep tx. It assumes a
// sweeping tx always has a single output(change).
func calcSweepTxWeight(inputs []input.Input,
outputPkScript [][]byte) (lntypes.WeightUnit, error) {
// Use a const fee rate as we only use the weight estimator to
// calculate the size.
const feeRate = 1
// Initialize the tx weight estimator with,
// - nil outputs as we only have one single change output.
// - const fee rate as we don't care about the fees here.
// - 0 maxfeerate as we don't care about fees here.
//
// TODO(yy): we should refactor the weight estimator to not require a
// fee rate and max fee rate and make it a pure tx weight calculator.
_, estimator, err := getWeightEstimate(
inputs, nil, feeRate, 0, outputPkScript,
)
if err != nil {
return 0, err
}
return estimator.weight(), nil
}
// BumpResult is used by the Bumper to send updates about the tx being
// broadcast.
type BumpResult struct {
// Event is the type of event that the result is for.
Event BumpEvent
// Tx is the tx being broadcast.
Tx *wire.MsgTx
// ReplacedTx is the old, replaced tx if a fee bump is attempted.
ReplacedTx *wire.MsgTx
// FeeRate is the fee rate used for the new tx.
FeeRate chainfee.SatPerKWeight
// Fee is the fee paid by the new tx.
Fee btcutil.Amount
// Err is the error that occurred during the broadcast.
Err error
// SpentInputs are the inputs spent by another tx which caused the
// current tx to be failed.
SpentInputs map[wire.OutPoint]*wire.MsgTx
// requestID is the ID of the request that created this record.
requestID uint64
}
// String returns a human-readable string for the result.
func (b *BumpResult) String() string {
desc := fmt.Sprintf("Event=%v", b.Event)
if b.Tx != nil {
desc += fmt.Sprintf(", Tx=%v", b.Tx.TxHash())
}
return fmt.Sprintf("[%s]", desc)
}
// Validate validates the BumpResult so it's safe to use.
func (b *BumpResult) Validate() error {
isFailureEvent := b.Event == TxFailed || b.Event == TxFatal ||
b.Event == TxUnknownSpend
// Every result must have a tx except the fatal or failed case.
if b.Tx == nil && !isFailureEvent {
return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult)
}
// Every result must have a known event.
if b.Event.Unknown() {
return fmt.Errorf("%w: unknown event", ErrInvalidBumpResult)
}
// If it's a replacing event, it must have a replaced tx.
if b.Event == TxReplaced && b.ReplacedTx == nil {
return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult)
}
// If it's a failed or fatal event, it must have an error.
if isFailureEvent && b.Err == nil {
return fmt.Errorf("%w: nil error", ErrInvalidBumpResult)
}
// If it's a confirmed event, it must have a fee rate and fee.
if b.Event == TxConfirmed && (b.FeeRate == 0 || b.Fee == 0) {
return fmt.Errorf("%w: missing fee rate or fee",
ErrInvalidBumpResult)
}
return nil
}
// TxPublisherConfig is the config used to create a new TxPublisher.
type TxPublisherConfig struct {
// Signer is used to create the tx signature.
Signer input.Signer
// Wallet is used primarily to publish the tx.
Wallet Wallet
// Estimator is used to estimate the fee rate for the new tx based on
// its deadline conf target.
Estimator chainfee.Estimator
// Notifier is used to monitor the confirmation status of the tx.
Notifier chainntnfs.ChainNotifier
// AuxSweeper is an optional interface that can be used to modify the
// way sweep transaction are generated.
AuxSweeper fn.Option[AuxSweeper]
}
// TxPublisher is an implementation of the Bumper interface. It utilizes the
// `testmempoolaccept` RPC to bump the fee of txns it created based on
// different fee function selected or configed by the caller. Its purpose is to
// take a list of inputs specified, and create a tx that spends them to a
// specified output. It will then monitor the confirmation status of the tx,
// and if it's not confirmed within a certain time frame, it will attempt to
// bump the fee of the tx by creating a new tx that spends the same inputs to
// the same output, but with a higher fee rate. It will continue to do this
// until the tx is confirmed or the fee rate reaches the maximum fee rate
// specified by the caller.
type TxPublisher struct {
started atomic.Bool
stopped atomic.Bool
// Embed the blockbeat consumer struct to get access to the method
// `NotifyBlockProcessed` and the `BlockbeatChan`.
chainio.BeatConsumer
wg sync.WaitGroup
// cfg specifies the configuration of the TxPublisher.
cfg *TxPublisherConfig
// currentHeight is the current block height.
currentHeight atomic.Int32
// records is a map keyed by the requestCounter and the value is the tx
// being monitored.
records lnutils.SyncMap[uint64, *monitorRecord]
// requestCounter is a monotonically increasing counter used to keep
// track of how many requests have been made.
requestCounter atomic.Uint64
// subscriberChans is a map keyed by the requestCounter, each item is
// the chan that the publisher sends the fee bump result to.
subscriberChans lnutils.SyncMap[uint64, chan *BumpResult]
// quit is used to signal the publisher to stop.
quit chan struct{}
}
// Compile-time constraint to ensure TxPublisher implements Bumper.
var _ Bumper = (*TxPublisher)(nil)
// Compile-time check for the chainio.Consumer interface.
var _ chainio.Consumer = (*TxPublisher)(nil)
// NewTxPublisher creates a new TxPublisher.
func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher {
tp := &TxPublisher{
cfg: &cfg,
records: lnutils.SyncMap[uint64, *monitorRecord]{},
subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{},
quit: make(chan struct{}),
}
// Mount the block consumer.
tp.BeatConsumer = chainio.NewBeatConsumer(tp.quit, tp.Name())
return tp
}
// isNeutrinoBackend checks if the wallet backend is neutrino.
func (t *TxPublisher) isNeutrinoBackend() bool {
return t.cfg.Wallet.BackEnd() == "neutrino"
}
// Broadcast is used to publish the tx created from the given inputs. It will
// register the broadcast request and return a chan to the caller to subscribe
// the broadcast result. The initial broadcast is guaranteed to be
// RBF-compliant unless the budget specified cannot cover the fee.
//
// NOTE: part of the Bumper interface.
func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult {
log.Tracef("Received broadcast request: %s",
lnutils.SpewLogClosure(req))
// Store the request.
record := t.storeInitialRecord(req)
// Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1)
t.subscriberChans.Store(record.requestID, subscriber)
// Publish the tx immediately if specified.
if req.Immediate {
t.handleInitialBroadcast(record)
}
return subscriber
}
// storeInitialRecord initializes a monitor record and saves it in the map.
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the counter.
requestID := t.requestCounter.Add(1)
// Register the record.
record := &monitorRecord{
requestID: requestID,
req: req,
}
t.records.Store(requestID, record)
return record
}
// updateRecord updates the given record's tx and fee, and saves it in the
// records map.
func (t *TxPublisher) updateRecord(r *monitorRecord,
sweepCtx *sweepTxCtx) *monitorRecord {
r.tx = sweepCtx.tx
r.fee = sweepCtx.fee
r.outpointToTxIndex = sweepCtx.outpointToTxIndex
// Register the record.
t.records.Store(r.requestID, r)
return r
}
// NOTE: part of the `chainio.Consumer` interface.
func (t *TxPublisher) Name() string {
return "TxPublisher"
}
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
// succeeded, the initial tx is stored in the records map.
func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(r.req)
if err != nil {
return nil, fmt.Errorf("init fee function: %w", err)
}
// Attach the newly created fee function.
//
// TODO(yy): current we'd initialize a monitorRecord before creating the
// fee function, while we could instead create the fee function first
// then save it to the record. To make this happen we need to change the
// conf target calculation below since we would be initializing the fee
// function one block before.
r.feeFunction = feeAlgo
// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
record, err := t.createRBFCompliantTx(r)
if err != nil {
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
}
return record, nil
}
// initializeFeeFunction initializes a fee function to be used for this request
// for future fee bumping.
func (t *TxPublisher) initializeFeeFunction(
req *BumpRequest) (FeeFunction, error) {
// Get the max allowed feerate.
maxFeeRateAllowed, err := req.MaxFeeRateAllowed()
if err != nil {
return nil, err
}
// Get the initial conf target.
confTarget := calcCurrentConfTarget(
t.currentHeight.Load(), req.DeadlineHeight,
)
log.Debugf("Initializing fee function with conf target=%v, budget=%v, "+
"maxFeeRateAllowed=%v", confTarget, req.Budget,
maxFeeRateAllowed)
// Initialize the fee function and return it.
//
// TODO(yy): return based on differet req.Strategy?
return NewLinearFeeFunction(
maxFeeRateAllowed, confTarget, t.cfg.Estimator,
req.StartingFeeRate,
)
}
// createRBFCompliantTx creates a tx that is compliant with RBF rules. It does
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(
r *monitorRecord) (*monitorRecord, error) {
f := r.feeFunction
for {
// Create a new tx with the given fee rate and check its
// mempool acceptance.
sweepCtx, err := t.createAndCheckTx(r)
switch {
case err == nil:
// The tx is valid, store it.
record := t.updateRecord(r, sweepCtx)
log.Infof("Created initial sweep tx=%v for %v inputs: "+
"feerate=%v, fee=%v, inputs:\n%v",
sweepCtx.tx.TxHash(), len(r.req.Inputs),
f.FeeRate(), sweepCtx.fee,
inputTypeSummary(r.req.Inputs))
return record, nil
// If the error indicates the fees paid is not enough, we will
// ask the fee function to increase the fee rate and retry.
case errors.Is(err, lnwallet.ErrMempoolFee):
// We should at least start with a feerate above the
// mempool min feerate, so if we get this error, it
// means something is wrong earlier in the pipeline.
log.Errorf("Current fee=%v, feerate=%v, %v",
sweepCtx.fee, f.FeeRate(), err)
fallthrough
// We are not paying enough fees so we increase it.
case errors.Is(err, chain.ErrInsufficientFee):
increased := false
// Keep calling the fee function until the fee rate is
// increased or maxed out.
for !increased {
log.Debugf("Increasing fee for next round, "+
"current fee=%v, feerate=%v",
sweepCtx.fee, f.FeeRate())
// If the fee function tells us that we have
// used up the budget, we will return an error
// indicating this tx cannot be made. The
// sweeper should handle this error and try to
// cluster these inputs differetly.
increased, err = f.Increment()
if err != nil {
return nil, err
}
}
// TODO(yy): suppose there's only one bad input, we can do a
// binary search to find out which input is causing this error
// by recreating a tx using half of the inputs and check its
// mempool acceptance.
default:
log.Debugf("Failed to create RBF-compliant tx: %v", err)
return nil, err
}
}
}
// createAndCheckTx creates a tx based on the given inputs, change output
// script, and the fee rate. In addition, it validates the tx's mempool
// acceptance before returning a tx that can be published directly, along with
// its fee.
func (t *TxPublisher) createAndCheckTx(r *monitorRecord) (*sweepTxCtx, error) {
req := r.req
f := r.feeFunction
// Create the sweep tx with max fee rate of 0 as the fee function
// guarantees the fee rate used here won't exceed the max fee rate.
sweepCtx, err := t.createSweepTx(
req.Inputs, req.DeliveryAddress, f.FeeRate(),
)
if err != nil {
return sweepCtx, fmt.Errorf("create sweep tx: %w", err)
}
// Sanity check the budget still covers the fee.
if sweepCtx.fee > req.Budget {
return sweepCtx, fmt.Errorf("%w: budget=%v, fee=%v",
ErrNotEnoughBudget, req.Budget, sweepCtx.fee)
}
// If we had an extra txOut, then we'll update the result to include
// it.
req.ExtraTxOut = sweepCtx.extraTxOut
// Validate the tx's mempool acceptance.
err = t.cfg.Wallet.CheckMempoolAcceptance(sweepCtx.tx)
// Exit early if the tx is valid.
if err == nil {
return sweepCtx, nil
}
// Print an error log if the chain backend doesn't support the mempool
// acceptance test RPC.
if errors.Is(err, rpcclient.ErrBackendVersion) {
log.Errorf("TestMempoolAccept not supported by backend, " +
"consider upgrading it to a newer version")
return sweepCtx, nil
}
// We are running on a backend that doesn't implement the RPC
// testmempoolaccept, eg, neutrino, so we'll skip the check.
if errors.Is(err, chain.ErrUnimplemented) {
log.Debug("Skipped testmempoolaccept due to not implemented")
return sweepCtx, nil
}
// If the inputs are spent by another tx, we will exit with the latest
// sweepCtx and an error.
if errors.Is(err, chain.ErrMissingInputs) {
log.Debugf("Tx %v missing inputs, it's likely the input has "+
"been spent by others", sweepCtx.tx.TxHash())
// Make sure to update the record with the latest attempt.
t.updateRecord(r, sweepCtx)
return sweepCtx, ErrInputMissing
}
return sweepCtx, fmt.Errorf("tx=%v failed mempool check: %w",
sweepCtx.tx.TxHash(), err)
}
// handleMissingInputs handles the case when the chain backend reports back a
// missing inputs error, which could happen when one of the input has been spent
// in another tx, or the input is referencing an orphan. When the input is
// spent, it will be handled via the TxUnknownSpend flow by creating a
// TxUnknownSpend bump result, otherwise, a TxFatal bump result is returned.
func (t *TxPublisher) handleMissingInputs(r *monitorRecord) *BumpResult {
// Get the spending txns.
spends := t.getSpentInputs(r)
// Attach the spending txns.
r.spentInputs = spends
// If there are no spending txns found and the input is missing, the
// input is referencing an orphan tx that's no longer valid, e.g., the
// spending the anchor output from the remote commitment after the local
// commitment has confirmed. In this case we will mark it as fatal and
// exit.
if len(spends) == 0 {
log.Warnf("Failing record=%v: found orphan inputs: %v\n",
r.requestID, inputTypeSummary(r.req.Inputs))
// Create a result that will be sent to the resultChan which is
// listened by the caller.
result := &BumpResult{
Event: TxFatal,
Tx: r.tx,
requestID: r.requestID,
Err: ErrInputMissing,
}
return result
}
// Check that the spending tx matches the sweeping tx - given that the
// current sweeping tx has been failed due to missing inputs, the
// spending tx must be a different tx, thus it should NOT be matched. We
// perform a sanity check here to catch the unexpected state.
if !t.isUnknownSpent(r, spends) {
log.Errorf("Sweeping tx %v has missing inputs, yet the "+
"spending tx is the sweeping tx itself: %v",
r.tx.TxHash(), r.spentInputs)
}
return t.createUnknownSpentBumpResult(r)
}
// broadcast takes a monitored tx and publishes it to the network. Prior to the
// broadcast, it will subscribe the tx's confirmation notification and attach
// the event channel to the record. Any broadcast-related errors will not be
// returned here, instead, they will be put inside the `BumpResult` and
// returned to the caller.
func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) {
txid := record.tx.TxHash()
tx := record.tx
log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v",
txid, len(tx.TxIn), t.currentHeight.Load())
// Before we go to broadcast, we'll notify the aux sweeper, if it's
// present of this new broadcast attempt.
err := fn.MapOptionZ(t.cfg.AuxSweeper, func(aux AuxSweeper) error {
return aux.NotifyBroadcast(
record.req, tx, record.fee, record.outpointToTxIndex,
)
})
if err != nil {
return nil, fmt.Errorf("unable to notify aux sweeper: %w", err)
}
// Set the event, and change it to TxFailed if the wallet fails to
// publish it.
event := TxPublished
// Publish the sweeping tx with customized label. If the publish fails,
// this error will be saved in the `BumpResult` and it will be removed
// from being monitored.
err = t.cfg.Wallet.PublishTransaction(
tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil),
)
if err != nil {
// NOTE: we decide to attach this error to the result instead
// of returning it here because by the time the tx reaches
// here, it should have passed the mempool acceptance check. If
// it still fails to be broadcast, it's likely a non-RBF
// related error happened. So we send this error back to the
// caller so that it can handle it properly.
//
// TODO(yy): find out which input is causing the failure.
log.Errorf("Failed to publish tx %v: %v", txid, err)
event = TxFailed
}
result := &BumpResult{
Event: event,
Tx: record.tx,
Fee: record.fee,
FeeRate: record.feeFunction.FeeRate(),
Err: err,
requestID: record.requestID,
}
return result, nil
}
// notifyResult sends the result to the resultChan specified by the requestID.
// This channel is expected to be read by the caller.
func (t *TxPublisher) notifyResult(result *BumpResult) {
id := result.requestID
subscriber, ok := t.subscriberChans.Load(id)
if !ok {
log.Errorf("Result chan for id=%v not found", id)
return
}
log.Debugf("Sending result %v for requestID=%v", result, id)
select {
// Send the result to the subscriber.
//
// TODO(yy): Add timeout in case it's blocking?
case subscriber <- result:
case <-t.quit:
log.Debug("Fee bumper stopped")
}
}
// removeResult removes the tracking of the result if the result contains a
// non-nil error, or the tx is confirmed, the record will be removed from the
// maps.
func (t *TxPublisher) removeResult(result *BumpResult) {
id := result.requestID
var txid chainhash.Hash
if result.Tx != nil {
txid = result.Tx.TxHash()
}
// Remove the record from the maps if there's an error or the tx is
// confirmed. When there's an error, it means this tx has failed its
// broadcast and cannot be retried. There are two cases it may fail,
// - when the budget cannot cover the increased fee calculated by the
// fee function, hence the budget is used up.
// - when a non-fee related error returned from PublishTransaction.
switch result.Event {
case TxFailed:
log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v",
id, txid, result.Err)
case TxConfirmed:
// Remove the record if the tx is confirmed.
log.Debugf("Removing confirmed monitor record=%v, tx=%v", id,
txid)
case TxFatal:
// Remove the record if there's an error.
log.Debugf("Removing monitor record=%v due to fatal err: %v",
id, result.Err)
case TxUnknownSpend:
// Remove the record if there's an unknown spend.
log.Debugf("Removing monitor record=%v due unknown spent: "+
"%v", id, result.Err)
// Do nothing if it's neither failed or confirmed.
default:
log.Tracef("Skipping record removal for id=%v, event=%v", id,
result.Event)
return
}
t.records.Delete(id)
t.subscriberChans.Delete(id)
}
// handleResult handles the result of a tx broadcast. It will notify the
// subscriber and remove the record if the tx is confirmed or failed to be
// broadcast.
func (t *TxPublisher) handleResult(result *BumpResult) {
// Notify the subscriber.
t.notifyResult(result)
// Remove the record if it's failed or confirmed.
t.removeResult(result)
}
// monitorRecord is used to keep track of the tx being monitored by the
// publisher internally.
type monitorRecord struct {
// requestID is the ID of the request that created this record.
requestID uint64
// tx is the tx being monitored.
tx *wire.MsgTx
// req is the original request.
req *BumpRequest
// feeFunction is the fee bumping algorithm used by the publisher.
feeFunction FeeFunction
// fee is the fee paid by the tx.
fee btcutil.Amount
// outpointToTxIndex is a map of outpoint to tx index.
outpointToTxIndex map[wire.OutPoint]int
// spentInputs are the inputs spent by another tx which caused the
// current tx failed.
spentInputs map[wire.OutPoint]*wire.MsgTx
}
// Start starts the publisher by subscribing to block epoch updates and kicking
// off the monitor loop.
func (t *TxPublisher) Start(beat chainio.Blockbeat) error {
log.Info("TxPublisher starting...")
if t.started.Swap(true) {
return fmt.Errorf("TxPublisher started more than once")
}
// Set the current height.
t.currentHeight.Store(beat.Height())
t.wg.Add(1)
go t.monitor()
log.Debugf("TxPublisher started")
return nil
}
// Stop stops the publisher and waits for the monitor loop to exit.
func (t *TxPublisher) Stop() error {
log.Info("TxPublisher stopping...")
if t.stopped.Swap(true) {
return fmt.Errorf("TxPublisher stopped more than once")
}
close(t.quit)
t.wg.Wait()
log.Debug("TxPublisher stopped")
return nil
}
// monitor is the main loop driven by new blocks. Whevenr a new block arrives,
// it will examine all the txns being monitored, and check if any of them needs
// to be bumped. If so, it will attempt to bump the fee of the tx.
//
// NOTE: Must be run as a goroutine.
func (t *TxPublisher) monitor() {
defer t.wg.Done()
for {
select {
case beat := <-t.BlockbeatChan:
height := beat.Height()
log.Debugf("TxPublisher received new block: %v", height)
// Update the best known height for the publisher.
t.currentHeight.Store(height)
// Check all monitored txns to see if any of them needs
// to be bumped.
t.processRecords()
// Notify we've processed the block.
t.NotifyBlockProcessed(beat, nil)
case <-t.quit:
log.Debug("Fee bumper stopped, exit monitor")
return
}
}
}
// processRecords checks all the txns being monitored, and checks if any of
// them needs to be bumped. If so, it will attempt to bump the fee of the tx.
func (t *TxPublisher) processRecords() {
// confirmedRecords stores a map of the records which have been
// confirmed.
confirmedRecords := make(map[uint64]*monitorRecord)
// feeBumpRecords stores a map of records which need to be bumped.
feeBumpRecords := make(map[uint64]*monitorRecord)
// failedRecords stores a map of records which has inputs being spent
// by a third party.
failedRecords := make(map[uint64]*monitorRecord)
// initialRecords stores a map of records which are being created and
// published for the first time.
initialRecords := make(map[uint64]*monitorRecord)
// visitor is a helper closure that visits each record and divides them
// into two groups.
visitor := func(requestID uint64, r *monitorRecord) error {
log.Tracef("Checking monitor recordID=%v", requestID)
// Check whether the inputs have already been spent.
spends := t.getSpentInputs(r)
// If the any of the inputs has been spent, the record will be
// marked as failed or confirmed.
if len(spends) != 0 {
// Attach the spending txns.
r.spentInputs = spends
// When tx is nil, it means we haven't tried the initial
// broadcast yet the input is already spent. This could
// happen when the node shuts down, a previous sweeping
// tx confirmed, then the node comes back online and
// reoffers the inputs. Another case is the remote node
// spends the input quickly before we even attempt the
// sweep. In either case we will fail the record and let
// the sweeper handles it.
if r.tx == nil {
failedRecords[requestID] = r
return nil
}
// Check whether the inputs has been spent by a unknown
// tx.
if t.isUnknownSpent(r, spends) {
failedRecords[requestID] = r
// Move to the next record.
return nil
}
// The tx is ours, we can move it to the confirmed queue
// and stop monitoring it.
confirmedRecords[requestID] = r
// Move to the next record.
return nil
}
// This is the first time we see this record, so we put it in
// the initial queue.
if r.tx == nil {
initialRecords[requestID] = r
return nil
}
// We can only get here when the inputs are not spent and a
// previous sweeping tx has been attempted. In this case we will
// perform an RBF on it in the current block.
feeBumpRecords[requestID] = r
// Return nil to move to the next record.
return nil
}
// Iterate through all the records and divide them into four groups.
t.records.ForEach(visitor)
// Handle the initial broadcast.
for _, r := range initialRecords {
t.handleInitialBroadcast(r)
}
// For records that are confirmed, we'll notify the caller about this
// result.
for _, r := range confirmedRecords {
t.wg.Add(1)
go t.handleTxConfirmed(r)
}
// Get the current height to be used in the following goroutines.
currentHeight := t.currentHeight.Load()
// For records that are not confirmed, we perform a fee bump if needed.
for _, r := range feeBumpRecords {
t.wg.Add(1)
go t.handleFeeBumpTx(r, currentHeight)
}
// For records that are failed, we'll notify the caller about this
// result.
for _, r := range failedRecords {
t.wg.Add(1)
go t.handleUnknownSpent(r)
}
}
// handleTxConfirmed is called when a monitored tx is confirmed. It will
// notify the subscriber then remove the record from the maps .
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleTxConfirmed(r *monitorRecord) {
defer t.wg.Done()
log.Debugf("Record %v is spent in tx=%v", r.requestID, r.tx.TxHash())
// Create a result that will be sent to the resultChan which is
// listened by the caller.
result := &BumpResult{
Event: TxConfirmed,
Tx: r.tx,
requestID: r.requestID,
Fee: r.fee,
FeeRate: r.feeFunction.FeeRate(),
}
// Notify that this tx is confirmed and remove the record from the map.
t.handleResult(result)
}
// handleInitialTxError takes the error from `initializeTx` and decides the
// bump event. It will construct a BumpResult and handles it.
func (t *TxPublisher) handleInitialTxError(r *monitorRecord, err error) {
// Create a bump result to be sent to the sweeper.
result := &BumpResult{
Err: err,
requestID: r.requestID,
}
// We now decide what type of event to send.
switch {
// When the error is due to a dust output, we'll send a TxFailed so
// these inputs can be retried with a different group in the next
// block.
case errors.Is(err, ErrTxNoOutput):
result.Event = TxFailed
// When the error is due to zero fee rate delta, we'll send a TxFailed
// so these inputs can be retried in the next block.
case errors.Is(err, ErrZeroFeeRateDelta):
result.Event = TxFailed
// When the error is due to budget being used up, we'll send a TxFailed
// so these inputs can be retried with a different group in the next
// block.
case errors.Is(err, ErrMaxPosition):
fallthrough
// If the tx doesn't not have enough budget, or if the inputs amounts
// are not sufficient to cover the budget, we will return a TxFailed
// event so the sweeper can handle it by re-clustering the utxos.
case errors.Is(err, ErrNotEnoughInputs),
errors.Is(err, ErrNotEnoughBudget):
result.Event = TxFailed
// Calculate the starting fee rate to be used when retry
// sweeping these inputs.
feeRate, err := t.calculateRetryFeeRate(r)
if err != nil {
result.Event = TxFatal
result.Err = err
}
// Attach the new fee rate.
result.FeeRate = feeRate
// When there are missing inputs, we'll create a TxUnknownSpend bump
// result here so the rest of the inputs can be retried.
case errors.Is(err, ErrInputMissing):
result = t.handleMissingInputs(r)
// Otherwise this is not a fee-related error and the tx cannot be
// retried. In that case we will fail ALL the inputs in this tx, which
// means they will be removed from the sweeper and never be tried
// again.
//
// TODO(yy): Find out which input is causing the failure and fail that
// one only.
default:
result.Event = TxFatal
}
t.handleResult(result)
}
// handleInitialBroadcast is called when a new request is received. It will
// handle the initial tx creation and broadcast. In details,
// 1. init a fee function based on the given strategy.
// 2. create an RBF-compliant tx and monitor it for confirmation.
// 3. notify the initial broadcast result back to the caller.
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
log.Debugf("Initial broadcast for requestID=%v", r.requestID)
var (
result *BumpResult
err error
)
// Attempt an initial broadcast which is guaranteed to comply with the
// RBF rules.
//
// Create the initial tx to be broadcasted.
record, err := t.initializeTx(r)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)
// We now handle the initialization error and exit.
t.handleInitialTxError(r, err)
return
}
// Successfully created the first tx, now broadcast it.
result, err = t.broadcast(record)
if err != nil {
// The broadcast failed, which can only happen if the tx record
// cannot be found or the aux sweeper returns an error. In
// either case, we will send back a TxFail event so these
// inputs can be retried.
result = &BumpResult{
Event: TxFailed,
Err: err,
requestID: r.requestID,
}
}
t.handleResult(result)
}
// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will
// attempt to bump the fee of the tx.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleFeeBumpTx(r *monitorRecord, currentHeight int32) {
defer t.wg.Done()
log.Debugf("Attempting to fee bump tx=%v in record %v", r.tx.TxHash(),
r.requestID)
oldTxid := r.tx.TxHash()
// Get the current conf target for this record.
confTarget := calcCurrentConfTarget(currentHeight, r.req.DeadlineHeight)
// Ask the fee function whether a bump is needed. We expect the fee
// function to increase its returned fee rate after calling this
// method.
increased, err := r.feeFunction.IncreaseFeeRate(confTarget)
if err != nil {
// TODO(yy): send this error back to the sweeper so it can
// re-group the inputs?
log.Errorf("Failed to increase fee rate for tx %v at "+
"height=%v: %v", oldTxid, t.currentHeight.Load(), err)
return
}
// If the fee rate was not increased, there's no need to bump the fee.
if !increased {
log.Tracef("Skip bumping tx %v at height=%v", oldTxid,
t.currentHeight.Load())
return
}
// The fee function now has a new fee rate, we will use it to bump the
// fee of the tx.
resultOpt := t.createAndPublishTx(r)
// If there's a result, we will notify the caller about the result.
resultOpt.WhenSome(func(result BumpResult) {
// Notify the new result.
t.handleResult(&result)
})
}
// handleUnknownSpent is called when the inputs are spent by a unknown tx. It
// will notify the subscriber then remove the record from the maps and send a
// TxUnknownSpend event to the subscriber.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleUnknownSpent(r *monitorRecord) {
defer t.wg.Done()
log.Debugf("Record %v has inputs spent by a tx unknown to the fee "+
"bumper, failing it now:\n%v", r.requestID,
inputTypeSummary(r.req.Inputs))
// Create a result that will be sent to the resultChan which is listened
// by the caller.
result := t.createUnknownSpentBumpResult(r)
// Notify the sweeper about this result in the end.
t.handleResult(result)
}
// createUnknownSpentBumpResult creates and returns a BumpResult given the
// monitored record has unknown spends.
func (t *TxPublisher) createUnknownSpentBumpResult(
r *monitorRecord) *BumpResult {
// Create a result that will be sent to the resultChan which is listened
// by the caller.
result := &BumpResult{
Event: TxUnknownSpend,
Tx: r.tx,
requestID: r.requestID,
Err: ErrUnknownSpent,
SpentInputs: r.spentInputs,
}
// Calculate the next fee rate for the retry.
feeRate, err := t.calculateRetryFeeRate(r)
if err != nil {
// Overwrite the event and error so the sweeper will
// remove this input.
result.Event = TxFatal
result.Err = err
}
// Attach the new fee rate to be used for the next sweeping attempt.
result.FeeRate = feeRate
return result
}
// createAndPublishTx creates a new tx with a higher fee rate and publishes it
// to the network. It will update the record with the new tx and fee rate if
// successfully created, and return the result when published successfully.
func (t *TxPublisher) createAndPublishTx(
r *monitorRecord) fn.Option[BumpResult] {
// Fetch the old tx.
oldTx := r.tx
// Create a new tx with the new fee rate.
//
// NOTE: The fee function is expected to have increased its returned
// fee rate after calling the SkipFeeBump method. So we can use it
// directly here.
sweepCtx, err := t.createAndCheckTx(r)
// If there's an error creating the replacement tx, we need to abort the
// flow and handle it.
if err != nil {
return t.handleReplacementTxError(r, oldTx, err)
}
// The tx has been created without any errors, we now register a new
// record by overwriting the same requestID.
record := t.updateRecord(r, sweepCtx)
// Attempt to broadcast this new tx.
result, err := t.broadcast(record)
if err != nil {
log.Infof("Failed to broadcast replacement tx %v: %v",
sweepCtx.tx.TxHash(), err)
return fn.None[BumpResult]()
}
// If the result error is fee related, we will return no error and let
// the fee bumper retry it at next block.
//
// NOTE: we may get this error if we've bypassed the mempool check,
// which means we are suing neutrino backend.
if errors.Is(result.Err, chain.ErrInsufficientFee) ||
errors.Is(result.Err, lnwallet.ErrMempoolFee) {
log.Debugf("Failed to bump tx %v: %v", oldTx.TxHash(),
result.Err)
return fn.None[BumpResult]()
}
// A successful replacement tx is created, attach the old tx.
result.ReplacedTx = oldTx
// If the new tx failed to be published, we will return the result so
// the caller can handle it.
if result.Event == TxFailed {
return fn.Some(*result)
}
log.Debugf("Replaced tx=%v with new tx=%v", oldTx.TxHash(),
sweepCtx.tx.TxHash())
// Otherwise, it's a successful RBF, set the event and return.
result.Event = TxReplaced
return fn.Some(*result)
}
// isUnknownSpent checks whether the inputs of the tx has already been spent by
// a tx not known to us. When a tx is not confirmed, yet its inputs has been
// spent, then it must be spent by a different tx other than the sweeping tx
// here.
func (t *TxPublisher) isUnknownSpent(r *monitorRecord,
spends map[wire.OutPoint]*wire.MsgTx) bool {
txid := r.tx.TxHash()
// Iterate all the spending txns and check if they match the sweeping
// tx.
for op, spendingTx := range spends {
spendingTxID := spendingTx.TxHash()
// If the spending tx is the same as the sweeping tx then we are
// good.
if spendingTxID == txid {
continue
}
log.Warnf("Detected unknown spend of input=%v in tx=%v", op,
spendingTx.TxHash())
return true
}
return false
}
// getSpentInputs performs a non-blocking read on the spending subscriptions to
// see whether any of the monitored inputs has been spent. A map of inputs with
// their spending txns are returned if found.
func (t *TxPublisher) getSpentInputs(
r *monitorRecord) map[wire.OutPoint]*wire.MsgTx {
// Create a slice to record the inputs spent.
spentInputs := make(map[wire.OutPoint]*wire.MsgTx, len(r.req.Inputs))
// Iterate all the inputs and check if they have been spent already.
for _, inp := range r.req.Inputs {
op := inp.OutPoint()
// For wallet utxos, the height hint is not set - we don't need
// to monitor them for third party spend.
//
// TODO(yy): We need to properly lock wallet utxos before
// skipping this check as the same wallet utxo can be used by
// different sweeping txns.
heightHint := inp.HeightHint()
if heightHint == 0 {
heightHint = uint32(t.currentHeight.Load())
log.Debugf("Checking wallet input %v using heightHint "+
"%v", op, heightHint)
}
// If the input has already been spent after the height hint, a
// spend event is sent back immediately.
spendEvent, err := t.cfg.Notifier.RegisterSpendNtfn(
&op, inp.SignDesc().Output.PkScript, heightHint,
)
if err != nil {
log.Criticalf("Failed to register spend ntfn for "+
"input=%v: %v", op, err)
return nil
}
// Remove the subscription when exit.
defer spendEvent.Cancel()
// Do a non-blocking read to see if the output has been spent.
select {
case spend, ok := <-spendEvent.Spend:
if !ok {
log.Debugf("Spend ntfn for %v canceled", op)
continue
}
spendingTx := spend.SpendingTx
log.Debugf("Detected spent of input=%v in tx=%v", op,
spendingTx.TxHash())
spentInputs[op] = spendingTx
// Move to the next input.
default:
log.Tracef("Input %v not spent yet", op)
}
}
return spentInputs
}
// calcCurrentConfTarget calculates the current confirmation target based on
// the deadline height. The conf target is capped at 0 if the deadline has
// already been past.
func calcCurrentConfTarget(currentHeight, deadline int32) uint32 {
var confTarget uint32
// Calculate how many blocks left until the deadline.
deadlineDelta := deadline - currentHeight
// If we are already past the deadline, we will set the conf target to
// be 1.
if deadlineDelta < 0 {
log.Warnf("Deadline is %d blocks behind current height %v",
-deadlineDelta, currentHeight)
confTarget = 0
} else {
confTarget = uint32(deadlineDelta)
}
return confTarget
}
// sweepTxCtx houses a sweep transaction with additional context.
type sweepTxCtx struct {
tx *wire.MsgTx
fee btcutil.Amount
extraTxOut fn.Option[SweepOutput]
// outpointToTxIndex maps the outpoint of the inputs to their index in
// the sweep transaction.
outpointToTxIndex map[wire.OutPoint]int
}
// createSweepTx creates a sweeping tx based on the given inputs, change
// address and fee rate.
func (t *TxPublisher) createSweepTx(inputs []input.Input,
changePkScript lnwallet.AddrWithKey,
feeRate chainfee.SatPerKWeight) (*sweepTxCtx, error) {
// Validate and calculate the fee and change amount.
txFee, changeOutputsOpt, locktimeOpt, err := prepareSweepTx(
inputs, changePkScript, feeRate, t.currentHeight.Load(),
t.cfg.AuxSweeper,
)
if err != nil {
return nil, err
}
var (
// Create the sweep transaction that we will be building. We
// use version 2 as it is required for CSV.
sweepTx = wire.NewMsgTx(2)
// We'll add the inputs as we go so we know the final ordering
// of inputs to sign.
idxs []input.Input
)
// We start by adding all inputs that commit to an output. We do this
// since the input and output index must stay the same for the
// signatures to be valid.
outpointToTxIndex := make(map[wire.OutPoint]int)
for _, o := range inputs {
if o.RequiredTxOut() == nil {
continue
}
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: o.OutPoint(),
Sequence: o.BlocksToMaturity(),
})
sweepTx.AddTxOut(o.RequiredTxOut())
outpointToTxIndex[o.OutPoint()] = len(sweepTx.TxOut) - 1
}
// Sum up the value contained in the remaining inputs, and add them to
// the sweep transaction.
for _, o := range inputs {
if o.RequiredTxOut() != nil {
continue
}
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: o.OutPoint(),
Sequence: o.BlocksToMaturity(),
})
}
// If we have change outputs to add, then add it the sweep transaction
// here.
changeOutputsOpt.WhenSome(func(changeOuts []SweepOutput) {
for i := range changeOuts {
sweepTx.AddTxOut(&changeOuts[i].TxOut)
}
})
// We'll default to using the current block height as locktime, if none
// of the inputs commits to a different locktime.
sweepTx.LockTime = uint32(locktimeOpt.UnwrapOr(t.currentHeight.Load()))
prevInputFetcher, err := input.MultiPrevOutFetcher(inputs)
if err != nil {
return nil, fmt.Errorf("error creating prev input fetcher "+
"for hash cache: %v", err)
}
hashCache := txscript.NewTxSigHashes(sweepTx, prevInputFetcher)
// With all the inputs in place, use each output's unique input script
// function to generate the final witness required for spending.
addInputScript := func(idx int, tso input.Input) error {
inputScript, err := tso.CraftInputScript(
t.cfg.Signer, sweepTx, hashCache, prevInputFetcher, idx,
)
if err != nil {
return err
}
sweepTx.TxIn[idx].Witness = inputScript.Witness
if len(inputScript.SigScript) == 0 {
return nil
}
sweepTx.TxIn[idx].SignatureScript = inputScript.SigScript
return nil
}
for idx, inp := range idxs {
if err := addInputScript(idx, inp); err != nil {
return nil, err
}
}
log.Debugf("Created sweep tx %v for inputs:\n%v", sweepTx.TxHash(),
inputTypeSummary(inputs))
// Try to locate the extra change output, though there might be None.
extraTxOut := fn.MapOption(
func(sweepOuts []SweepOutput) fn.Option[SweepOutput] {
for _, sweepOut := range sweepOuts {
if !sweepOut.IsExtra {
continue
}
// If we sweep outputs of a custom channel, the
// custom leaves in those outputs will be merged
// into a single output, even if we sweep
// multiple outputs (e.g. to_remote and breached
// to_local of a breached channel) at the same
// time. So there will only ever be one extra
// output.
log.Debugf("Sweep produced extra_sweep_out=%v",
lnutils.SpewLogClosure(sweepOut))
return fn.Some(sweepOut)
}
return fn.None[SweepOutput]()
},
)(changeOutputsOpt)
return &sweepTxCtx{
tx: sweepTx,
fee: txFee,
extraTxOut: fn.FlattenOption(extraTxOut),
outpointToTxIndex: outpointToTxIndex,
}, nil
}
// prepareSweepTx returns the tx fee, a set of optional change outputs and an
// optional locktime after a series of validations:
// 1. check the locktime has been reached.
// 2. check the locktimes are the same.
// 3. check the inputs cover the outputs.
//
// NOTE: if the change amount is below dust, it will be added to the tx fee.
func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey,
feeRate chainfee.SatPerKWeight, currentHeight int32,
auxSweeper fn.Option[AuxSweeper]) (
btcutil.Amount, fn.Option[[]SweepOutput], fn.Option[int32], error) {
noChange := fn.None[[]SweepOutput]()
noLocktime := fn.None[int32]()
// Given the set of inputs we have, if we have an aux sweeper, then
// we'll attempt to see if we have any other change outputs we'll need
// to add to the sweep transaction.
changePkScripts := [][]byte{changePkScript.DeliveryAddress}
var extraChangeOut fn.Option[SweepOutput]
err := fn.MapOptionZ(
auxSweeper, func(aux AuxSweeper) error {
extraOut := aux.DeriveSweepAddr(inputs, changePkScript)
if err := extraOut.Err(); err != nil {
return err
}
extraChangeOut = extraOut.LeftToSome()
return nil
},
)
if err != nil {
return 0, noChange, noLocktime, err
}
// Creating a weight estimator with nil outputs and zero max fee rate.
// We don't allow adding customized outputs in the sweeping tx, and the
// fee rate is already being managed before we get here.
inputs, estimator, err := getWeightEstimate(
inputs, nil, feeRate, 0, changePkScripts,
)
if err != nil {
return 0, noChange, noLocktime, err
}
txFee := estimator.fee()
var (
// Track whether any of the inputs require a certain locktime.
locktime = int32(-1)
// We keep track of total input amount, and required output
// amount to use for calculating the change amount below.
totalInput btcutil.Amount
requiredOutput btcutil.Amount
)
// If we have an extra change output, then we'll add it as a required
// output amt.
extraChangeOut.WhenSome(func(o SweepOutput) {
requiredOutput += btcutil.Amount(o.Value)
})
// Go through each input and check if the required lock times have
// reached and are the same.
for _, o := range inputs {
// If the input has a required output, we'll add it to the
// required output amount.
if o.RequiredTxOut() != nil {
requiredOutput += btcutil.Amount(
o.RequiredTxOut().Value,
)
}
// Update the total input amount.
totalInput += btcutil.Amount(o.SignDesc().Output.Value)
lt, ok := o.RequiredLockTime()
// Skip if the input doesn't require a lock time.
if !ok {
continue
}
// Check if the lock time has reached
if lt > uint32(currentHeight) {
return 0, noChange, noLocktime,
fmt.Errorf("%w: current height is %v, "+
"locktime is %v", ErrLocktimeImmature,
currentHeight, lt)
}
// If another input commits to a different locktime, they
// cannot be combined in the same transaction.
if locktime != -1 && locktime != int32(lt) {
return 0, noChange, noLocktime, ErrLocktimeConflict
}
// Update the locktime for next iteration.
locktime = int32(lt)
}
// Make sure total output amount is less than total input amount.
if requiredOutput+txFee > totalInput {
log.Errorf("Insufficient input to create sweep tx: "+
"input_sum=%v, output_sum=%v", totalInput,
requiredOutput+txFee)
return 0, noChange, noLocktime, ErrNotEnoughInputs
}
// The value remaining after the required output and fees is the
// change output.
changeAmt := totalInput - requiredOutput - txFee
changeOuts := make([]SweepOutput, 0, 2)
extraChangeOut.WhenSome(func(o SweepOutput) {
changeOuts = append(changeOuts, o)
})
// We'll calculate the dust limit for the given changePkScript since it
// is variable.
changeFloor := lnwallet.DustLimitForSize(
len(changePkScript.DeliveryAddress),
)
switch {
// If the change amount is dust, we'll move it into the fees, and
// ignore it.
case changeAmt < changeFloor:
log.Infof("Change amt %v below dustlimit %v, not adding "+
"change output", changeAmt, changeFloor)
// If there's no required output, and the change output is a
// dust, it means we are creating a tx without any outputs. In
// this case we'll return an error. This could happen when
// creating a tx that has an anchor as the only input.
if requiredOutput == 0 {
return 0, noChange, noLocktime, ErrTxNoOutput
}
// The dust amount is added to the fee.
txFee += changeAmt
// Otherwise, we'll actually recognize it as a change output.
default:
changeOuts = append(changeOuts, SweepOutput{
TxOut: wire.TxOut{
Value: int64(changeAmt),
PkScript: changePkScript.DeliveryAddress,
},
IsExtra: false,
InternalKey: changePkScript.InternalKey,
})
}
// Optionally set the locktime.
locktimeOpt := fn.Some(locktime)
if locktime == -1 {
locktimeOpt = noLocktime
}
var changeOutsOpt fn.Option[[]SweepOutput]
if len(changeOuts) > 0 {
changeOutsOpt = fn.Some(changeOuts)
}
log.Debugf("Creating sweep tx for %v inputs (%s) using %v, "+
"tx_weight=%v, tx_fee=%v, locktime=%v, parents_count=%v, "+
"parents_fee=%v, parents_weight=%v, current_height=%v",
len(inputs), inputTypeSummary(inputs), feeRate,
estimator.weight(), txFee, locktimeOpt, len(estimator.parents),
estimator.parentsFee, estimator.parentsWeight, currentHeight)
return txFee, changeOutsOpt, locktimeOpt, nil
}
// handleReplacementTxError handles the error returned from creating the
// replacement tx. It returns a BumpResult that should be notified to the
// sweeper.
func (t *TxPublisher) handleReplacementTxError(r *monitorRecord,
oldTx *wire.MsgTx, err error) fn.Option[BumpResult] {
// If the error is fee related, we will return no error and let the fee
// bumper retry it at next block.
//
// NOTE: we can check the RBF error here and ask the fee function to
// recalculate the fee rate. However, this would defeat the purpose of
// using a deadline based fee function:
// - if the deadline is far away, there's no rush to RBF the tx.
// - if the deadline is close, we expect the fee function to give us a
// higher fee rate. If the fee rate cannot satisfy the RBF rules, it
// means the budget is not enough.
if errors.Is(err, chain.ErrInsufficientFee) ||
errors.Is(err, lnwallet.ErrMempoolFee) {
log.Debugf("Failed to bump tx %v: %v", oldTx.TxHash(), err)
return fn.None[BumpResult]()
}
// At least one of the inputs is missing, which means it has already
// been spent by another tx and confirmed. In this case we will handle
// it by returning a TxUnknownSpend bump result.
if errors.Is(err, ErrInputMissing) {
log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), err)
bumpResult := t.handleMissingInputs(r)
return fn.Some(*bumpResult)
}
// Return a failed event to retry the sweep.
event := TxFailed
// Calculate the next fee rate for the retry.
feeRate, ferr := t.calculateRetryFeeRate(r)
if ferr != nil {
// If there's an error with the fee calculation, we need to
// abort the sweep.
event = TxFatal
}
// If the error is not fee related, we will return a `TxFailed` event so
// this input can be retried.
result := fn.Some(BumpResult{
Event: event,
Tx: oldTx,
Err: err,
requestID: r.requestID,
FeeRate: feeRate,
})
// If the tx doesn't not have enough budget, or if the inputs amounts
// are not sufficient to cover the budget, we will return a result so
// the sweeper can handle it by re-clustering the utxos.
if errors.Is(err, ErrNotEnoughBudget) ||
errors.Is(err, ErrNotEnoughInputs) {
log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), err)
return result
}
// Otherwise, an unexpected error occurred, we will log an error and let
// the sweeper retry the whole process.
log.Errorf("Failed to bump tx %v: %v", oldTx.TxHash(), err)
return result
}
// calculateRetryFeeRate calculates a new fee rate to be used as the starting
// fee rate for the next sweep attempt if the inputs are to be retried. When the
// fee function is nil it will be created here, and an error is returned if the
// fee func cannot be initialized.
func (t *TxPublisher) calculateRetryFeeRate(
r *monitorRecord) (chainfee.SatPerKWeight, error) {
// Get the fee function, which will be used to decided the next fee rate
// to use if the sweeper decides to retry sweeping this input.
feeFunc := r.feeFunction
// When the record is failed before the initial broadcast is attempted,
// it will have a nil fee func. In this case, we'll create the fee func
// here.
//
// NOTE: Since the current record is failed and will be deleted, we
// don't need to update the record on this fee function. We only need
// the fee rate data so the sweeper can pick up where we left off.
if feeFunc == nil {
f, err := t.initializeFeeFunction(r.req)
// TODO(yy): The only error we would receive here is when the
// pkScript is not recognized by the weightEstimator. What we
// should do instead is to check the pkScript immediately after
// receiving a sweep request so we don't need to check it again,
// which will also save us from error checking from several
// callsites.
if err != nil {
log.Errorf("Failed to create fee func for record %v: "+
"%v", r.requestID, err)
return 0, err
}
feeFunc = f
}
// Since we failed to sweep the inputs, either the sweeping tx has been
// replaced by another party's tx, or the current output values cannot
// cover the budget, we missed this block window to increase its fee
// rate. To make sure the fee rate stays in the initial line, we now ask
// the fee function to give us the next fee rate as if the sweeping tx
// were RBFed. This new fee rate will be used as the starting fee rate
// if the upper system decides to continue sweeping the rest of the
// inputs.
_, err := feeFunc.Increment()
if err != nil {
// The fee function has reached its max position - nothing we
// can do here other than letting the user increase the budget.
log.Errorf("Failed to calculate the next fee rate for "+
"Record(%v): %v", r.requestID, err)
}
return feeFunc.FeeRate(), nil
}
package sweep
import (
"errors"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrMaxPosition is returned when trying to increase the position of
// the fee function while it's already at its max.
ErrMaxPosition = errors.New("position already at max")
// ErrZeroFeeRateDelta is returned when the fee rate delta is zero.
ErrZeroFeeRateDelta = errors.New("fee rate delta is zero")
)
// mSatPerKWeight represents a fee rate in msat/kw.
//
// TODO(yy): unify all the units to be virtual bytes.
type mSatPerKWeight lnwire.MilliSatoshi
// String returns a human-readable string of the fee rate.
func (m mSatPerKWeight) String() string {
s := lnwire.MilliSatoshi(m)
return fmt.Sprintf("%v/kw", s)
}
// FeeFunction defines an interface that is used to calculate fee rates for
// transactions. It's expected the implementations use three params, the
// starting fee rate, the ending fee rate, and number of blocks till deadline
// block height, to build an algorithm to calculate the fee rate based on the
// current block height.
type FeeFunction interface {
// FeeRate returns the current fee rate calculated by the fee function.
FeeRate() chainfee.SatPerKWeight
// Increment increases the fee rate by one step. The definition of one
// step is up to the implementation. After calling this method, it's
// expected to change the state of the fee function such that calling
// `FeeRate` again will return the increased value.
//
// It returns a boolean to indicate whether the fee rate is increased,
// as fee bump should not be attempted if the increased fee rate is not
// greater than the current fee rate, which may happen if the algorithm
// gives the same fee rates at two positions.
//
// An error is returned when the max fee rate is reached.
//
// NOTE: we intentionally don't return the new fee rate here, so both
// the implementation and the caller are aware of the state change.
Increment() (bool, error)
// IncreaseFeeRate increases the fee rate to the new position
// calculated using (width - confTarget). It returns a boolean to
// indicate whether the fee rate is increased, and an error if the
// position is greater than the width.
//
// NOTE: this method is provided to allow the caller to increase the
// fee rate based on a conf target without taking care of the fee
// function's current state (position).
IncreaseFeeRate(confTarget uint32) (bool, error)
}
// LinearFeeFunction implements the FeeFunction interface with a linear
// function:
//
// feeRate = startingFeeRate + position * delta.
// - width: deadlineBlockHeight - startingBlockHeight
// - delta: (endingFeeRate - startingFeeRate) / width
// - position: currentBlockHeight - startingBlockHeight
//
// The fee rate will be capped at endingFeeRate.
//
// TODO(yy): implement more functions specified here:
// - https://github.com/lightningnetwork/lnd/issues/4215
type LinearFeeFunction struct {
// startingFeeRate specifies the initial fee rate to begin with.
startingFeeRate chainfee.SatPerKWeight
// endingFeeRate specifies the max allowed fee rate.
endingFeeRate chainfee.SatPerKWeight
// currentFeeRate specifies the current calculated fee rate.
currentFeeRate chainfee.SatPerKWeight
// width is the number of blocks between the starting block height
// and the deadline block height minus one.
//
// NOTE: We do minus one from the conf target here because we want to
// max out the budget before the deadline height is reached.
width uint32
// position is the fee function's current position, given a width of w,
// a valid position should lie in range [0, w].
position uint32
// deltaFeeRate is the fee rate (msat/kw) increase per block.
//
// NOTE: this is used to increase precision.
deltaFeeRate mSatPerKWeight
// estimator is the fee estimator used to estimate the fee rate. We use
// it to get the initial fee rate and, use it as a benchmark to decide
// whether we want to used the estimated fee rate or the calculated fee
// rate based on different strategies.
estimator chainfee.Estimator
}
// Compile-time check to ensure LinearFeeFunction satisfies the FeeFunction.
var _ FeeFunction = (*LinearFeeFunction)(nil)
// NewLinearFeeFunction creates a new linear fee function and initializes it
// with a starting fee rate which is an estimated value returned from the fee
// estimator using the initial conf target.
func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight,
confTarget uint32, estimator chainfee.Estimator,
startingFeeRate fn.Option[chainfee.SatPerKWeight]) (
*LinearFeeFunction, error) {
// If the deadline is one block away or has already been reached,
// there's nothing the fee function can do. In this case, we'll use the
// max fee rate immediately.
if confTarget <= 1 {
return &LinearFeeFunction{
startingFeeRate: maxFeeRate,
endingFeeRate: maxFeeRate,
currentFeeRate: maxFeeRate,
}, nil
}
l := &LinearFeeFunction{
endingFeeRate: maxFeeRate,
width: confTarget - 1,
estimator: estimator,
}
// If the caller specifies the starting fee rate, we'll use it instead
// of estimating it based on the deadline.
start, err := startingFeeRate.UnwrapOrFuncErr(
func() (chainfee.SatPerKWeight, error) {
// Estimate the initial fee rate.
//
// NOTE: estimateFeeRate guarantees the returned fee
// rate is capped by the ending fee rate, so we don't
// need to worry about overpay.
return l.estimateFeeRate(confTarget)
})
if err != nil {
return nil, fmt.Errorf("estimate initial fee rate: %w", err)
}
// Calculate how much fee rate should be increased per block.
end := l.endingFeeRate
// The starting and ending fee rates are in sat/kw, so we need to
// convert them to msat/kw by multiplying by 1000.
delta := btcutil.Amount(end - start).MulF64(1000 / float64(l.width))
l.deltaFeeRate = mSatPerKWeight(delta)
// We only allow the delta to be zero if the width is one - when the
// delta is zero, it means the starting and ending fee rates are the
// same, which means there's nothing to increase, so any width greater
// than 1 doesn't provide any utility. This could happen when the
// budget is too small.
if l.deltaFeeRate == 0 && l.width != 1 {
log.Errorf("Failed to init fee function: startingFeeRate=%v, "+
"endingFeeRate=%v, width=%v, delta=%v", start, end,
l.width, l.deltaFeeRate)
return nil, ErrZeroFeeRateDelta
}
// Attach the calculated values to the fee function.
l.startingFeeRate = start
l.currentFeeRate = start
log.Debugf("Linear fee function initialized with startingFeeRate=%v, "+
"endingFeeRate=%v, width=%v, delta=%v", start, end,
l.width, l.deltaFeeRate)
return l, nil
}
// FeeRate returns the current fee rate.
//
// NOTE: part of the FeeFunction interface.
func (l *LinearFeeFunction) FeeRate() chainfee.SatPerKWeight {
return l.currentFeeRate
}
// Increment increases the fee rate by one position, returns a boolean to
// indicate whether the fee rate was increased, and an error if the position is
// greater than the width. The increased fee rate will be set as the current
// fee rate, and the internal position will be incremented.
//
// NOTE: this method will change the state of the fee function as it increases
// its current fee rate.
//
// NOTE: part of the FeeFunction interface.
func (l *LinearFeeFunction) Increment() (bool, error) {
return l.increaseFeeRate(l.position + 1)
}
// IncreaseFeeRate calculate a new position using the given conf target, and
// increases the fee rate to the new position by calling the Increment method.
//
// NOTE: this method will change the state of the fee function as it increases
// its current fee rate.
//
// NOTE: part of the FeeFunction interface.
func (l *LinearFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) {
newPosition := uint32(0)
// Only calculate the new position when the conf target is less than
// the function's width - the width is the initial conf target-1, and
// we expect the current conf target to decrease over time. However, we
// still allow the supplied conf target to be greater than the width,
// and we won't increase the fee rate in that case.
if confTarget < l.width+1 {
newPosition = l.width + 1 - confTarget
log.Tracef("Increasing position from %v to %v", l.position,
newPosition)
}
if newPosition <= l.position {
log.Tracef("Skipped increase feerate: position=%v, "+
"newPosition=%v ", l.position, newPosition)
return false, nil
}
return l.increaseFeeRate(newPosition)
}
// increaseFeeRate increases the fee rate by the specified position, returns a
// boolean to indicate whether the fee rate was increased, and an error if the
// position is greater than the width. The increased fee rate will be set as
// the current fee rate, and the internal position will be set to the specified
// position.
//
// NOTE: this method will change the state of the fee function as it increases
// its current fee rate.
func (l *LinearFeeFunction) increaseFeeRate(position uint32) (bool, error) {
// If the new position is already at the end, we return an error.
if l.position >= l.width {
return false, ErrMaxPosition
}
// Get the old fee rate.
oldFeeRate := l.currentFeeRate
// Update its internal state.
l.position = position
l.currentFeeRate = l.feeRateAtPosition(position)
log.Tracef("Fee rate increased from %v to %v at position %v",
oldFeeRate, l.currentFeeRate, l.position)
return l.currentFeeRate > oldFeeRate, nil
}
// feeRateAtPosition calculates the fee rate at a given position and caps it at
// the ending fee rate.
func (l *LinearFeeFunction) feeRateAtPosition(p uint32) chainfee.SatPerKWeight {
if p >= l.width {
return l.endingFeeRate
}
// deltaFeeRate is in msat/kw, so we need to divide by 1000 to get the
// fee rate in sat/kw.
feeRateDelta := btcutil.Amount(l.deltaFeeRate).MulF64(float64(p) / 1000)
feeRate := l.startingFeeRate + chainfee.SatPerKWeight(feeRateDelta)
if feeRate > l.endingFeeRate {
return l.endingFeeRate
}
return feeRate
}
// estimateFeeRate asks the fee estimator to estimate the fee rate based on its
// conf target.
func (l *LinearFeeFunction) estimateFeeRate(
confTarget uint32) (chainfee.SatPerKWeight, error) {
fee := FeeEstimateInfo{
ConfTarget: confTarget,
}
// If the conf target is greater or equal to the max allowed value
// (1008), we will use the min relay fee instead.
if confTarget >= chainfee.MaxBlockTarget {
minFeeRate := l.estimator.RelayFeePerKW()
log.Infof("Conf target %v is greater than max block target, "+
"using min relay fee rate %v", confTarget, minFeeRate)
return minFeeRate, nil
}
// endingFeeRate comes from budget/txWeight, which means the returned
// fee rate will always be capped by this value, hence we don't need to
// worry about overpay.
estimatedFeeRate, err := fee.Estimate(l.estimator, l.endingFeeRate)
if err != nil {
return 0, err
}
return estimatedFeeRate, nil
}
package sweep
import (
"github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/build"
)
// log is a logger that is initialized with no output filters. This means the
// package will not perform any logging by default until the caller requests
// it.
var log btclog.Logger
// The default amount of logging is none.
func init() {
UseLogger(build.NewSubLogger("SWPR", nil))
}
// DisableLog disables all library log output. Logging output is disabled by
// default until UseLogger is called.
func DisableLog() {
UseLogger(btclog.Disabled)
}
// UseLogger uses a specified Logger to output package logging info. This
// should be used in preference to SetLogWriter if the caller is also using
// btclog.
func UseLogger(logger btclog.Logger) {
log = logger
}
package sweep
import (
"bytes"
"encoding/binary"
"errors"
"io"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/tlv"
)
var (
// txHashesBucketKey is the key that points to a bucket containing the
// hashes of all sweep txes that were published successfully.
//
// maps: txHash -> TxRecord
txHashesBucketKey = []byte("sweeper-tx-hashes")
// utxnChainPrefix is the bucket prefix for nursery buckets.
utxnChainPrefix = []byte("utxn")
// utxnHeightIndexKey is the sub bucket where the nursery stores the
// height index.
utxnHeightIndexKey = []byte("height-index")
// utxnFinalizedKndrTxnKey is a static key that can be used to locate
// the nursery finalized kindergarten sweep txn.
utxnFinalizedKndrTxnKey = []byte("finalized-kndr-txn")
byteOrder = binary.BigEndian
errNoTxHashesBucket = errors.New("tx hashes bucket does not exist")
// ErrTxNotFound is returned when querying using a txid that's not
// found in our db.
ErrTxNotFound = errors.New("tx not found")
)
// TxRecord specifies a record of a tx that's stored in the database.
type TxRecord struct {
// Txid is the sweeping tx's txid, which is used as the key to store
// the following values.
Txid chainhash.Hash
// FeeRate is the fee rate of the sweeping tx, unit is sats/kw.
FeeRate uint64
// Fee is the fee of the sweeping tx, unit is sat.
Fee uint64
// Published indicates whether the tx has been published.
Published bool
}
// toTlvStream converts TxRecord into a tlv representation.
func (t *TxRecord) toTlvStream() (*tlv.Stream, error) {
const (
// A set of tlv type definitions used to serialize TxRecord.
// We define it here instead of the head of the file to avoid
// naming conflicts.
//
// NOTE: A migration should be added whenever the existing type
// changes.
//
// NOTE: Txid is stored as the key, so it's not included here.
feeRateType tlv.Type = 0
feeType tlv.Type = 1
boolType tlv.Type = 2
)
return tlv.NewStream(
tlv.MakeBigSizeRecord(feeRateType, &t.FeeRate),
tlv.MakeBigSizeRecord(feeType, &t.Fee),
tlv.MakePrimitiveRecord(boolType, &t.Published),
)
}
// serializeTxRecord serializes a TxRecord based on tlv format.
func serializeTxRecord(w io.Writer, tx *TxRecord) error {
// Create the tlv stream.
tlvStream, err := tx.toTlvStream()
if err != nil {
return err
}
// Encode the tlv stream.
var buf bytes.Buffer
if err := tlvStream.Encode(&buf); err != nil {
return err
}
// Write the tlv stream.
if _, err = w.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
// deserializeTxRecord deserializes a TxRecord based on tlv format.
func deserializeTxRecord(r io.Reader) (*TxRecord, error) {
var tx TxRecord
// Create the tlv stream.
tlvStream, err := tx.toTlvStream()
if err != nil {
return nil, err
}
if err := tlvStream.Decode(r); err != nil {
return nil, err
}
return &tx, nil
}
// SweeperStore stores published txes.
type SweeperStore interface {
// IsOurTx determines whether a tx is published by us, based on its
// hash.
IsOurTx(hash chainhash.Hash) bool
// StoreTx stores a tx hash we are about to publish.
StoreTx(*TxRecord) error
// ListSweeps lists all the sweeps we have successfully published.
ListSweeps() ([]chainhash.Hash, error)
// GetTx queries the database to find the tx that matches the given
// txid. Returns ErrTxNotFound if it cannot be found.
GetTx(hash chainhash.Hash) (*TxRecord, error)
// DeleteTx removes a tx specified by the hash from the store.
DeleteTx(hash chainhash.Hash) error
}
type sweeperStore struct {
db kvdb.Backend
}
// NewSweeperStore returns a new store instance.
func NewSweeperStore(db kvdb.Backend, chainHash *chainhash.Hash) (
SweeperStore, error) {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
if tx.ReadWriteBucket(txHashesBucketKey) != nil {
return nil
}
txHashesBucket, err := tx.CreateTopLevelBucket(
txHashesBucketKey,
)
if err != nil {
return err
}
// Use non-existence of tx hashes bucket as a signal to migrate
// nursery finalized txes.
err = migrateTxHashes(tx, txHashesBucket, chainHash)
return err
}, func() {})
if err != nil {
return nil, err
}
return &sweeperStore{
db: db,
}, nil
}
// migrateTxHashes migrates nursery finalized txes to the tx hashes bucket. This
// is not implemented as a database migration, to keep the downgrade path open.
//
// TODO(yy): delete this function once nursery is removed.
func migrateTxHashes(tx kvdb.RwTx, txHashesBucket kvdb.RwBucket,
chainHash *chainhash.Hash) error {
log.Infof("Migrating UTXO nursery finalized TXIDs")
// Compose chain bucket key.
var b bytes.Buffer
if _, err := b.Write(utxnChainPrefix); err != nil {
return err
}
if _, err := b.Write(chainHash[:]); err != nil {
return err
}
// Get chain bucket if exists.
chainBucket := tx.ReadWriteBucket(b.Bytes())
if chainBucket == nil {
return nil
}
// Retrieve the existing height index.
hghtIndex := chainBucket.NestedReadWriteBucket(utxnHeightIndexKey)
if hghtIndex == nil {
return nil
}
// Retrieve all heights.
err := hghtIndex.ForEach(func(k, v []byte) error {
heightBucket := hghtIndex.NestedReadWriteBucket(k)
if heightBucket == nil {
return nil
}
// Get finalized tx for height.
txBytes := heightBucket.Get(utxnFinalizedKndrTxnKey)
if txBytes == nil {
return nil
}
// Deserialize and skip tx if it cannot be deserialized.
tx := &wire.MsgTx{}
err := tx.Deserialize(bytes.NewReader(txBytes))
if err != nil {
log.Warnf("Cannot deserialize utxn tx")
return nil
}
// Calculate hash.
hash := tx.TxHash()
// Insert utxn tx hash in hashes bucket.
log.Debugf("Inserting nursery tx %v in hash list "+
"(height=%v)", hash, byteOrder.Uint32(k))
// Create the transaction record. Since this is an old record,
// we can assume it's already been published. Although it's
// possible to calculate the fees and fee rate used here, we
// skip it as it's unlikely we'd perform RBF on these old
// sweeping transactions.
tr := &TxRecord{
Txid: hash,
Published: true,
}
// Serialize tx record.
var b bytes.Buffer
err = serializeTxRecord(&b, tr)
if err != nil {
return err
}
return txHashesBucket.Put(tr.Txid[:], b.Bytes())
})
if err != nil {
return err
}
return nil
}
// StoreTx stores that we are about to publish a tx.
func (s *sweeperStore) StoreTx(tr *TxRecord) error {
return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
txHashesBucket := tx.ReadWriteBucket(txHashesBucketKey)
if txHashesBucket == nil {
return errNoTxHashesBucket
}
// Serialize tx record.
var b bytes.Buffer
err := serializeTxRecord(&b, tr)
if err != nil {
return err
}
return txHashesBucket.Put(tr.Txid[:], b.Bytes())
}, func() {})
}
// IsOurTx determines whether a tx is published by us, based on its hash.
func (s *sweeperStore) IsOurTx(hash chainhash.Hash) bool {
var ours bool
err := kvdb.View(s.db, func(tx kvdb.RTx) error {
txHashesBucket := tx.ReadBucket(txHashesBucketKey)
// If the root bucket cannot be found, we consider the tx to be
// not found in our db.
if txHashesBucket == nil {
log.Error("Tx hashes bucket not found in sweeper store")
return nil
}
ours = txHashesBucket.Get(hash[:]) != nil
return nil
}, func() {
ours = false
})
if err != nil {
return false
}
return ours
}
// ListSweeps lists all the sweep transactions we have in the sweeper store.
func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) {
var sweepTxns []chainhash.Hash
if err := kvdb.View(s.db, func(tx kvdb.RTx) error {
txHashesBucket := tx.ReadBucket(txHashesBucketKey)
if txHashesBucket == nil {
return errNoTxHashesBucket
}
return txHashesBucket.ForEach(func(resKey, _ []byte) error {
txid, err := chainhash.NewHash(resKey)
if err != nil {
return err
}
sweepTxns = append(sweepTxns, *txid)
return nil
})
}, func() {
sweepTxns = nil
}); err != nil {
return nil, err
}
return sweepTxns, nil
}
// GetTx queries the database to find the tx that matches the given txid.
// Returns ErrTxNotFound if it cannot be found.
func (s *sweeperStore) GetTx(txid chainhash.Hash) (*TxRecord, error) {
// Create a record.
tr := &TxRecord{}
var err error
err = kvdb.View(s.db, func(tx kvdb.RTx) error {
txHashesBucket := tx.ReadBucket(txHashesBucketKey)
if txHashesBucket == nil {
return errNoTxHashesBucket
}
txBytes := txHashesBucket.Get(txid[:])
if txBytes == nil {
return ErrTxNotFound
}
// For old records, we'd get an empty byte slice here. We can
// assume it's already been published. Although it's possible
// to calculate the fees and fee rate used here, we skip it as
// it's unlikely we'd perform RBF on these old sweeping
// transactions.
//
// TODO(yy): remove this check once migration is added.
if len(txBytes) == 0 {
tr.Published = true
return nil
}
tr, err = deserializeTxRecord(bytes.NewReader(txBytes))
if err != nil {
return err
}
return nil
}, func() {
tr = &TxRecord{}
})
if err != nil {
return nil, err
}
// Attach the txid to the record.
tr.Txid = txid
return tr, nil
}
// DeleteTx removes the given tx from db.
func (s *sweeperStore) DeleteTx(txid chainhash.Hash) error {
return kvdb.Update(s.db, func(tx kvdb.RwTx) error {
txHashesBucket := tx.ReadWriteBucket(txHashesBucketKey)
if txHashesBucket == nil {
return errNoTxHashesBucket
}
return txHashesBucket.Delete(txid[:])
}, func() {})
}
// Compile-time constraint to ensure sweeperStore implements SweeperStore.
var _ SweeperStore = (*sweeperStore)(nil)
package sweep
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainio"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
// ErrRemoteSpend is returned in case an output that we try to sweep is
// confirmed in a tx of the remote party.
ErrRemoteSpend = errors.New("remote party swept utxo")
// ErrFeePreferenceTooLow is returned when the fee preference gives a
// fee rate that's below the relay fee rate.
ErrFeePreferenceTooLow = errors.New("fee preference too low")
// ErrExclusiveGroupSpend is returned in case a different input of the
// same exclusive group was spent.
ErrExclusiveGroupSpend = errors.New("other member of exclusive group " +
"was spent")
// ErrSweeperShuttingDown is an error returned when a client attempts to
// make a request to the UtxoSweeper, but it is unable to handle it as
// it is/has already been stopped.
ErrSweeperShuttingDown = errors.New("utxo sweeper shutting down")
// DefaultDeadlineDelta defines a default deadline delta (1 week) to be
// used when sweeping inputs with no deadline pressure.
DefaultDeadlineDelta = int32(1008)
)
// Params contains the parameters that control the sweeping process.
type Params struct {
// ExclusiveGroup is an identifier that, if set, prevents other inputs
// with the same identifier from being batched together.
ExclusiveGroup *uint64
// DeadlineHeight specifies an absolute block height that this input
// should be confirmed by. This value is used by the fee bumper to
// decide its urgency and adjust its feerate used.
DeadlineHeight fn.Option[int32]
// Budget specifies the maximum amount of satoshis that can be spent on
// fees for this sweep.
Budget btcutil.Amount
// Immediate indicates that the input should be swept immediately
// without waiting for blocks to come to trigger the sweeping of
// inputs.
Immediate bool
// StartingFeeRate is an optional parameter that can be used to specify
// the initial fee rate to use for the fee function.
StartingFeeRate fn.Option[chainfee.SatPerKWeight]
}
// String returns a human readable interpretation of the sweep parameters.
func (p Params) String() string {
deadline := "none"
p.DeadlineHeight.WhenSome(func(d int32) {
deadline = fmt.Sprintf("%d", d)
})
exclusiveGroup := "none"
if p.ExclusiveGroup != nil {
exclusiveGroup = fmt.Sprintf("%d", *p.ExclusiveGroup)
}
return fmt.Sprintf("startingFeeRate=%v, immediate=%v, "+
"exclusive_group=%v, budget=%v, deadline=%v", p.StartingFeeRate,
p.Immediate, exclusiveGroup, p.Budget, deadline)
}
// SweepState represents the current state of a pending input.
//
//nolint:revive
type SweepState uint8
const (
// Init is the initial state of a pending input. This is set when a new
// sweeping request for a given input is made.
Init SweepState = iota
// PendingPublish specifies an input's state where it's already been
// included in a sweeping tx but the tx is not published yet. Inputs
// in this state should not be used for grouping again.
PendingPublish
// Published is the state where the input's sweeping tx has
// successfully been published. Inputs in this state can only be
// updated via RBF.
Published
// PublishFailed is the state when an error is returned from publishing
// the sweeping tx. Inputs in this state can be re-grouped in to a new
// sweeping tx.
PublishFailed
// Swept is the final state of a pending input. This is set when the
// input has been successfully swept.
Swept
// Excluded is the state of a pending input that has been excluded and
// can no longer be swept. For instance, when one of the three anchor
// sweeping transactions confirmed, the remaining two will be excluded.
Excluded
// Fatal is the final state of a pending input. Inputs ending in this
// state won't be retried. This could happen,
// - when a pending input has too many failed publish attempts;
// - the input has been spent by another party;
// - unknown broadcast error is returned.
Fatal
)
// String gives a human readable text for the sweep states.
func (s SweepState) String() string {
switch s {
case Init:
return "Init"
case PendingPublish:
return "PendingPublish"
case Published:
return "Published"
case PublishFailed:
return "PublishFailed"
case Swept:
return "Swept"
case Excluded:
return "Excluded"
case Fatal:
return "Fatal"
default:
return "Unknown"
}
}
// RBFInfo stores the information required to perform a RBF bump on a pending
// sweeping tx.
type RBFInfo struct {
// Txid is the txid of the sweeping tx.
Txid chainhash.Hash
// FeeRate is the fee rate of the sweeping tx.
FeeRate chainfee.SatPerKWeight
// Fee is the total fee of the sweeping tx.
Fee btcutil.Amount
}
// SweeperInput is created when an input reaches the main loop for the first
// time. It wraps the input and tracks all relevant state that is needed for
// sweeping.
type SweeperInput struct {
input.Input
// state tracks the current state of the input.
state SweepState
// listeners is a list of channels over which the final outcome of the
// sweep needs to be broadcasted.
listeners []chan Result
// ntfnRegCancel is populated with a function that cancels the chain
// notifier spend registration.
ntfnRegCancel func()
// publishAttempts records the number of attempts that have already been
// made to sweep this tx.
publishAttempts int
// params contains the parameters that control the sweeping process.
params Params
// lastFeeRate is the most recent fee rate used for this input within a
// transaction broadcast to the network.
lastFeeRate chainfee.SatPerKWeight
// rbf records the RBF constraints.
rbf fn.Option[RBFInfo]
// DeadlineHeight is the deadline height for this input. This is
// different from the DeadlineHeight in its params as it's an actual
// value than an option.
DeadlineHeight int32
}
// String returns a human readable interpretation of the pending input.
func (p *SweeperInput) String() string {
return fmt.Sprintf("%v (%v)", p.Input.OutPoint(), p.Input.WitnessType())
}
// terminated returns a boolean indicating whether the input has reached a
// final state.
func (p *SweeperInput) terminated() bool {
switch p.state {
// If the input has reached a final state, that it's either
// been swept, or failed, or excluded, we will remove it from
// our sweeper.
case Fatal, Swept, Excluded:
return true
default:
return false
}
}
// isMature returns a boolean indicating whether the input has a timelock that
// has been reached or not. The locktime found is also returned.
func (p *SweeperInput) isMature(currentHeight uint32) (bool, uint32) {
locktime, _ := p.RequiredLockTime()
if currentHeight < locktime {
log.Debugf("Input %v has locktime=%v, current height is %v",
p, locktime, currentHeight)
return false, locktime
}
// If the input has a CSV that's not yet reached, we will skip
// this input and wait for the expiry.
//
// NOTE: We need to consider whether this input can be included in the
// next block or not, which means the CSV will be checked against the
// currentHeight plus one.
locktime = p.BlocksToMaturity() + p.HeightHint()
if currentHeight+1 < locktime {
log.Debugf("Input %v has CSV expiry=%v, current height is %v, "+
"skipped sweeping", p, locktime, currentHeight)
return false, locktime
}
return true, locktime
}
// InputsMap is a type alias for a set of pending inputs.
type InputsMap = map[wire.OutPoint]*SweeperInput
// inputsMapToString returns a human readable interpretation of the pending
// inputs.
func inputsMapToString(inputs InputsMap) string {
if len(inputs) == 0 {
return ""
}
inps := make([]input.Input, 0, len(inputs))
for _, in := range inputs {
inps = append(inps, in)
}
return "\n" + inputTypeSummary(inps)
}
// pendingSweepsReq is an internal message we'll use to represent an external
// caller's intent to retrieve all of the pending inputs the UtxoSweeper is
// attempting to sweep.
type pendingSweepsReq struct {
respChan chan map[wire.OutPoint]*PendingInputResponse
errChan chan error
}
// PendingInputResponse contains information about an input that is currently
// being swept by the UtxoSweeper.
type PendingInputResponse struct {
// OutPoint is the identify outpoint of the input being swept.
OutPoint wire.OutPoint
// WitnessType is the witness type of the input being swept.
WitnessType input.WitnessType
// Amount is the amount of the input being swept.
Amount btcutil.Amount
// LastFeeRate is the most recent fee rate used for the input being
// swept within a transaction broadcast to the network.
LastFeeRate chainfee.SatPerKWeight
// BroadcastAttempts is the number of attempts we've made to sweept the
// input.
BroadcastAttempts int
// Params contains the sweep parameters for this pending request.
Params Params
// DeadlineHeight records the deadline height of this input.
DeadlineHeight uint32
}
// updateReq is an internal message we'll use to represent an external caller's
// intent to update the sweep parameters of a given input.
type updateReq struct {
input wire.OutPoint
params Params
responseChan chan *updateResp
}
// updateResp is an internal message we'll use to hand off the response of a
// updateReq from the UtxoSweeper's main event loop back to the caller.
type updateResp struct {
resultChan chan Result
err error
}
// UtxoSweeper is responsible for sweeping outputs back into the wallet
type UtxoSweeper struct {
started uint32 // To be used atomically.
stopped uint32 // To be used atomically.
// Embed the blockbeat consumer struct to get access to the method
// `NotifyBlockProcessed` and the `BlockbeatChan`.
chainio.BeatConsumer
cfg *UtxoSweeperConfig
newInputs chan *sweepInputMessage
spendChan chan *chainntnfs.SpendDetail
// pendingSweepsReq is a channel that will be sent requests by external
// callers in order to retrieve the set of pending inputs the
// UtxoSweeper is attempting to sweep.
pendingSweepsReqs chan *pendingSweepsReq
// updateReqs is a channel that will be sent requests by external
// callers who wish to bump the fee rate of a given input.
updateReqs chan *updateReq
// inputs is the total set of inputs the UtxoSweeper has been requested
// to sweep.
inputs InputsMap
currentOutputScript fn.Option[lnwallet.AddrWithKey]
relayFeeRate chainfee.SatPerKWeight
quit chan struct{}
wg sync.WaitGroup
// currentHeight is the best known height of the main chain. This is
// updated whenever a new block epoch is received.
currentHeight int32
// bumpRespChan is a channel that receives broadcast results from the
// TxPublisher.
bumpRespChan chan *bumpResp
}
// Compile-time check for the chainio.Consumer interface.
var _ chainio.Consumer = (*UtxoSweeper)(nil)
// UtxoSweeperConfig contains dependencies of UtxoSweeper.
type UtxoSweeperConfig struct {
// GenSweepScript generates a P2WKH script belonging to the wallet where
// funds can be swept.
GenSweepScript func() fn.Result[lnwallet.AddrWithKey]
// FeeEstimator is used when crafting sweep transactions to estimate
// the necessary fee relative to the expected size of the sweep
// transaction.
FeeEstimator chainfee.Estimator
// Wallet contains the wallet functions that sweeper requires.
Wallet Wallet
// Notifier is an instance of a chain notifier we'll use to watch for
// certain on-chain events.
Notifier chainntnfs.ChainNotifier
// Mempool is the mempool watcher that will be used to query whether a
// given input is already being spent by a transaction in the mempool.
Mempool chainntnfs.MempoolWatcher
// Store stores the published sweeper txes.
Store SweeperStore
// Signer is used by the sweeper to generate valid witnesses at the
// time the incubated outputs need to be spent.
Signer input.Signer
// MaxInputsPerTx specifies the default maximum number of inputs allowed
// in a single sweep tx. If more need to be swept, multiple txes are
// created and published.
MaxInputsPerTx uint32
// MaxFeeRate is the maximum fee rate allowed within the UtxoSweeper.
MaxFeeRate chainfee.SatPerVByte
// Aggregator is used to group inputs into clusters based on its
// implemention-specific strategy.
Aggregator UtxoAggregator
// Publisher is used to publish the sweep tx crafted here and monitors
// it for potential fee bumps.
Publisher Bumper
// NoDeadlineConfTarget is the conf target to use when sweeping
// non-time-sensitive outputs.
NoDeadlineConfTarget uint32
}
// Result is the struct that is pushed through the result channel. Callers can
// use this to be informed of the final sweep result. In case of a remote
// spend, Err will be ErrRemoteSpend.
type Result struct {
// Err is the final result of the sweep. It is nil when the input is
// swept successfully by us. ErrRemoteSpend is returned when another
// party took the input.
Err error
// Tx is the transaction that spent the input.
Tx *wire.MsgTx
}
// sweepInputMessage structs are used in the internal channel between the
// SweepInput call and the sweeper main loop.
type sweepInputMessage struct {
input input.Input
params Params
resultChan chan Result
}
// New returns a new Sweeper instance.
func New(cfg *UtxoSweeperConfig) *UtxoSweeper {
s := &UtxoSweeper{
cfg: cfg,
newInputs: make(chan *sweepInputMessage),
spendChan: make(chan *chainntnfs.SpendDetail),
updateReqs: make(chan *updateReq),
pendingSweepsReqs: make(chan *pendingSweepsReq),
quit: make(chan struct{}),
inputs: make(InputsMap),
bumpRespChan: make(chan *bumpResp, 100),
}
// Mount the block consumer.
s.BeatConsumer = chainio.NewBeatConsumer(s.quit, s.Name())
return s
}
// Start starts the process of constructing and publish sweep txes.
func (s *UtxoSweeper) Start(beat chainio.Blockbeat) error {
if !atomic.CompareAndSwapUint32(&s.started, 0, 1) {
return nil
}
log.Info("Sweeper starting")
// Retrieve relay fee for dust limit calculation. Assume that this will
// not change from here on.
s.relayFeeRate = s.cfg.FeeEstimator.RelayFeePerKW()
// Set the current height.
s.currentHeight = beat.Height()
// Start sweeper main loop.
s.wg.Add(1)
go s.collector()
return nil
}
// RelayFeePerKW returns the minimum fee rate required for transactions to be
// relayed.
func (s *UtxoSweeper) RelayFeePerKW() chainfee.SatPerKWeight {
return s.relayFeeRate
}
// Stop stops sweeper from listening to block epochs and constructing sweep
// txes.
func (s *UtxoSweeper) Stop() error {
if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
return nil
}
log.Info("Sweeper shutting down...")
defer log.Debug("Sweeper shutdown complete")
close(s.quit)
s.wg.Wait()
return nil
}
// NOTE: part of the `chainio.Consumer` interface.
func (s *UtxoSweeper) Name() string {
return "UtxoSweeper"
}
// SweepInput sweeps inputs back into the wallet. The inputs will be batched and
// swept after the batch time window ends. A custom fee preference can be
// provided to determine what fee rate should be used for the input. Note that
// the input may not always be swept with this exact value, as its possible for
// it to be batched under the same transaction with other similar fee rate
// inputs.
//
// NOTE: Extreme care needs to be taken that input isn't changed externally.
// Because it is an interface and we don't know what is exactly behind it, we
// cannot make a local copy in sweeper.
//
// TODO(yy): make sure the caller is using the Result chan.
func (s *UtxoSweeper) SweepInput(inp input.Input,
params Params) (chan Result, error) {
if inp == nil || inp.OutPoint() == input.EmptyOutPoint ||
inp.SignDesc() == nil {
return nil, errors.New("nil input received")
}
absoluteTimeLock, _ := inp.RequiredLockTime()
log.Debugf("Sweep request received: out_point=%v, witness_type=%v, "+
"relative_time_lock=%v, absolute_time_lock=%v, amount=%v, "+
"parent=(%v), params=(%v)", inp.OutPoint(), inp.WitnessType(),
inp.BlocksToMaturity(), absoluteTimeLock,
btcutil.Amount(inp.SignDesc().Output.Value),
inp.UnconfParent(), params)
sweeperInput := &sweepInputMessage{
input: inp,
params: params,
resultChan: make(chan Result, 1),
}
// Deliver input to the main event loop.
select {
case s.newInputs <- sweeperInput:
case <-s.quit:
return nil, ErrSweeperShuttingDown
}
return sweeperInput.resultChan, nil
}
// removeConflictSweepDescendants removes any transactions from the wallet that
// spend outputs included in the passed outpoint set. This needs to be done in
// cases where we're not the only ones that can sweep an output, but there may
// exist unconfirmed spends that spend outputs created by a sweep transaction.
// The most common case for this is when someone sweeps our anchor outputs
// after 16 blocks. Moreover this is also needed for wallets which use neutrino
// as a backend when a channel is force closed and anchor cpfp txns are
// created to bump the initial commitment transaction. In this case an anchor
// cpfp is broadcasted for up to 3 commitment transactions (local,
// remote-dangling, remote). Using neutrino all of those transactions will be
// accepted (the commitment tx will be different in all of those cases) and have
// to be removed as soon as one of them confirmes (they do have the same
// ExclusiveGroup). For neutrino backends the corresponding BIP 157 serving full
// nodes do not signal invalid transactions anymore.
func (s *UtxoSweeper) removeConflictSweepDescendants(
outpoints map[wire.OutPoint]struct{}) error {
// Obtain all the past sweeps that we've done so far. We'll need these
// to ensure that if the spendingTx spends any of the same inputs, then
// we remove any transaction that may be spending those inputs from the
// wallet.
//
// TODO(roasbeef): can be last sweep here if we remove anything confirmed
// from the store?
pastSweepHashes, err := s.cfg.Store.ListSweeps()
if err != nil {
return err
}
// We'll now go through each past transaction we published during this
// epoch and cross reference the spent inputs. If there're any inputs
// in common with the inputs the spendingTx spent, then we'll remove
// those.
//
// TODO(roasbeef): need to start to remove all transaction hashes after
// every N blocks (assumed point of no return)
for _, sweepHash := range pastSweepHashes {
sweepTx, err := s.cfg.Wallet.FetchTx(sweepHash)
if err != nil {
return err
}
// Transaction wasn't found in the wallet, may have already
// been replaced/removed.
if sweepTx == nil {
// If it was removed, then we'll play it safe and mark
// it as no longer need to be rebroadcasted.
s.cfg.Wallet.CancelRebroadcast(sweepHash)
continue
}
// Check to see if this past sweep transaction spent any of the
// same inputs as spendingTx.
var isConflicting bool
for _, txIn := range sweepTx.TxIn {
if _, ok := outpoints[txIn.PreviousOutPoint]; ok {
isConflicting = true
break
}
}
if !isConflicting {
continue
}
// If it is conflicting, then we'll signal the wallet to remove
// all the transactions that are descendants of outputs created
// by the sweepTx and the sweepTx itself.
log.Debugf("Removing sweep txid=%v from wallet: %v",
sweepTx.TxHash(), spew.Sdump(sweepTx))
err = s.cfg.Wallet.RemoveDescendants(sweepTx)
if err != nil {
log.Warnf("Unable to remove descendants: %v", err)
}
// If this transaction was conflicting, then we'll stop
// rebroadcasting it in the background.
s.cfg.Wallet.CancelRebroadcast(sweepHash)
}
return nil
}
// collector is the sweeper main loop. It processes new inputs, spend
// notifications and counts down to publication of the sweep tx.
func (s *UtxoSweeper) collector() {
defer s.wg.Done()
for {
// Clean inputs, which will remove inputs that are swept,
// failed, or excluded from the sweeper and return inputs that
// are either new or has been published but failed back, which
// will be retried again here.
s.updateSweeperInputs()
select {
// A new inputs is offered to the sweeper. We check to see if
// we are already trying to sweep this input and if not, set up
// a listener to spend and schedule a sweep.
case input := <-s.newInputs:
err := s.handleNewInput(input)
if err != nil {
log.Criticalf("Unable to handle new input: %v",
err)
return
}
// If this input is forced, we perform an sweep
// immediately.
//
// TODO(ziggie): Make sure when `immediate` is selected
// as a parameter that we only trigger the sweeping of
// this specific input rather than triggering the sweeps
// of all current pending inputs registered with the
// sweeper.
if input.params.Immediate {
inputs := s.updateSweeperInputs()
s.sweepPendingInputs(inputs)
}
// A spend of one of our inputs is detected. Signal sweep
// results to the caller(s).
case spend := <-s.spendChan:
s.handleInputSpent(spend)
// A new external request has been received to retrieve all of
// the inputs we're currently attempting to sweep.
case req := <-s.pendingSweepsReqs:
s.handlePendingSweepsReq(req)
// A new external request has been received to bump the fee rate
// of a given input.
case req := <-s.updateReqs:
resultChan, err := s.handleUpdateReq(req)
req.responseChan <- &updateResp{
resultChan: resultChan,
err: err,
}
// Perform an sweep immediately if asked.
if req.params.Immediate {
inputs := s.updateSweeperInputs()
s.sweepPendingInputs(inputs)
}
case resp := <-s.bumpRespChan:
// Handle the bump event.
err := s.handleBumpEvent(resp)
if err != nil {
log.Errorf("Failed to handle bump event: %v",
err)
}
// A new block comes in, update the bestHeight, perform a check
// over all pending inputs and publish sweeping txns if needed.
case beat := <-s.BlockbeatChan:
// Update the sweeper to the best height.
s.currentHeight = beat.Height()
// Update the inputs with the latest height.
inputs := s.updateSweeperInputs()
log.Debugf("Received new block: height=%v, attempt "+
"sweeping %d inputs:%s", s.currentHeight,
len(inputs),
lnutils.NewLogClosure(func() string {
return inputsMapToString(inputs)
}))
// Attempt to sweep any pending inputs.
s.sweepPendingInputs(inputs)
// Notify we've processed the block.
s.NotifyBlockProcessed(beat, nil)
case <-s.quit:
return
}
}
}
// removeExclusiveGroup removes all inputs in the given exclusive group. This
// function is called when one of the exclusive group inputs has been spent. The
// other inputs won't ever be spendable and can be removed. This also prevents
// them from being part of future sweep transactions that would fail. In
// addition sweep transactions of those inputs will be removed from the wallet.
func (s *UtxoSweeper) removeExclusiveGroup(group uint64) {
for outpoint, input := range s.inputs {
outpoint := outpoint
// Skip inputs that aren't exclusive.
if input.params.ExclusiveGroup == nil {
continue
}
// Skip inputs from other exclusive groups.
if *input.params.ExclusiveGroup != group {
continue
}
// Skip inputs that are already terminated.
if input.terminated() {
log.Tracef("Skipped sending error result for "+
"input %v, state=%v", outpoint, input.state)
continue
}
// Signal result channels.
s.signalResult(input, Result{
Err: ErrExclusiveGroupSpend,
})
// Update the input's state as it can no longer be swept.
input.state = Excluded
// Remove all unconfirmed transactions from the wallet which
// spend the passed outpoint of the same exclusive group.
outpoints := map[wire.OutPoint]struct{}{
outpoint: {},
}
err := s.removeConflictSweepDescendants(outpoints)
if err != nil {
log.Warnf("Unable to remove conflicting sweep tx from "+
"wallet for outpoint %v : %v", outpoint, err)
}
}
}
// signalResult notifies the listeners of the final result of the input sweep.
// It also cancels any pending spend notification.
func (s *UtxoSweeper) signalResult(pi *SweeperInput, result Result) {
op := pi.OutPoint()
listeners := pi.listeners
if result.Err == nil {
log.Tracef("Dispatching sweep success for %v to %v listeners",
op, len(listeners),
)
} else {
log.Tracef("Dispatching sweep error for %v to %v listeners: %v",
op, len(listeners), result.Err,
)
}
// Signal all listeners. Channel is buffered. Because we only send once
// on every channel, it should never block.
for _, resultChan := range listeners {
resultChan <- result
}
// Cancel spend notification with chain notifier. This is not necessary
// in case of a success, except for that a reorg could still happen.
if pi.ntfnRegCancel != nil {
log.Debugf("Canceling spend ntfn for %v", op)
pi.ntfnRegCancel()
}
}
// sweep takes a set of preselected inputs, creates a sweep tx and publishes
// the tx. The output address is only marked as used if the publish succeeds.
func (s *UtxoSweeper) sweep(set InputSet) error {
// Generate an output script if there isn't an unused script available.
if s.currentOutputScript.IsNone() {
addr, err := s.cfg.GenSweepScript().Unpack()
if err != nil {
return fmt.Errorf("gen sweep script: %w", err)
}
s.currentOutputScript = fn.Some(addr)
log.Debugf("Created sweep DeliveryAddress %x",
addr.DeliveryAddress)
}
sweepAddr, err := s.currentOutputScript.UnwrapOrErr(
fmt.Errorf("none sweep script"),
)
if err != nil {
return err
}
// Create a fee bump request and ask the publisher to broadcast it. The
// publisher will then take over and start monitoring the tx for
// potential fee bump.
req := &BumpRequest{
Inputs: set.Inputs(),
Budget: set.Budget(),
DeadlineHeight: set.DeadlineHeight(),
DeliveryAddress: sweepAddr,
MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(),
StartingFeeRate: set.StartingFeeRate(),
Immediate: set.Immediate(),
// TODO(yy): pass the strategy here.
}
// Reschedule the inputs that we just tried to sweep. This is done in
// case the following publish fails, we'd like to update the inputs'
// publish attempts and rescue them in the next sweep.
s.markInputsPendingPublish(set)
// Broadcast will return a read-only chan that we will listen to for
// this publish result and future RBF attempt.
resp := s.cfg.Publisher.Broadcast(req)
// Successfully sent the broadcast attempt, we now handle the result by
// subscribing to the result chan and listen for future updates about
// this tx.
s.wg.Add(1)
go s.monitorFeeBumpResult(set, resp)
return nil
}
// markInputsPendingPublish updates the pending inputs with the given tx
// inputs. It also increments the `publishAttempts`.
func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) {
// Reschedule sweep.
for _, input := range set.Inputs() {
op := input.OutPoint()
pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
// input that was attached. In that case there also
// isn't a pending input to update.
log.Tracef("Skipped marking input as pending "+
"published: %v not found in pending inputs", op)
continue
}
// If this input has already terminated, there's clearly
// something wrong as it would have been removed. In this case
// we log an error and skip marking this input as pending
// publish.
if pi.terminated() {
log.Errorf("Expect input %v to not have terminated "+
"state, instead it has %v", op, pi.state)
continue
}
// Update the input's state.
pi.state = PendingPublish
// Record another publish attempt.
pi.publishAttempts++
}
}
// markInputsPublished updates the sweeping tx in db and marks the list of
// inputs as published.
func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, set InputSet) error {
// Mark this tx in db once successfully published.
//
// NOTE: this will behave as an overwrite, which is fine as the record
// is small.
tr.Published = true
err := s.cfg.Store.StoreTx(tr)
if err != nil {
return fmt.Errorf("store tx: %w", err)
}
// Reschedule sweep.
for _, input := range set.Inputs() {
op := input.OutPoint()
pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
// input that was attached. In that case there also
// isn't a pending input to update.
log.Tracef("Skipped marking input as published: %v "+
"not found in pending inputs", op)
continue
}
// Valdiate that the input is in an expected state.
if pi.state != PendingPublish {
// We may get a Published if this is a replacement tx.
log.Debugf("Expect input %v to have %v, instead it "+
"has %v", op, PendingPublish, pi.state)
continue
}
// Update the input's state.
pi.state = Published
// Update the input's latest fee rate.
pi.lastFeeRate = chainfee.SatPerKWeight(tr.FeeRate)
}
return nil
}
// markInputsPublishFailed marks the list of inputs as failed to be published.
func (s *UtxoSweeper) markInputsPublishFailed(set InputSet,
feeRate chainfee.SatPerKWeight) {
// Reschedule sweep.
for _, inp := range set.Inputs() {
op := inp.OutPoint()
pi, ok := s.inputs[op]
if !ok {
// It could be that this input is an additional wallet
// input that was attached. In that case there also
// isn't a pending input to update.
log.Tracef("Skipped marking input as publish failed: "+
"%v not found in pending inputs", op)
continue
}
// Valdiate that the input is in an expected state.
if pi.state != PendingPublish && pi.state != Published {
log.Debugf("Expect input %v to have %v, instead it "+
"has %v", op, PendingPublish, pi.state)
continue
}
log.Warnf("Failed to publish input %v", op)
// Update the input's state.
pi.state = PublishFailed
log.Debugf("Input(%v): updating params: starting fee rate "+
"[%v -> %v]", op, pi.params.StartingFeeRate,
feeRate)
// Update the input using the fee rate specified from the
// BumpResult, which should be the starting fee rate to use for
// the next sweeping attempt.
pi.params.StartingFeeRate = fn.Some(feeRate)
}
}
// monitorSpend registers a spend notification with the chain notifier. It
// returns a cancel function that can be used to cancel the registration.
func (s *UtxoSweeper) monitorSpend(outpoint wire.OutPoint,
script []byte, heightHint uint32) (func(), error) {
log.Tracef("Wait for spend of %v at heightHint=%v",
outpoint, heightHint)
spendEvent, err := s.cfg.Notifier.RegisterSpendNtfn(
&outpoint, script, heightHint,
)
if err != nil {
return nil, fmt.Errorf("register spend ntfn: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
select {
case spend, ok := <-spendEvent.Spend:
if !ok {
log.Debugf("Spend ntfn for %v canceled",
outpoint)
return
}
log.Debugf("Delivering spend ntfn for %v", outpoint)
select {
case s.spendChan <- spend:
log.Debugf("Delivered spend ntfn for %v",
outpoint)
case <-s.quit:
}
case <-s.quit:
}
}()
return spendEvent.Cancel, nil
}
// PendingInputs returns the set of inputs that the UtxoSweeper is currently
// attempting to sweep.
func (s *UtxoSweeper) PendingInputs() (
map[wire.OutPoint]*PendingInputResponse, error) {
respChan := make(chan map[wire.OutPoint]*PendingInputResponse, 1)
errChan := make(chan error, 1)
select {
case s.pendingSweepsReqs <- &pendingSweepsReq{
respChan: respChan,
errChan: errChan,
}:
case <-s.quit:
return nil, ErrSweeperShuttingDown
}
select {
case pendingSweeps := <-respChan:
return pendingSweeps, nil
case err := <-errChan:
return nil, err
case <-s.quit:
return nil, ErrSweeperShuttingDown
}
}
// handlePendingSweepsReq handles a request to retrieve all pending inputs the
// UtxoSweeper is attempting to sweep.
func (s *UtxoSweeper) handlePendingSweepsReq(
req *pendingSweepsReq) map[wire.OutPoint]*PendingInputResponse {
resps := make(map[wire.OutPoint]*PendingInputResponse, len(s.inputs))
for _, inp := range s.inputs {
// Skip immature inputs for compatibility.
mature, _ := inp.isMature(uint32(s.currentHeight))
if !mature {
continue
}
// Only the exported fields are set, as we expect the response
// to only be consumed externally.
op := inp.OutPoint()
resps[op] = &PendingInputResponse{
OutPoint: op,
WitnessType: inp.WitnessType(),
Amount: btcutil.Amount(
inp.SignDesc().Output.Value,
),
LastFeeRate: inp.lastFeeRate,
BroadcastAttempts: inp.publishAttempts,
Params: inp.params,
DeadlineHeight: uint32(inp.DeadlineHeight),
}
}
select {
case req.respChan <- resps:
case <-s.quit:
log.Debug("Skipped sending pending sweep response due to " +
"UtxoSweeper shutting down")
}
return resps
}
// UpdateParams allows updating the sweep parameters of a pending input in the
// UtxoSweeper. This function can be used to provide an updated fee preference
// and force flag that will be used for a new sweep transaction of the input
// that will act as a replacement transaction (RBF) of the original sweeping
// transaction, if any. The exclusive group is left unchanged.
//
// NOTE: This currently doesn't do any fee rate validation to ensure that a bump
// is actually successful. The responsibility of doing so should be handled by
// the caller.
func (s *UtxoSweeper) UpdateParams(input wire.OutPoint,
params Params) (chan Result, error) {
responseChan := make(chan *updateResp, 1)
select {
case s.updateReqs <- &updateReq{
input: input,
params: params,
responseChan: responseChan,
}:
case <-s.quit:
return nil, ErrSweeperShuttingDown
}
select {
case response := <-responseChan:
return response.resultChan, response.err
case <-s.quit:
return nil, ErrSweeperShuttingDown
}
}
// handleUpdateReq handles an update request by simply updating the sweep
// parameters of the pending input. Currently, no validation is done on the new
// fee preference to ensure it will properly create a replacement transaction.
//
// TODO(wilmer):
// - Validate fee preference to ensure we'll create a valid replacement
// transaction to allow the new fee rate to propagate throughout the
// network.
// - Ensure we don't combine this input with any other unconfirmed inputs that
// did not exist in the original sweep transaction, resulting in an invalid
// replacement transaction.
func (s *UtxoSweeper) handleUpdateReq(req *updateReq) (
chan Result, error) {
// If the UtxoSweeper is already trying to sweep this input, then we can
// simply just increase its fee rate. This will allow the input to be
// batched with others which also have a similar fee rate, creating a
// higher fee rate transaction that replaces the original input's
// sweeping transaction.
sweeperInput, ok := s.inputs[req.input]
if !ok {
return nil, lnwallet.ErrNotMine
}
// Create the updated parameters struct. Leave the exclusive group
// unchanged.
newParams := Params{
StartingFeeRate: req.params.StartingFeeRate,
Immediate: req.params.Immediate,
Budget: req.params.Budget,
DeadlineHeight: req.params.DeadlineHeight,
ExclusiveGroup: sweeperInput.params.ExclusiveGroup,
}
log.Debugf("Updating parameters for %v(state=%v) from (%v) to (%v)",
req.input, sweeperInput.state, sweeperInput.params, newParams)
sweeperInput.params = newParams
// We need to reset the state so this input will be attempted again by
// our sweeper.
//
// TODO(yy): a dedicated state?
sweeperInput.state = Init
// If the new input specifies a deadline, update the deadline height.
sweeperInput.DeadlineHeight = req.params.DeadlineHeight.UnwrapOr(
sweeperInput.DeadlineHeight,
)
resultChan := make(chan Result, 1)
sweeperInput.listeners = append(sweeperInput.listeners, resultChan)
return resultChan, nil
}
// ListSweeps returns a list of the sweeps recorded by the sweep store.
func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) {
return s.cfg.Store.ListSweeps()
}
// mempoolLookup takes an input's outpoint and queries the mempool to see
// whether it's already been spent in a transaction found in the mempool.
// Returns the transaction if found.
func (s *UtxoSweeper) mempoolLookup(op wire.OutPoint) fn.Option[wire.MsgTx] {
// For neutrino backend, there's no mempool available, so we exit
// early.
if s.cfg.Mempool == nil {
log.Debugf("Skipping mempool lookup for %v, no mempool ", op)
return fn.None[wire.MsgTx]()
}
// Query this input in the mempool. If this outpoint is already spent
// in mempool, we should get a spending event back immediately.
return s.cfg.Mempool.LookupInputMempoolSpend(op)
}
// calculateDefaultDeadline calculates the default deadline height for a sweep
// request that has no deadline height specified.
func (s *UtxoSweeper) calculateDefaultDeadline(pi *SweeperInput) int32 {
// Create a default deadline height, which will be used when there's no
// DeadlineHeight specified for a given input.
defaultDeadline := s.currentHeight + int32(s.cfg.NoDeadlineConfTarget)
// If the input is immature and has a locktime, we'll use the locktime
// height as the starting height.
matured, locktime := pi.isMature(uint32(s.currentHeight))
if !matured {
defaultDeadline = int32(locktime + s.cfg.NoDeadlineConfTarget)
log.Debugf("Input %v is immature, using locktime=%v instead "+
"of current height=%d as starting height",
pi.OutPoint(), locktime, s.currentHeight)
}
return defaultDeadline
}
// handleNewInput processes a new input by registering spend notification and
// scheduling sweeping for it.
func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) error {
outpoint := input.input.OutPoint()
pi, pending := s.inputs[outpoint]
if pending {
log.Infof("Already has pending input %v received, old params: "+
"%v, new params %v", outpoint, pi.params, input.params)
s.handleExistingInput(input, pi)
return nil
}
// This is a new input, and we want to query the mempool to see if this
// input has already been spent. If so, we'll start the input with the
// RBFInfo.
rbfInfo := s.decideRBFInfo(input.input.OutPoint())
// Create a new pendingInput and initialize the listeners slice with
// the passed in result channel. If this input is offered for sweep
// again, the result channel will be appended to this slice.
pi = &SweeperInput{
state: Init,
listeners: []chan Result{input.resultChan},
Input: input.input,
params: input.params,
rbf: rbfInfo,
}
// Set the starting fee rate if a previous sweeping tx is found.
rbfInfo.WhenSome(func(info RBFInfo) {
pi.params.StartingFeeRate = fn.Some(info.FeeRate)
})
// Set the acutal deadline height.
pi.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr(
s.calculateDefaultDeadline(pi),
)
s.inputs[outpoint] = pi
log.Tracef("input %v, state=%v, added to inputs", outpoint, pi.state)
log.Infof("Registered sweep request at block %d: out_point=%v, "+
"witness_type=%v, amount=%v, deadline=%d, state=%v, "+
"params=(%v)", s.currentHeight, pi.OutPoint(), pi.WitnessType(),
btcutil.Amount(pi.SignDesc().Output.Value), pi.DeadlineHeight,
pi.state, pi.params)
// Start watching for spend of this input, either by us or the remote
// party.
cancel, err := s.monitorSpend(
outpoint, input.input.SignDesc().Output.PkScript,
input.input.HeightHint(),
)
if err != nil {
err := fmt.Errorf("wait for spend: %w", err)
s.markInputFatal(pi, nil, err)
return err
}
pi.ntfnRegCancel = cancel
return nil
}
// decideRBFInfo queries the mempool to see whether the given input has already
// been spent. When spent, it will query the sweeper store to fetch the fee info
// of the spending transction, and construct an RBFInfo based on it. Suppose an
// error occurs, fn.None is returned.
func (s *UtxoSweeper) decideRBFInfo(
op wire.OutPoint) fn.Option[RBFInfo] {
// Check if we can find the spending tx of this input in mempool.
txOption := s.mempoolLookup(op)
// Extract the spending tx from the option.
var tx *wire.MsgTx
txOption.WhenSome(func(t wire.MsgTx) {
tx = &t
})
// Exit early if it's not found.
//
// NOTE: this is not accurate for backends that don't support mempool
// lookup:
// - for neutrino we don't have a mempool.
// - for btcd below v0.24.1 we don't have `gettxspendingprevout`.
if tx == nil {
return fn.None[RBFInfo]()
}
// Otherwise the input is already spent in the mempool, so eventually
// we will return Published.
//
// We also need to update the RBF info for this input. If the sweeping
// transaction is broadcast by us, we can find the fee info in the
// sweeper store.
txid := tx.TxHash()
tr, err := s.cfg.Store.GetTx(txid)
log.Debugf("Found spending tx %v in mempool for input %v", tx.TxHash(),
op)
// If the tx is not found in the store, it means it's not broadcast by
// us, hence we can't find the fee info. This is fine as, later on when
// this tx is confirmed, we will remove the input from our inputs.
if errors.Is(err, ErrTxNotFound) {
log.Warnf("Spending tx %v not found in sweeper store", txid)
return fn.None[RBFInfo]()
}
// Exit if we get an db error.
if err != nil {
log.Errorf("Unable to get tx %v from sweeper store: %v",
txid, err)
return fn.None[RBFInfo]()
}
// Prepare the fee info and return it.
rbf := fn.Some(RBFInfo{
Txid: txid,
Fee: btcutil.Amount(tr.Fee),
FeeRate: chainfee.SatPerKWeight(tr.FeeRate),
})
return rbf
}
// handleExistingInput processes an input that is already known to the sweeper.
// It will overwrite the params of the old input with the new ones.
func (s *UtxoSweeper) handleExistingInput(input *sweepInputMessage,
oldInput *SweeperInput) {
// Before updating the input details, check if an exclusive group was
// set. In case the same input is registered again without an exclusive
// group set, the previous input and its sweep parameters are outdated
// hence need to be replaced. This scenario currently only happens for
// anchor outputs. When a channel is force closed, in the worst case 3
// different sweeps with the same exclusive group are registered with
// the sweeper to bump the closing transaction (cpfp) when its time
// critical. Receiving an input which was already registered with the
// sweeper but now without an exclusive group means non of the previous
// inputs were used as CPFP, so we need to make sure we update the
// sweep parameters but also remove all inputs with the same exclusive
// group because the are outdated too.
var prevExclGroup *uint64
if oldInput.params.ExclusiveGroup != nil &&
input.params.ExclusiveGroup == nil {
prevExclGroup = new(uint64)
*prevExclGroup = *oldInput.params.ExclusiveGroup
}
// Update input details and sweep parameters. The re-offered input
// details may contain a change to the unconfirmed parent tx info.
oldInput.params = input.params
oldInput.Input = input.input
// If the new input specifies a deadline, update the deadline height.
oldInput.DeadlineHeight = input.params.DeadlineHeight.UnwrapOr(
oldInput.DeadlineHeight,
)
// Add additional result channel to signal spend of this input.
oldInput.listeners = append(oldInput.listeners, input.resultChan)
if prevExclGroup != nil {
s.removeExclusiveGroup(*prevExclGroup)
}
}
// handleInputSpent takes a spend event of our input and updates the sweeper's
// internal state to remove the input.
func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) {
// Query store to find out if we ever published this tx.
spendHash := *spend.SpenderTxHash
isOurTx := s.cfg.Store.IsOurTx(spendHash)
// If this isn't our transaction, it means someone else swept outputs
// that we were attempting to sweep. This can happen for anchor outputs
// as well as justice transactions. In this case, we'll notify the
// wallet to remove any spends that descent from this output.
if !isOurTx {
// Construct a map of the inputs this transaction spends.
spendingTx := spend.SpendingTx
inputsSpent := make(
map[wire.OutPoint]struct{}, len(spendingTx.TxIn),
)
for _, txIn := range spendingTx.TxIn {
inputsSpent[txIn.PreviousOutPoint] = struct{}{}
}
log.Debugf("Attempting to remove descendant txns invalidated "+
"by (txid=%v): %v", spendingTx.TxHash(),
spew.Sdump(spendingTx))
err := s.removeConflictSweepDescendants(inputsSpent)
if err != nil {
log.Warnf("unable to remove descendant transactions "+
"due to tx %v: ", spendHash)
}
log.Debugf("Detected third party spend related to in flight "+
"inputs (is_ours=%v): %v", isOurTx,
lnutils.SpewLogClosure(spend.SpendingTx))
}
// We now use the spending tx to update the state of the inputs.
s.markInputsSwept(spend.SpendingTx, isOurTx)
}
// markInputsSwept marks all inputs swept by the spending transaction as swept.
// It will also notify all the subscribers of this input.
func (s *UtxoSweeper) markInputsSwept(tx *wire.MsgTx, isOurTx bool) {
for _, txIn := range tx.TxIn {
outpoint := txIn.PreviousOutPoint
// Check if this input is known to us. It could probably be
// unknown if we canceled the registration, deleted from inputs
// map but the ntfn was in-flight already. Or this could be not
// one of our inputs.
input, ok := s.inputs[outpoint]
if !ok {
// It's very likely that a spending tx contains inputs
// that we don't know.
log.Tracef("Skipped marking input as swept: %v not "+
"found in pending inputs", outpoint)
continue
}
// This input may already been marked as swept by a previous
// spend notification, which is likely to happen as one sweep
// transaction usually sweeps multiple inputs.
if input.terminated() {
log.Debugf("Skipped marking input as swept: %v "+
"state=%v", outpoint, input.state)
continue
}
input.state = Swept
// Return either a nil or a remote spend result.
var err error
if !isOurTx {
log.Warnf("Input=%v was spent by remote or third "+
"party in tx=%v", outpoint, tx.TxHash())
err = ErrRemoteSpend
}
// Signal result channels.
s.signalResult(input, Result{
Tx: tx,
Err: err,
})
// Remove all other inputs in this exclusive group.
if input.params.ExclusiveGroup != nil {
s.removeExclusiveGroup(*input.params.ExclusiveGroup)
}
}
}
// markInputFatal marks the given input as fatal and won't be retried. It
// will also notify all the subscribers of this input.
func (s *UtxoSweeper) markInputFatal(pi *SweeperInput, tx *wire.MsgTx,
err error) {
log.Errorf("Failed to sweep input: %v, error: %v", pi, err)
pi.state = Fatal
s.signalResult(pi, Result{
Tx: tx,
Err: err,
})
}
// updateSweeperInputs updates the sweeper's internal state and returns a map
// of inputs to be swept. It will remove the inputs that are in final states,
// and returns a map of inputs that have either state Init or PublishFailed.
func (s *UtxoSweeper) updateSweeperInputs() InputsMap {
// Create a map of inputs to be swept.
inputs := make(InputsMap)
// Iterate the pending inputs and update the sweeper's state.
//
// TODO(yy): sweeper is made to communicate via go channels, so no
// locks are needed to access the map. However, it'd be safer if we
// turn this inputs map into a SyncMap in case we wanna add concurrent
// access to the map in the future.
for op, input := range s.inputs {
log.Tracef("Checking input: %s, state=%v", input, input.state)
// If the input has reached a final state, that it's either
// been swept, or failed, or excluded, we will remove it from
// our sweeper.
if input.terminated() {
log.Debugf("Removing input(State=%v) %v from sweeper",
input.state, op)
delete(s.inputs, op)
continue
}
// If this input has been included in a sweep tx that's not
// published yet, we'd skip this input and wait for the sweep
// tx to be published.
if input.state == PendingPublish {
continue
}
// If this input has already been published, we will need to
// check the RBF condition before attempting another sweeping.
if input.state == Published {
continue
}
// If the input has a locktime that's not yet reached, we will
// skip this input and wait for the locktime to be reached.
mature, _ := input.isMature(uint32(s.currentHeight))
if !mature {
continue
}
// If this input is new or has been failed to be published,
// we'd retry it. The assumption here is that when an error is
// returned from `PublishTransaction`, it means the tx has
// failed to meet the policy, hence it's not in the mempool.
inputs[op] = input
}
return inputs
}
// sweepPendingInputs is called when the ticker fires. It will create clusters
// and attempt to create and publish the sweeping transactions.
func (s *UtxoSweeper) sweepPendingInputs(inputs InputsMap) {
log.Debugf("Sweeping %v inputs", len(inputs))
// Cluster all of our inputs based on the specific Aggregator.
sets := s.cfg.Aggregator.ClusterInputs(inputs)
// sweepWithLock is a helper closure that executes the sweep within a
// coin select lock to prevent the coins being selected for other
// transactions like funding of a channel.
sweepWithLock := func(set InputSet) error {
return s.cfg.Wallet.WithCoinSelectLock(func() error {
// Try to add inputs from our wallet.
err := set.AddWalletInputs(s.cfg.Wallet)
if err != nil {
return err
}
// Create sweeping transaction for each set.
err = s.sweep(set)
if err != nil {
return err
}
return nil
})
}
for _, set := range sets {
var err error
if set.NeedWalletInput() {
// Sweep the set of inputs that need the wallet inputs.
err = sweepWithLock(set)
} else {
// Sweep the set of inputs that don't need the wallet
// inputs.
err = s.sweep(set)
}
if err != nil {
log.Errorf("Failed to sweep %v: %v", set, err)
}
}
}
// bumpResp wraps the result of a bump attempt returned from the fee bumper and
// the inputs being used.
type bumpResp struct {
// result is the result of the bump attempt returned from the fee
// bumper.
result *BumpResult
// set is the input set that was used in the bump attempt.
set InputSet
}
// monitorFeeBumpResult subscribes to the passed result chan to listen for
// future updates about the sweeping tx.
//
// NOTE: must run as a goroutine.
func (s *UtxoSweeper) monitorFeeBumpResult(set InputSet,
resultChan <-chan *BumpResult) {
defer s.wg.Done()
for {
select {
case r := <-resultChan:
// Validate the result is valid.
if err := r.Validate(); err != nil {
log.Errorf("Received invalid result: %v", err)
continue
}
resp := &bumpResp{
result: r,
set: set,
}
// Send the result back to the main event loop.
select {
case s.bumpRespChan <- resp:
case <-s.quit:
log.Debug("Sweeper shutting down, skip " +
"sending bump result")
return
}
// The sweeping tx has been confirmed, we can exit the
// monitor now.
//
// TODO(yy): can instead remove the spend subscription
// in sweeper and rely solely on this event to mark
// inputs as Swept?
if r.Event == TxConfirmed || r.Event == TxFailed {
// Exit if the tx is failed to be created.
if r.Tx == nil {
log.Debugf("Received %v for nil tx, "+
"exit monitor", r.Event)
return
}
log.Debugf("Received %v for sweep tx %v, exit "+
"fee bump monitor", r.Event,
r.Tx.TxHash())
// Cancel the rebroadcasting of the failed tx.
s.cfg.Wallet.CancelRebroadcast(r.Tx.TxHash())
return
}
case <-s.quit:
log.Debugf("Sweeper shutting down, exit fee " +
"bump handler")
return
}
}
}
// handleBumpEventTxFailed handles the case where the tx has been failed to
// publish.
func (s *UtxoSweeper) handleBumpEventTxFailed(resp *bumpResp) {
r := resp.result
tx, err := r.Tx, r.Err
if tx != nil {
log.Warnf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(),
err)
}
// NOTE: When marking the inputs as failed, we are using the input set
// instead of the inputs found in the tx. This is fine for current
// version of the sweeper because we always create a tx using ALL of
// the inputs specified by the set.
//
// TODO(yy): should we also remove the failed tx from db?
s.markInputsPublishFailed(resp.set, resp.result.FeeRate)
}
// handleBumpEventTxReplaced handles the case where the sweeping tx has been
// replaced by a new one.
func (s *UtxoSweeper) handleBumpEventTxReplaced(resp *bumpResp) error {
r := resp.result
oldTx := r.ReplacedTx
newTx := r.Tx
// Prepare a new record to replace the old one.
tr := &TxRecord{
Txid: newTx.TxHash(),
FeeRate: uint64(r.FeeRate),
Fee: uint64(r.Fee),
}
// Get the old record for logging purpose.
oldTxid := oldTx.TxHash()
record, err := s.cfg.Store.GetTx(oldTxid)
if err != nil {
log.Errorf("Fetch tx record for %v: %v", oldTxid, err)
return err
}
// Cancel the rebroadcasting of the replaced tx.
s.cfg.Wallet.CancelRebroadcast(oldTxid)
log.Infof("RBFed tx=%v(fee=%v sats, feerate=%v sats/kw) with new "+
"tx=%v(fee=%v sats, feerate=%v sats/kw)", record.Txid,
record.Fee, record.FeeRate, tr.Txid, tr.Fee, tr.FeeRate)
// The old sweeping tx has been replaced by a new one, we will update
// the tx record in the sweeper db.
//
// TODO(yy): we may also need to update the inputs in this tx to a new
// state. Suppose a replacing tx only spends a subset of the inputs
// here, we'd end up with the rest being marked as `Published` and
// won't be aggregated in the next sweep. Atm it's fine as we always
// RBF the same input set.
if err := s.cfg.Store.DeleteTx(oldTxid); err != nil {
log.Errorf("Delete tx record for %v: %v", oldTxid, err)
return err
}
// Mark the inputs as published using the replacing tx.
return s.markInputsPublished(tr, resp.set)
}
// handleBumpEventTxPublished handles the case where the sweeping tx has been
// successfully published.
func (s *UtxoSweeper) handleBumpEventTxPublished(resp *bumpResp) error {
r := resp.result
tx := r.Tx
tr := &TxRecord{
Txid: tx.TxHash(),
FeeRate: uint64(r.FeeRate),
Fee: uint64(r.Fee),
}
// Inputs have been successfully published so we update their
// states.
err := s.markInputsPublished(tr, resp.set)
if err != nil {
return err
}
log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v",
tx.TxHash(), len(tx.TxIn), s.currentHeight)
// If there's no error, remove the output script. Otherwise keep it so
// that it can be reused for the next transaction and causes no address
// inflation.
s.currentOutputScript = fn.None[lnwallet.AddrWithKey]()
return nil
}
// handleBumpEventTxFatal handles the case where there's an unexpected error
// when creating or publishing the sweeping tx. In this case, the tx will be
// removed from the sweeper store and the inputs will be marked as `Failed`,
// which means they will not be retried.
func (s *UtxoSweeper) handleBumpEventTxFatal(resp *bumpResp) error {
r := resp.result
// Remove the tx from the sweeper store if there is one. Since this is
// a broadcast error, it's likely there isn't a tx here.
if r.Tx != nil {
txid := r.Tx.TxHash()
log.Infof("Tx=%v failed with unexpected error: %v", txid, r.Err)
// Remove the tx from the sweeper db if it exists.
if err := s.cfg.Store.DeleteTx(txid); err != nil {
return fmt.Errorf("delete tx record for %v: %w", txid,
err)
}
}
// Mark the inputs as fatal.
s.markInputsFatal(resp.set, r.Err)
return nil
}
// markInputsFatal marks all inputs in the input set as failed. It will also
// notify all the subscribers of these inputs.
func (s *UtxoSweeper) markInputsFatal(set InputSet, err error) {
for _, inp := range set.Inputs() {
outpoint := inp.OutPoint()
input, ok := s.inputs[outpoint]
if !ok {
// It's very likely that a spending tx contains inputs
// that we don't know.
log.Tracef("Skipped marking input as failed: %v not "+
"found in pending inputs", outpoint)
continue
}
// If the input is already in a terminal state, we don't want
// to rewrite it, which also indicates an error as we only get
// an error event during the initial broadcast.
if input.terminated() {
log.Errorf("Skipped marking input=%v as failed due to "+
"unexpected state=%v", outpoint, input.state)
continue
}
s.markInputFatal(input, nil, err)
}
}
// handleBumpEvent handles the result sent from the bumper based on its event
// type.
//
// NOTE: TxConfirmed event is not handled, since we already subscribe to the
// input's spending event, we don't need to do anything here.
func (s *UtxoSweeper) handleBumpEvent(r *bumpResp) error {
log.Debugf("Received bump result %v", r.result)
switch r.result.Event {
// The tx has been published, we update the inputs' state and create a
// record to be stored in the sweeper db.
case TxPublished:
return s.handleBumpEventTxPublished(r)
// The tx has failed, we update the inputs' state.
case TxFailed:
s.handleBumpEventTxFailed(r)
return nil
// The tx has been replaced, we will remove the old tx and replace it
// with the new one.
case TxReplaced:
return s.handleBumpEventTxReplaced(r)
// There are inputs being spent in a tx which the fee bumper doesn't
// understand. We will remove the tx from the sweeper db and mark the
// inputs as swept.
case TxUnknownSpend:
s.handleBumpEventTxUnknownSpend(r)
// There's a fatal error in creating the tx, we will remove the tx from
// the sweeper db and mark the inputs as failed.
case TxFatal:
return s.handleBumpEventTxFatal(r)
}
return nil
}
// IsSweeperOutpoint determines whether the outpoint was created by the sweeper.
//
// NOTE: It is enough to check the txid because the sweeper will create
// outpoints which solely belong to the internal LND wallet.
func (s *UtxoSweeper) IsSweeperOutpoint(op wire.OutPoint) bool {
return s.cfg.Store.IsOurTx(op.Hash)
}
// markInputSwept marks the given input as swept by the tx. It will also notify
// all the subscribers of this input.
func (s *UtxoSweeper) markInputSwept(inp *SweeperInput, tx *wire.MsgTx) {
log.Debugf("Marking input as swept: %v from state=%v", inp.OutPoint(),
inp.state)
inp.state = Swept
// Signal result channels.
s.signalResult(inp, Result{
Tx: tx,
})
// Remove all other inputs in this exclusive group.
if inp.params.ExclusiveGroup != nil {
s.removeExclusiveGroup(*inp.params.ExclusiveGroup)
}
}
// handleUnknownSpendTx takes an input and its spending tx. If the spending tx
// cannot be found in the sweeper store, the input will be marked as fatal,
// otherwise it will be marked as swept.
func (s *UtxoSweeper) handleUnknownSpendTx(inp *SweeperInput, tx *wire.MsgTx) {
op := inp.OutPoint()
txid := tx.TxHash()
isOurTx := s.cfg.Store.IsOurTx(txid)
// If this is our tx, it means it's a previous sweeping tx that got
// confirmed, which could happen when a restart happens during the
// sweeping process.
if isOurTx {
log.Debugf("Found our sweeping tx %v, marking input %v as "+
"swept", txid, op)
// We now use the spending tx to update the state of the inputs.
s.markInputSwept(inp, tx)
return
}
// Since the input is spent by others, we now mark it as fatal and won't
// be retried.
s.markInputFatal(inp, tx, ErrRemoteSpend)
log.Debugf("Removing descendant txns invalidated by (txid=%v): %v",
txid, lnutils.SpewLogClosure(tx))
// Construct a map of the inputs this transaction spends.
spentInputs := make(map[wire.OutPoint]struct{}, len(tx.TxIn))
for _, txIn := range tx.TxIn {
spentInputs[txIn.PreviousOutPoint] = struct{}{}
}
err := s.removeConflictSweepDescendants(spentInputs)
if err != nil {
log.Warnf("unable to remove descendant transactions "+
"due to tx %v: ", txid)
}
}
// handleBumpEventTxUnknownSpend handles the case where the confirmed tx is
// unknown to the fee bumper. In the case when the sweeping tx has been replaced
// by another party with their tx being confirmed. It will retry sweeping the
// "good" inputs once the "bad" ones are kicked out.
func (s *UtxoSweeper) handleBumpEventTxUnknownSpend(r *bumpResp) {
// Mark the inputs as publish failed, which means they will be retried
// later.
s.markInputsPublishFailed(r.set, r.result.FeeRate)
// Get all the inputs that are not spent in the current sweeping tx.
spentInputs := r.result.SpentInputs
// Create a slice to track inputs to be retried.
inputsToRetry := make([]input.Input, 0, len(r.set.Inputs()))
// Iterate all the inputs found in this bump and mark the ones spent by
// the third party as failed. The rest of inputs will then be updated
// with a new fee rate and be retried immediately.
for _, inp := range r.set.Inputs() {
op := inp.OutPoint()
input, ok := s.inputs[op]
// Wallet inputs are not tracked so we will not find them from
// the inputs map.
if !ok {
log.Debugf("Skipped marking input: %v not found in "+
"pending inputs", op)
continue
}
// Check whether this input has been spent, if so we mark it as
// fatal or swept based on whether this is one of our previous
// sweeping txns, then move to the next.
tx, spent := spentInputs[op]
if spent {
s.handleUnknownSpendTx(input, tx)
continue
}
log.Debugf("Input(%v): updating params: immediate [%v -> true]",
op, r.result.FeeRate, input.params.Immediate)
input.params.Immediate = true
inputsToRetry = append(inputsToRetry, input)
}
// Exit early if there are no inputs to be retried.
if len(inputsToRetry) == 0 {
return
}
log.Debugf("Retry sweeping inputs with updated params: %v",
inputTypeSummary(inputsToRetry))
// Get the latest inputs, which should put the PublishFailed inputs back
// to the sweeping queue.
inputs := s.updateSweeperInputs()
// Immediately sweep the remaining inputs - the previous inputs should
// now be swept with the updated StartingFeeRate immediately. We may
// also include more inputs in the new sweeping tx if new ones with the
// same deadline are offered.
s.sweepPendingInputs(inputs)
}
package sweep
import (
"fmt"
"sync"
"testing"
"time"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
)
var (
defaultTestTimeout = 5 * time.Second
processingDelay = 1 * time.Second
mockChainHash, _ = chainhash.NewHashFromStr("00aabbccddeeff")
mockChainHeight = int32(100)
)
// MockNotifier simulates the chain notifier for test purposes. This type is
// exported because it is used in nursery tests.
type MockNotifier struct {
confChannel map[chainhash.Hash]chan *chainntnfs.TxConfirmation
epochChan map[chan *chainntnfs.BlockEpoch]int32
spendChan map[wire.OutPoint][]chan *chainntnfs.SpendDetail
spends map[wire.OutPoint]*wire.MsgTx
mutex sync.RWMutex
t *testing.T
}
// NewMockNotifier instantiates a new mock notifier.
func NewMockNotifier(t *testing.T) *MockNotifier {
return &MockNotifier{
confChannel: make(map[chainhash.Hash]chan *chainntnfs.TxConfirmation),
epochChan: make(map[chan *chainntnfs.BlockEpoch]int32),
spendChan: make(map[wire.OutPoint][]chan *chainntnfs.SpendDetail),
spends: make(map[wire.OutPoint]*wire.MsgTx),
t: t,
}
}
// NotifyEpochNonBlocking simulates a new epoch arriving without blocking when
// the epochChan is not read.
func (m *MockNotifier) NotifyEpochNonBlocking(height int32) {
m.t.Helper()
for epochChan, chanHeight := range m.epochChan {
// Only send notifications if the height is greater than the
// height the caller passed into the register call.
if chanHeight >= height {
continue
}
log.Debugf("Notifying height %v to listener", height)
select {
case epochChan <- &chainntnfs.BlockEpoch{Height: height}:
default:
}
}
}
// NotifyEpoch simulates a new epoch arriving.
func (m *MockNotifier) NotifyEpoch(height int32) {
m.t.Helper()
for epochChan, chanHeight := range m.epochChan {
// Only send notifications if the height is greater than the
// height the caller passed into the register call.
if chanHeight >= height {
continue
}
log.Debugf("Notifying height %v to listener", height)
select {
case epochChan <- &chainntnfs.BlockEpoch{
Height: height,
}:
case <-time.After(defaultTestTimeout):
m.t.Fatal("epoch event not consumed")
}
}
}
// ConfirmTx simulates a tx confirming.
func (m *MockNotifier) ConfirmTx(txid *chainhash.Hash, height uint32) error {
confirm := &chainntnfs.TxConfirmation{
BlockHeight: height,
}
select {
case m.getConfChannel(txid) <- confirm:
case <-time.After(defaultTestTimeout):
return fmt.Errorf("confirmation not consumed")
}
return nil
}
// SpendOutpoint simulates a utxo being spent.
func (m *MockNotifier) SpendOutpoint(outpoint wire.OutPoint,
spendingTx wire.MsgTx) {
log.Debugf("Spending outpoint %v", outpoint)
m.mutex.Lock()
defer m.mutex.Unlock()
channels, ok := m.spendChan[outpoint]
if ok {
for _, channel := range channels {
m.sendSpend(channel, &outpoint, &spendingTx)
}
}
m.spends[outpoint] = &spendingTx
}
func (m *MockNotifier) sendSpend(channel chan *chainntnfs.SpendDetail,
outpoint *wire.OutPoint,
spendingTx *wire.MsgTx) {
log.Debugf("Notifying spend of outpoint %v", outpoint)
spenderTxHash := spendingTx.TxHash()
channel <- &chainntnfs.SpendDetail{
SpenderTxHash: &spenderTxHash,
SpendingTx: spendingTx,
SpentOutPoint: outpoint,
}
}
// RegisterConfirmationsNtfn registers for tx confirm notifications.
func (m *MockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
_ []byte, numConfs, heightHint uint32,
opt ...chainntnfs.NotifierOption) (*chainntnfs.ConfirmationEvent, error) {
return &chainntnfs.ConfirmationEvent{
Confirmed: m.getConfChannel(txid),
}, nil
}
func (m *MockNotifier) getConfChannel(
txid *chainhash.Hash) chan *chainntnfs.TxConfirmation {
m.mutex.Lock()
defer m.mutex.Unlock()
channel, ok := m.confChannel[*txid]
if ok {
return channel
}
channel = make(chan *chainntnfs.TxConfirmation)
m.confChannel[*txid] = channel
return channel
}
// RegisterBlockEpochNtfn registers a block notification.
func (m *MockNotifier) RegisterBlockEpochNtfn(
bestBlock *chainntnfs.BlockEpoch) (*chainntnfs.BlockEpochEvent, error) {
log.Tracef("Mock block ntfn registered")
m.mutex.Lock()
epochChan := make(chan *chainntnfs.BlockEpoch, 1)
// The real notifier returns a notification with the current block hash
// and height immediately if no best block hash or height is specified
// in the request. We want to emulate this behaviour as well for the
// mock.
switch {
case bestBlock == nil:
epochChan <- &chainntnfs.BlockEpoch{
Hash: mockChainHash,
Height: mockChainHeight,
}
m.epochChan[epochChan] = mockChainHeight
default:
m.epochChan[epochChan] = bestBlock.Height
}
m.mutex.Unlock()
return &chainntnfs.BlockEpochEvent{
Epochs: epochChan,
Cancel: func() {
log.Tracef("Mock block ntfn canceled")
m.mutex.Lock()
delete(m.epochChan, epochChan)
m.mutex.Unlock()
},
}, nil
}
// Start the notifier.
func (m *MockNotifier) Start() error {
return nil
}
// Started checks if started.
func (m *MockNotifier) Started() bool {
return true
}
// Stop the notifier.
func (m *MockNotifier) Stop() error {
return nil
}
// RegisterSpendNtfn registers for spend notifications.
func (m *MockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
_ []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
log.Debugf("RegisterSpendNtfn for outpoint %v", outpoint)
// Add channel to global spend ntfn map.
m.mutex.Lock()
channels, ok := m.spendChan[*outpoint]
if !ok {
channels = make([]chan *chainntnfs.SpendDetail, 0)
}
channel := make(chan *chainntnfs.SpendDetail, 1)
channels = append(channels, channel)
m.spendChan[*outpoint] = channels
// Check if this output has already been spent.
spendingTx, spent := m.spends[*outpoint]
m.mutex.Unlock()
// If output has been spent already, signal now. Do this outside the
// lock to prevent a deadlock.
if spent {
m.sendSpend(channel, outpoint, spendingTx)
}
return &chainntnfs.SpendEvent{
Spend: channel,
Cancel: func() {
log.Infof("Cancelling RegisterSpendNtfn for %v",
outpoint)
m.mutex.Lock()
defer m.mutex.Unlock()
channels := m.spendChan[*outpoint]
for i, c := range channels {
if c == channel {
channels[i] = channels[len(channels)-1]
m.spendChan[*outpoint] =
channels[:len(channels)-1]
}
}
close(channel)
log.Infof("Spend ntfn channel closed for %v",
outpoint)
},
}, nil
}
package sweep
import (
"fmt"
"math"
"sort"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
// ErrNotEnoughInputs is returned when there are not enough wallet
// inputs to construct a non-dust change output for an input set.
ErrNotEnoughInputs = fmt.Errorf("not enough inputs")
// ErrDeadlinesMismatch is returned when the deadlines of the input
// sets do not match.
ErrDeadlinesMismatch = fmt.Errorf("deadlines mismatch")
// ErrDustOutput is returned when the output value is below the dust
// limit.
ErrDustOutput = fmt.Errorf("dust output")
)
// InputSet defines an interface that's responsible for filtering a set of
// inputs that can be swept economically.
type InputSet interface {
// Inputs returns the set of inputs that should be used to create a tx.
Inputs() []input.Input
// AddWalletInputs adds wallet inputs to the set until a non-dust
// change output can be made. Return an error if there are not enough
// wallet inputs.
AddWalletInputs(wallet Wallet) error
// NeedWalletInput returns true if the input set needs more wallet
// inputs.
NeedWalletInput() bool
// DeadlineHeight returns an absolute block height to express the
// time-sensitivity of the input set. The outputs from a force close tx
// have different time preferences:
// - to_local: no time pressure as it can only be swept by us.
// - first level outgoing HTLC: must be swept before its corresponding
// incoming HTLC's CLTV is reached.
// - first level incoming HTLC: must be swept before its CLTV is
// reached.
// - second level HTLCs: no time pressure.
// - anchor: for CPFP-purpose anchor, it must be swept before any of
// the above CLTVs is reached. For non-CPFP purpose anchor, there's
// no time pressure.
DeadlineHeight() int32
// Budget givens the total amount that can be used as fees by this
// input set.
Budget() btcutil.Amount
// StartingFeeRate returns the max starting fee rate found in the
// inputs.
StartingFeeRate() fn.Option[chainfee.SatPerKWeight]
// Immediate returns a boolean to indicate whether the tx made from
// this input set should be published immediately.
//
// TODO(yy): create a new method `Params` to combine the informational
// methods DeadlineHeight, Budget, StartingFeeRate and Immediate.
Immediate() bool
}
// createWalletTxInput converts a wallet utxo into an object that can be added
// to the other inputs to sweep.
func createWalletTxInput(utxo *lnwallet.Utxo) (input.Input, error) {
signDesc := &input.SignDescriptor{
Output: &wire.TxOut{
PkScript: utxo.PkScript,
Value: int64(utxo.Value),
},
HashType: txscript.SigHashAll,
}
var witnessType input.WitnessType
switch utxo.AddressType {
case lnwallet.WitnessPubKey:
witnessType = input.WitnessKeyHash
case lnwallet.NestedWitnessPubKey:
witnessType = input.NestedWitnessKeyHash
case lnwallet.TaprootPubkey:
witnessType = input.TaprootPubKeySpend
signDesc.HashType = txscript.SigHashDefault
default:
return nil, fmt.Errorf("unknown address type %v",
utxo.AddressType)
}
// A height hint doesn't need to be set, because we don't monitor these
// inputs for spend.
heightHint := uint32(0)
return input.NewBaseInput(
&utxo.OutPoint, witnessType, signDesc, heightHint,
), nil
}
// BudgetInputSet implements the interface `InputSet`. It takes a list of
// pending inputs which share the same deadline height and groups them into a
// set conditionally based on their economical values.
type BudgetInputSet struct {
// inputs is the set of inputs that have been added to the set after
// considering their economical contribution.
inputs []*SweeperInput
// deadlineHeight is the height which the inputs in this set must be
// confirmed by.
deadlineHeight int32
// extraBudget is a value that should be allocated to sweep the given
// set of inputs. This can be used to add extra funds to the sweep
// transaction, for example to cover fees for additional outputs of
// custom channels.
extraBudget btcutil.Amount
}
// Compile-time constraint to ensure budgetInputSet implements InputSet.
var _ InputSet = (*BudgetInputSet)(nil)
// errEmptyInputs is returned when the input slice is empty.
var errEmptyInputs = fmt.Errorf("inputs slice is empty")
// validateInputs is used when creating new BudgetInputSet to ensure there are
// no duplicate inputs and they all share the same deadline heights, if set.
func validateInputs(inputs []SweeperInput, deadlineHeight int32) error {
// Sanity check the input slice to ensure it's non-empty.
if len(inputs) == 0 {
return errEmptyInputs
}
// inputDeadline tracks the input's deadline height. It will be updated
// if the input has a different deadline than the specified
// deadlineHeight.
inputDeadline := deadlineHeight
// dedupInputs is a set used to track unique outpoints of the inputs.
dedupInputs := fn.NewSet(
// Iterate all the inputs and map the function.
fn.Map(inputs, func(inp SweeperInput) wire.OutPoint {
// If the input has a deadline height, we'll check if
// it's the same as the specified.
inp.params.DeadlineHeight.WhenSome(func(h int32) {
// Exit early if the deadlines matched.
if h == deadlineHeight {
return
}
// Update the deadline height if it's
// different.
inputDeadline = h
})
return inp.OutPoint()
})...,
)
// Make sure the inputs share the same deadline height when there is
// one.
if inputDeadline != deadlineHeight {
return fmt.Errorf("input deadline height not matched: want "+
"%d, got %d", deadlineHeight, inputDeadline)
}
// Provide a defensive check to ensure that we don't have any duplicate
// inputs within the set.
if len(dedupInputs) != len(inputs) {
return fmt.Errorf("duplicate inputs")
}
return nil
}
// NewBudgetInputSet creates a new BudgetInputSet.
func NewBudgetInputSet(inputs []SweeperInput, deadlineHeight int32,
auxSweeper fn.Option[AuxSweeper]) (*BudgetInputSet, error) {
// Validate the supplied inputs.
if err := validateInputs(inputs, deadlineHeight); err != nil {
return nil, err
}
bi := &BudgetInputSet{
deadlineHeight: deadlineHeight,
inputs: make([]*SweeperInput, 0, len(inputs)),
}
for _, input := range inputs {
bi.addInput(input)
}
log.Tracef("Created %v", bi.String())
// Attach an optional budget. This will be a no-op if the auxSweeper
// is not set.
if err := bi.attachExtraBudget(auxSweeper); err != nil {
return nil, err
}
return bi, nil
}
// attachExtraBudget attaches an extra budget to the input set, if the passed
// aux sweeper is set.
func (b *BudgetInputSet) attachExtraBudget(s fn.Option[AuxSweeper]) error {
extraBudget, err := fn.MapOptionZ(
s, func(aux AuxSweeper) fn.Result[btcutil.Amount] {
return aux.ExtraBudgetForInputs(b.Inputs())
},
).Unpack()
if err != nil {
return err
}
b.extraBudget = extraBudget
return nil
}
// String returns a human-readable description of the input set.
func (b *BudgetInputSet) String() string {
inputsDesc := ""
for _, input := range b.inputs {
inputsDesc += fmt.Sprintf("\n%v", input)
}
return fmt.Sprintf("BudgetInputSet(budget=%v, deadline=%v, "+
"inputs=[%v])", b.Budget(), b.DeadlineHeight(), inputsDesc)
}
// addInput adds an input to the input set.
func (b *BudgetInputSet) addInput(input SweeperInput) {
b.inputs = append(b.inputs, &input)
}
// addWalletInput takes a wallet UTXO and adds it as an input to be used as
// budget for the input set.
func (b *BudgetInputSet) addWalletInput(utxo *lnwallet.Utxo) error {
input, err := createWalletTxInput(utxo)
if err != nil {
return err
}
pi := SweeperInput{
Input: input,
params: Params{
DeadlineHeight: fn.Some(b.deadlineHeight),
},
}
b.addInput(pi)
log.Debugf("Added wallet input to input set: op=%v, amt=%v",
pi.OutPoint(), utxo.Value)
return nil
}
// NeedWalletInput returns true if the input set needs more wallet inputs.
//
// A set may need wallet inputs when it has a required output or its total
// value cannot cover its total budget.
func (b *BudgetInputSet) NeedWalletInput() bool {
var (
// budgetNeeded is the amount that needs to be covered from
// other inputs. We start at the value of the extra budget,
// which might be needed for custom channels that add extra
// outputs.
budgetNeeded = b.extraBudget
// budgetBorrowable is the amount that can be borrowed from
// other inputs.
budgetBorrowable btcutil.Amount
)
for _, inp := range b.inputs {
// If this input has a required output, we can assume it's a
// second-level htlc txns input. Although this input must have
// a value that can cover its budget, it cannot be used to pay
// fees. Instead, we need to borrow budget from other inputs to
// make the sweep happen. Once swept, the input value will be
// credited to the wallet.
if inp.RequiredTxOut() != nil {
budgetNeeded += inp.params.Budget
continue
}
// Get the amount left after covering the input's own budget.
// This amount can then be lent to the above input. For a wallet
// input, its `Budget` is set to zero, which means the whole
// input can be borrowed to cover the budget.
budget := inp.params.Budget
output := btcutil.Amount(inp.SignDesc().Output.Value)
budgetBorrowable += output - budget
// If the input's budget is not even covered by itself, we need
// to borrow outputs from other inputs.
if budgetBorrowable < 0 {
log.Tracef("Input %v specified a budget that exceeds "+
"its output value: %v > %v", inp, budget,
output)
}
}
log.Debugf("NeedWalletInput: budgetNeeded=%v, budgetBorrowable=%v",
budgetNeeded, budgetBorrowable)
// If we don't have enough extra budget to borrow, we need wallet
// inputs.
return budgetBorrowable < budgetNeeded
}
// hasNormalInput return a bool to indicate whether there exists an input that
// doesn't require a TxOut. When an input has no required outputs, it's either a
// wallet input, or an input we want to sweep.
func (b *BudgetInputSet) hasNormalInput() bool {
for _, inp := range b.inputs {
if inp.RequiredTxOut() != nil {
continue
}
return true
}
return false
}
// AddWalletInputs adds wallet inputs to the set until the specified budget is
// met. When sweeping inputs with required outputs, although there's budget
// specified, it cannot be directly spent from these required outputs. Instead,
// we need to borrow budget from other inputs to make the sweep happen.
// There are two sources to borrow from: 1) other inputs, 2) wallet utxos. If
// we are calling this method, it means other inputs cannot cover the specified
// budget, so we need to borrow from wallet utxos.
//
// Return an error if there are not enough wallet inputs, and the budget set is
// set to its initial state by removing any wallet inputs added.
//
// NOTE: must be called with the wallet lock held via `WithCoinSelectLock`.
func (b *BudgetInputSet) AddWalletInputs(wallet Wallet) error {
// Retrieve wallet utxos. Only consider confirmed utxos to prevent
// problems around RBF rules for unconfirmed inputs. This currently
// ignores the configured coin selection strategy.
utxos, err := wallet.ListUnspentWitnessFromDefaultAccount(
1, math.MaxInt32,
)
if err != nil {
return fmt.Errorf("list unspent witness: %w", err)
}
// Sort the UTXOs by putting smaller values at the start of the slice
// to avoid locking large UTXO for sweeping.
//
// TODO(yy): add more choices to CoinSelectionStrategy and use the
// configured value here.
sort.Slice(utxos, func(i, j int) bool {
return utxos[i].Value < utxos[j].Value
})
// Add wallet inputs to the set until the specified budget is covered.
for _, utxo := range utxos {
err := b.addWalletInput(utxo)
if err != nil {
return err
}
// Return if we've reached the minimum output amount.
if !b.NeedWalletInput() {
return nil
}
}
// Exit if there are no inputs can contribute to the fees.
if !b.hasNormalInput() {
return ErrNotEnoughInputs
}
// If there's at least one input that can contribute to fees, we allow
// the sweep to continue, even though the full budget can't be met.
// Maybe later more wallet inputs will become available and we can add
// them if needed.
budget := b.Budget()
total, spendable := b.inputAmts()
log.Warnf("Not enough wallet UTXOs: need budget=%v, has spendable=%v, "+
"total=%v, missing at least %v, sweeping anyway...", budget,
spendable, total, budget-spendable)
return nil
}
// Budget returns the total budget of the set.
//
// NOTE: part of the InputSet interface.
func (b *BudgetInputSet) Budget() btcutil.Amount {
budget := btcutil.Amount(0)
for _, input := range b.inputs {
budget += input.params.Budget
}
// We'll also tack on the extra budget which will eventually be
// accounted for by the wallet txns when we're broadcasting.
return budget + b.extraBudget
}
// DeadlineHeight returns the deadline height of the set.
//
// NOTE: part of the InputSet interface.
func (b *BudgetInputSet) DeadlineHeight() int32 {
return b.deadlineHeight
}
// Inputs returns the inputs that should be used to create a tx.
//
// NOTE: part of the InputSet interface.
func (b *BudgetInputSet) Inputs() []input.Input {
inputs := make([]input.Input, 0, len(b.inputs))
for _, inp := range b.inputs {
inputs = append(inputs, inp.Input)
}
return inputs
}
// inputAmts returns two values for the set - the total input amount, and the
// spendable amount. Only the spendable amount can be used to pay the fees.
func (b *BudgetInputSet) inputAmts() (btcutil.Amount, btcutil.Amount) {
var (
totalAmt btcutil.Amount
spendableAmt btcutil.Amount
)
for _, inp := range b.inputs {
output := btcutil.Amount(inp.SignDesc().Output.Value)
totalAmt += output
if inp.RequiredTxOut() != nil {
continue
}
spendableAmt += output
}
return totalAmt, spendableAmt
}
// StartingFeeRate returns the max starting fee rate found in the inputs.
//
// NOTE: part of the InputSet interface.
func (b *BudgetInputSet) StartingFeeRate() fn.Option[chainfee.SatPerKWeight] {
maxFeeRate := chainfee.SatPerKWeight(0)
startingFeeRate := fn.None[chainfee.SatPerKWeight]()
for _, inp := range b.inputs {
feerate := inp.params.StartingFeeRate.UnwrapOr(0)
if feerate > maxFeeRate {
maxFeeRate = feerate
startingFeeRate = fn.Some(maxFeeRate)
}
}
return startingFeeRate
}
// Immediate returns whether the inputs should be swept immediately.
//
// NOTE: part of the InputSet interface.
func (b *BudgetInputSet) Immediate() bool {
for _, inp := range b.inputs {
// As long as one of the inputs is immediate, the whole set is
// immediate.
if inp.params.Immediate {
return true
}
}
return false
}
package sweep
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
var (
// DefaultMaxInputsPerTx specifies the default maximum number of inputs
// allowed in a single sweep tx. If more need to be swept, multiple txes
// are created and published.
DefaultMaxInputsPerTx = uint32(100)
// ErrLocktimeConflict is returned when inputs with different
// transaction nLockTime values are included in the same transaction.
//
// NOTE: due the SINGLE|ANYONECANPAY sighash flag, which is used in the
// second level success/timeout txns, only the txns sharing the same
// nLockTime can exist in the same tx.
ErrLocktimeConflict = errors.New("incompatible locktime")
)
// createSweepTx builds a signed tx spending the inputs to the given outputs,
// sending any leftover change to the change script.
func createSweepTx(inputs []input.Input, outputs []*wire.TxOut,
changePkScript []byte, currentBlockHeight uint32,
feeRate, maxFeeRate chainfee.SatPerKWeight,
signer input.Signer) (*wire.MsgTx, btcutil.Amount, error) {
inputs, estimator, err := getWeightEstimate(
inputs, outputs, feeRate, maxFeeRate, [][]byte{changePkScript},
)
if err != nil {
return nil, 0, err
}
txFee := estimator.feeWithParent()
var (
// Create the sweep transaction that we will be building. We
// use version 2 as it is required for CSV.
sweepTx = wire.NewMsgTx(2)
// Track whether any of the inputs require a certain locktime.
locktime = int32(-1)
// We keep track of total input amount, and required output
// amount to use for calculating the change amount below.
totalInput btcutil.Amount
requiredOutput btcutil.Amount
// We'll add the inputs as we go so we know the final ordering
// of inputs to sign.
idxs []input.Input
)
// We start by adding all inputs that commit to an output. We do this
// since the input and output index must stay the same for the
// signatures to be valid.
for _, o := range inputs {
if o.RequiredTxOut() == nil {
continue
}
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: o.OutPoint(),
Sequence: o.BlocksToMaturity(),
})
sweepTx.AddTxOut(o.RequiredTxOut())
if lt, ok := o.RequiredLockTime(); ok {
// If another input commits to a different locktime,
// they cannot be combined in the same transaction.
if locktime != -1 && locktime != int32(lt) {
return nil, 0, ErrLocktimeConflict
}
locktime = int32(lt)
}
totalInput += btcutil.Amount(o.SignDesc().Output.Value)
requiredOutput += btcutil.Amount(o.RequiredTxOut().Value)
}
// Sum up the value contained in the remaining inputs, and add them to
// the sweep transaction.
for _, o := range inputs {
if o.RequiredTxOut() != nil {
continue
}
idxs = append(idxs, o)
sweepTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: o.OutPoint(),
Sequence: o.BlocksToMaturity(),
})
if lt, ok := o.RequiredLockTime(); ok {
if locktime != -1 && locktime != int32(lt) {
return nil, 0, ErrLocktimeConflict
}
locktime = int32(lt)
}
totalInput += btcutil.Amount(o.SignDesc().Output.Value)
}
// Add the outputs given, if any.
for _, o := range outputs {
sweepTx.AddTxOut(o)
requiredOutput += btcutil.Amount(o.Value)
}
if requiredOutput+txFee > totalInput {
return nil, 0, fmt.Errorf("insufficient input to create sweep "+
"tx: input_sum=%v, output_sum=%v", totalInput,
requiredOutput+txFee)
}
// The value remaining after the required output and fees, go to
// change. Not that this fee is what we would have to pay in case the
// sweep tx has a change output.
changeAmt := totalInput - requiredOutput - txFee
// We'll calculate the dust limit for the given changePkScript since it
// is variable.
changeLimit := lnwallet.DustLimitForSize(len(changePkScript))
// The txn will sweep the amount after fees to the pkscript generated
// above.
if changeAmt >= changeLimit {
sweepTx.AddTxOut(&wire.TxOut{
PkScript: changePkScript,
Value: int64(changeAmt),
})
} else {
log.Infof("Change amt %v below dustlimit %v, not adding "+
"change output", changeAmt, changeLimit)
// The dust amount is added to the fee as the miner will
// collect it.
txFee += changeAmt
}
// We'll default to using the current block height as locktime, if none
// of the inputs commits to a different locktime.
sweepTx.LockTime = currentBlockHeight
if locktime != -1 {
sweepTx.LockTime = uint32(locktime)
}
// Before signing the transaction, check to ensure that it meets some
// basic validity requirements.
//
// TODO(conner): add more control to sanity checks, allowing us to
// delay spending "problem" outputs, e.g. possibly batching with other
// classes if fees are too low.
btx := btcutil.NewTx(sweepTx)
if err := blockchain.CheckTransactionSanity(btx); err != nil {
return nil, 0, err
}
prevInputFetcher, err := input.MultiPrevOutFetcher(inputs)
if err != nil {
return nil, 0, fmt.Errorf("error creating prev input fetcher "+
"for hash cache: %v", err)
}
hashCache := txscript.NewTxSigHashes(sweepTx, prevInputFetcher)
// With all the inputs in place, use each output's unique input script
// function to generate the final witness required for spending.
addInputScript := func(idx int, tso input.Input) error {
inputScript, err := tso.CraftInputScript(
signer, sweepTx, hashCache, prevInputFetcher, idx,
)
if err != nil {
return err
}
sweepTx.TxIn[idx].Witness = inputScript.Witness
if len(inputScript.SigScript) != 0 {
sweepTx.TxIn[idx].SignatureScript =
inputScript.SigScript
}
return nil
}
for idx, inp := range idxs {
if err := addInputScript(idx, inp); err != nil {
return nil, 0, err
}
}
log.Debugf("Creating sweep transaction %v for %v inputs (%s) "+
"using %v, tx_weight=%v, tx_fee=%v, parents_count=%v, "+
"parents_fee=%v, parents_weight=%v, current_height=%v",
sweepTx.TxHash(), len(inputs),
inputTypeSummary(inputs), feeRate,
estimator.weight(), txFee,
len(estimator.parents), estimator.parentsFee,
estimator.parentsWeight, currentBlockHeight)
return sweepTx, txFee, nil
}
// getWeightEstimate returns a weight estimate for the given inputs.
// Additionally, it returns counts for the number of csv and cltv inputs.
func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut,
feeRate, maxFeeRate chainfee.SatPerKWeight,
outputPkScripts [][]byte) ([]input.Input, *weightEstimator, error) {
// We initialize a weight estimator so we can accurately asses the
// amount of fees we need to pay for this sweep transaction.
//
// TODO(roasbeef): can be more intelligent about buffering outputs to
// be more efficient on-chain.
weightEstimate := newWeightEstimator(feeRate, maxFeeRate)
// Our sweep transaction will always pay to the given set of outputs.
for _, o := range outputs {
weightEstimate.addOutput(o)
}
// If there is any leftover change after paying to the given outputs
// and required outputs, it will go to a single segwit p2wkh or p2tr
// address. This will be our change address, so ensure it contributes
// to our weight estimate. Note that if we have other outputs, we might
// end up creating a sweep tx without a change output. It is okay to
// add the change output to the weight estimate regardless, since the
// estimated fee will just be subtracted from this already dust output,
// and trimmed.
for _, outputPkScript := range outputPkScripts {
switch {
case txscript.IsPayToTaproot(outputPkScript):
weightEstimate.addP2TROutput()
case txscript.IsPayToWitnessScriptHash(outputPkScript):
weightEstimate.addP2WSHOutput()
case txscript.IsPayToWitnessPubKeyHash(outputPkScript):
weightEstimate.addP2WKHOutput()
case txscript.IsPayToPubKeyHash(outputPkScript):
weightEstimate.estimator.AddP2PKHOutput()
case txscript.IsPayToScriptHash(outputPkScript):
weightEstimate.estimator.AddP2SHOutput()
default:
// Unknown script type.
return nil, nil, fmt.Errorf("unknown script "+
"type: %x", outputPkScript)
}
}
// For each output, use its witness type to determine the estimate
// weight of its witness, and add it to the proper set of spendable
// outputs.
var sweepInputs []input.Input
for i := range inputs {
inp := inputs[i]
err := weightEstimate.add(inp)
if err != nil {
// TODO(yy): check if this is even possible? If so, we
// should return the error here instead of filtering!
log.Errorf("Failed to get weight estimate for "+
"input=%v, witnessType=%v: %v ", inp.OutPoint(),
inp.WitnessType(), err)
// Skip inputs for which no weight estimate can be
// given.
continue
}
// If this input comes with a committed output, add that as
// well.
if inp.RequiredTxOut() != nil {
weightEstimate.addOutput(inp.RequiredTxOut())
}
sweepInputs = append(sweepInputs, inp)
}
return sweepInputs, weightEstimate, nil
}
// inputSummary returns a string containing a human readable summary about the
// witness types of a list of inputs.
func inputTypeSummary(inputs []input.Input) string {
// Sort inputs by witness type.
sortedInputs := make([]input.Input, len(inputs))
copy(sortedInputs, inputs)
sort.Slice(sortedInputs, func(i, j int) bool {
return sortedInputs[i].WitnessType().String() <
sortedInputs[j].WitnessType().String()
})
var parts []string
for _, i := range sortedInputs {
part := fmt.Sprintf("%v (%v)", i.OutPoint(), i.WitnessType())
parts = append(parts, part)
}
return strings.Join(parts, "\n")
}
package sweep
import (
"errors"
"fmt"
"maps"
"math"
"slices"
"time"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwallet/chanfunding"
)
var (
// ErrNoFeePreference is returned when we attempt to satisfy a sweep
// request from a client whom did not specify a fee preference.
ErrNoFeePreference = errors.New("no fee preference specified")
// ErrFeePreferenceConflict is returned when both a fee rate and a conf
// target is set for a fee preference.
ErrFeePreferenceConflict = errors.New("fee preference conflict")
// ErrUnknownUTXO is returned when creating a sweeping tx using an UTXO
// that's unknown to the wallet.
ErrUnknownUTXO = errors.New("unknown utxo")
)
// FeePreference defines an interface that allows the caller to specify how the
// fee rate should be handled. Depending on the implementation, the fee rate
// can either be specified directly, or via a conf target which relies on the
// chain backend(`bitcoind`) to give a fee estimation, or a customized fee
// function which handles fee calculation based on the specified
// urgency(deadline).
type FeePreference interface {
// String returns a human-readable string of the fee preference.
String() string
// Estimate takes a fee estimator and a max allowed fee rate and
// returns a fee rate for the given fee preference. It ensures that the
// fee rate respects the bounds of the relay fee and the specified max
// fee rates.
Estimate(chainfee.Estimator,
chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error)
}
// FeeEstimateInfo allows callers to express their time value for inclusion of
// a transaction into a block via either a confirmation target, or a fee rate.
type FeeEstimateInfo struct {
// ConfTarget if non-zero, signals a fee preference expressed in the
// number of desired blocks between first broadcast, and confirmation.
ConfTarget uint32
// FeeRate if non-zero, signals a fee pre fence expressed in the fee
// rate expressed in sat/kw for a particular transaction.
FeeRate chainfee.SatPerKWeight
}
// Compile-time constraint to ensure FeeEstimateInfo implements FeePreference.
var _ FeePreference = (*FeeEstimateInfo)(nil)
// String returns a human-readable string of the fee preference.
func (f FeeEstimateInfo) String() string {
if f.ConfTarget != 0 {
return fmt.Sprintf("%v blocks", f.ConfTarget)
}
return f.FeeRate.String()
}
// Estimate returns a fee rate for the given fee preference. It ensures that
// the fee rate respects the bounds of the relay fee and the max fee rates, if
// specified.
func (f FeeEstimateInfo) Estimate(estimator chainfee.Estimator,
maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) {
var (
feeRate chainfee.SatPerKWeight
err error
)
switch {
// Ensure a type of fee preference is specified to prevent using a
// default below.
case f.FeeRate == 0 && f.ConfTarget == 0:
return 0, ErrNoFeePreference
// If both values are set, then we'll return an error as we require a
// strict directive.
case f.FeeRate != 0 && f.ConfTarget != 0:
return 0, ErrFeePreferenceConflict
// If the target number of confirmations is set, then we'll use that to
// consult our fee estimator for an adequate fee.
case f.ConfTarget != 0:
feeRate, err = estimator.EstimateFeePerKW((f.ConfTarget))
if err != nil {
return 0, fmt.Errorf("unable to query fee "+
"estimator: %w", err)
}
// If a manual sat/kw fee rate is set, then we'll use that directly.
// We'll need to convert it to sat/kw as this is what we use
// internally.
case f.FeeRate != 0:
feeRate = f.FeeRate
// Because the user can specify 1 sat/vByte on the RPC
// interface, which corresponds to 250 sat/kw, we need to bump
// that to the minimum "safe" fee rate which is 253 sat/kw.
if feeRate == chainfee.AbsoluteFeePerKwFloor {
log.Infof("Manual fee rate input of %d sat/kw is "+
"too low, using %d sat/kw instead", feeRate,
chainfee.FeePerKwFloor)
feeRate = chainfee.FeePerKwFloor
}
}
// Get the relay fee as the min fee rate.
minFeeRate := estimator.RelayFeePerKW()
// If that bumped fee rate of at least 253 sat/kw is still lower than
// the relay fee rate, we return an error to let the user know. Note
// that "Relay fee rate" may mean slightly different things depending
// on the backend. For bitcoind, it is effectively max(relay fee, min
// mempool fee).
if feeRate < minFeeRate {
return 0, fmt.Errorf("%w: got %v, minimum is %v",
ErrFeePreferenceTooLow, feeRate, minFeeRate)
}
// If a maxFeeRate is specified and the estimated fee rate is above the
// maximum allowed fee rate, default to the max fee rate.
if maxFeeRate != 0 && feeRate > maxFeeRate {
log.Warnf("Estimated fee rate %v exceeds max allowed fee "+
"rate %v, using max fee rate instead", feeRate,
maxFeeRate)
return maxFeeRate, nil
}
return feeRate, nil
}
// UtxoSource is an interface that allows a caller to access a source of UTXOs
// to use when crafting sweep transactions.
type UtxoSource interface {
// ListUnspentWitnessFromDefaultAccount returns all UTXOs from the
// default wallet account that have between minConfs and maxConfs
// number of confirmations.
ListUnspentWitnessFromDefaultAccount(minConfs, maxConfs int32) (
[]*lnwallet.Utxo, error)
}
// CoinSelectionLocker is an interface that allows the caller to perform an
// operation, which is synchronized with all coin selection attempts. This can
// be used when an operation requires that all coin selection operations cease
// forward progress. Think of this as an exclusive lock on coin selection
// operations.
type CoinSelectionLocker interface {
// WithCoinSelectLock will execute the passed function closure in a
// synchronized manner preventing any coin selection operations from
// proceeding while the closure is executing. This can be seen as the
// ability to execute a function closure under an exclusive coin
// selection lock.
WithCoinSelectLock(func() error) error
}
// OutputLeaser allows a caller to lease/release an output. When leased, the
// outputs shouldn't be used for any sort of channel funding or coin selection.
// Leased outputs are expected to be persisted between restarts.
type OutputLeaser interface {
// LeaseOutput leases a target output, rendering it unusable for coin
// selection.
LeaseOutput(i wtxmgr.LockID, o wire.OutPoint, d time.Duration) (
time.Time, error)
// ReleaseOutput releases a target output, allowing it to be used for
// coin selection once again.
ReleaseOutput(i wtxmgr.LockID, o wire.OutPoint) error
}
// WalletSweepPackage is a package that gives the caller the ability to sweep
// relevant funds from a wallet in a single transaction. We also package a
// function closure that allows one to abort the operation.
type WalletSweepPackage struct {
// SweepTx is a fully signed, and valid transaction that is broadcast,
// will sweep ALL relevant confirmed coins in the wallet with a single
// transaction.
SweepTx *wire.MsgTx
// CancelSweepAttempt allows the caller to cancel the sweep attempt.
//
// NOTE: If the sweeping transaction isn't or cannot be broadcast, then
// this closure MUST be called, otherwise all selected utxos will be
// unable to be used.
CancelSweepAttempt func()
}
// DeliveryAddr is a pair of (address, amount) used to craft a transaction
// paying to more than one specified address.
type DeliveryAddr struct {
// Addr is the address to pay to.
Addr btcutil.Address
// Amt is the amount to pay to the given address.
Amt btcutil.Amount
}
// CraftSweepAllTx attempts to craft a WalletSweepPackage which will allow the
// caller to sweep ALL funds in ALL or SELECT outputs within the wallet to a
// list of outputs. Any leftover amount after these outputs and transaction fee,
// is sent to a single output, as specified by the change address. The sweep
// transaction will be crafted with the target fee rate, and will use the
// utxoSource and outputLeaser as sources for wallet funds.
func CraftSweepAllTx(feeRate, maxFeeRate chainfee.SatPerKWeight,
blockHeight uint32, deliveryAddrs []DeliveryAddr,
changeAddr btcutil.Address, coinSelectLocker CoinSelectionLocker,
utxoSource UtxoSource, outputLeaser OutputLeaser,
signer input.Signer, minConfs int32,
selectUtxos fn.Set[wire.OutPoint]) (*WalletSweepPackage, error) {
// TODO(roasbeef): turn off ATPL as well when available?
var outputsForSweep []*lnwallet.Utxo
// We'll make a function closure up front that allows us to unlock all
// selected outputs to ensure that they become available again in the
// case of an error after the outputs have been locked, but before we
// can actually craft a sweeping transaction.
unlockOutputs := func() {
for _, utxo := range outputsForSweep {
// Log the error but continue since we're already
// handling an error.
err := outputLeaser.ReleaseOutput(
chanfunding.LndInternalLockID, utxo.OutPoint,
)
if err != nil {
log.Warnf("Failed to release UTXO %s (%v))",
utxo.OutPoint, err)
}
}
}
// Next, we'll use the coinSelectLocker to ensure that no coin
// selection takes place while we fetch and lock outputs in the
// wallet. Otherwise, it may be possible for a new funding flow to lock
// an output while we fetch the set of unspent witnesses.
err := coinSelectLocker.WithCoinSelectLock(func() error {
log.Trace("[WithCoinSelectLock] entered the lock")
// Now that we can be sure that no other coin selection
// operations are going on, we can grab a clean snapshot of the
// current UTXO state of the wallet.
utxos, err := utxoSource.ListUnspentWitnessFromDefaultAccount(
minConfs, math.MaxInt32,
)
if err != nil {
return err
}
log.Trace("[WithCoinSelectLock] finished fetching UTXOs")
// Use select utxos, if provided.
if len(selectUtxos) > 0 {
utxos, err = fetchUtxosFromOutpoints(
utxos, selectUtxos.ToSlice(),
)
if err != nil {
return err
}
}
// We'll now lock each UTXO to ensure that other callers don't
// attempt to use these UTXOs in transactions while we're
// crafting out sweep all transaction.
for _, utxo := range utxos {
log.Tracef("[WithCoinSelectLock] leasing utxo: %v",
utxo.OutPoint)
_, err = outputLeaser.LeaseOutput(
chanfunding.LndInternalLockID, utxo.OutPoint,
chanfunding.DefaultLockDuration,
)
if err != nil {
return err
}
}
log.Trace("[WithCoinSelectLock] exited the lock")
outputsForSweep = append(outputsForSweep, utxos...)
return nil
})
if err != nil {
// If we failed at all, we'll unlock any outputs selected just
// in case we had any lingering outputs.
unlockOutputs()
return nil, fmt.Errorf("unable to fetch+lock wallet utxos: %w",
err)
}
// Now that we've locked all the potential outputs to sweep, we'll
// assemble an input for each of them, so we can hand it off to the
// sweeper to generate and sign a transaction for us.
var inputsToSweep []input.Input
for _, output := range outputsForSweep {
// As we'll be signing for outputs under control of the wallet,
// we only need to populate the output value and output script.
// The rest of the items will be populated internally within
// the sweeper via the witness generation function.
signDesc := &input.SignDescriptor{
Output: &wire.TxOut{
PkScript: output.PkScript,
Value: int64(output.Value),
},
HashType: txscript.SigHashAll,
}
pkScript := output.PkScript
// Based on the output type, we'll map it to the proper witness
// type so we can generate the set of input scripts needed to
// sweep the output.
var witnessType input.WitnessType
switch output.AddressType {
// If this is a p2wkh output, then we'll assume it's a witness
// key hash witness type.
case lnwallet.WitnessPubKey:
witnessType = input.WitnessKeyHash
// If this is a p2sh output, then as since it's under control
// of the wallet, we'll assume it's a nested p2sh output.
case lnwallet.NestedWitnessPubKey:
witnessType = input.NestedWitnessKeyHash
case lnwallet.TaprootPubkey:
witnessType = input.TaprootPubKeySpend
signDesc.HashType = txscript.SigHashDefault
// All other output types we count as unknown and will fail to
// sweep.
default:
unlockOutputs()
return nil, fmt.Errorf("unable to sweep coins, "+
"unknown script: %x", pkScript[:])
}
// Now that we've constructed the items required, we'll make an
// input which can be passed to the sweeper for ultimate
// sweeping.
input := input.MakeBaseInput(
&output.OutPoint, witnessType, signDesc, 0, nil,
)
inputsToSweep = append(inputsToSweep, &input)
}
// Create a list of TxOuts from the given delivery addresses.
var txOuts []*wire.TxOut
for _, d := range deliveryAddrs {
pkScript, err := txscript.PayToAddrScript(d.Addr)
if err != nil {
unlockOutputs()
return nil, err
}
txOuts = append(txOuts, &wire.TxOut{
PkScript: pkScript,
Value: int64(d.Amt),
})
}
// Next, we'll convert the change addr to a pkScript that we can use
// to create the sweep transaction.
changePkScript, err := txscript.PayToAddrScript(changeAddr)
if err != nil {
unlockOutputs()
return nil, err
}
// Finally, we'll ask the sweeper to craft a sweep transaction which
// respects our fee preference and targets all the UTXOs of the wallet.
sweepTx, _, err := createSweepTx(
inputsToSweep, txOuts, changePkScript, blockHeight,
feeRate, maxFeeRate, signer,
)
if err != nil {
unlockOutputs()
return nil, err
}
return &WalletSweepPackage{
SweepTx: sweepTx,
CancelSweepAttempt: unlockOutputs,
}, nil
}
// fetchUtxosFromOutpoints returns UTXOs for given outpoints. Errors if any
// outpoint is not in the passed slice of utxos.
func fetchUtxosFromOutpoints(utxos []*lnwallet.Utxo,
outpoints []wire.OutPoint) ([]*lnwallet.Utxo, error) {
lookup := fn.SliceToMap(utxos, func(utxo *lnwallet.Utxo) wire.OutPoint {
return utxo.OutPoint
}, func(utxo *lnwallet.Utxo) *lnwallet.Utxo {
return utxo
})
subMap, err := fn.NewSubMap(lookup, outpoints)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrUnknownUTXO, err.Error())
}
fetchedUtxos := slices.Collect(maps.Values(subMap))
return fetchedUtxos, nil
}
package sweep
import (
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
// weightEstimator wraps a standard weight estimator instance and adds to that
// support for child-pays-for-parent.
type weightEstimator struct {
estimator input.TxWeightEstimator
feeRate chainfee.SatPerKWeight
parents map[chainhash.Hash]struct{}
parentsFee btcutil.Amount
parentsWeight lntypes.WeightUnit
// maxFeeRate is the max allowed fee rate configured by the user.
maxFeeRate chainfee.SatPerKWeight
}
// newWeightEstimator instantiates a new sweeper weight estimator.
func newWeightEstimator(
feeRate, maxFeeRate chainfee.SatPerKWeight) *weightEstimator {
return &weightEstimator{
feeRate: feeRate,
maxFeeRate: maxFeeRate,
parents: make(map[chainhash.Hash]struct{}),
}
}
// add adds the weight of the given input to the weight estimate.
func (w *weightEstimator) add(inp input.Input) error {
// If there is a parent tx, add the parent's fee and weight.
w.tryAddParent(inp)
wt := inp.WitnessType()
return wt.AddWeightEstimation(&w.estimator)
}
// tryAddParent examines the input and updates parent tx totals if required for
// cpfp.
func (w *weightEstimator) tryAddParent(inp input.Input) {
// Get unconfirmed parent info from the input.
unconfParent := inp.UnconfParent()
// If there is no parent, there is nothing to add.
if unconfParent == nil {
return
}
// If we've already accounted for the parent tx, don't do it
// again. This can happens when two outputs of the parent tx are
// included in the same sweep tx.
parentHash := inp.OutPoint().Hash
if _, ok := w.parents[parentHash]; ok {
return
}
// Calculate parent fee rate.
parentFeeRate := chainfee.SatPerKWeight(unconfParent.Fee) * 1000 /
chainfee.SatPerKWeight(unconfParent.Weight)
// Ignore parents that pay at least the fee rate of this transaction.
// Parent pays for child is not happening.
if parentFeeRate >= w.feeRate {
return
}
// Include parent.
w.parents[parentHash] = struct{}{}
w.parentsFee += unconfParent.Fee
w.parentsWeight += unconfParent.Weight
}
// addP2WKHOutput updates the weight estimate to account for an additional
// native P2WKH output.
func (w *weightEstimator) addP2WKHOutput() {
w.estimator.AddP2WKHOutput()
}
// addP2TROutput updates the weight estimate to account for an additional native
// SegWit v1 P2TR output.
func (w *weightEstimator) addP2TROutput() {
w.estimator.AddP2TROutput()
}
// addP2WSHOutput updates the weight estimate to account for an additional
// segwit v0 P2WSH output.
func (w *weightEstimator) addP2WSHOutput() {
w.estimator.AddP2WSHOutput()
}
// addOutput updates the weight estimate to account for the known
// output given.
func (w *weightEstimator) addOutput(txOut *wire.TxOut) {
w.estimator.AddTxOutput(txOut)
}
// weight gets the estimated weight of the transaction.
func (w *weightEstimator) weight() lntypes.WeightUnit {
return w.estimator.Weight()
}
// fee returns the tx fee to use for the aggregated inputs and outputs, which
// is different from feeWithParent as it doesn't take into account unconfirmed
// parent transactions.
func (w *weightEstimator) fee() btcutil.Amount {
// Calculate the weight of the transaction.
weight := w.estimator.Weight()
// Calculate the fee.
fee := w.feeRate.FeeForWeight(weight)
return fee
}
// feeWithParent returns the tx fee to use for the aggregated inputs and
// outputs, taking into account unconfirmed parent transactions (cpfp).
func (w *weightEstimator) feeWithParent() btcutil.Amount {
// Calculate fee and weight for just this tx.
childWeight := w.estimator.Weight()
// Add combined weight of unconfirmed parent txes.
totalWeight := childWeight + w.parentsWeight
// Subtract fee already paid by parents.
fee := w.feeRate.FeeForWeight(totalWeight) - w.parentsFee
// Clamp the fee to what would be required if no parent txes were paid
// for. This is to make sure no rounding errors can get us into trouble.
childFee := w.feeRate.FeeForWeight(childWeight)
if childFee > fee {
fee = childFee
}
// Exit early if maxFeeRate is not set.
if w.maxFeeRate == 0 {
return fee
}
// Clamp the fee to the max fee rate.
maxFee := w.maxFeeRate.FeeForWeight(childWeight)
if fee > maxFee {
// Calculate the effective fee rate for logging.
childFeeRate := chainfee.SatPerKWeight(
fee * 1000 / btcutil.Amount(childWeight),
)
log.Warnf("Child fee rate %v exceeds max allowed fee rate %v, "+
"returning fee %v instead of %v", childFeeRate,
w.maxFeeRate, maxFee, fee)
fee = maxFee
}
return fee
}
package zpay32
import (
"fmt"
"strconv"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// toMSat is a map from a unit to a function that converts an amount
// of that unit to millisatoshis.
toMSat = map[byte]func(uint64) (lnwire.MilliSatoshi, error){
'm': mBtcToMSat,
'u': uBtcToMSat,
'n': nBtcToMSat,
'p': pBtcToMSat,
}
// fromMSat is a map from a unit to a function that converts an amount
// in millisatoshis to an amount of that unit.
fromMSat = map[byte]func(lnwire.MilliSatoshi) (uint64, error){
'm': mSatToMBtc,
'u': mSatToUBtc,
'n': mSatToNBtc,
'p': mSatToPBtc,
}
)
// mBtcToMSat converts the given amount in milliBTC to millisatoshis.
func mBtcToMSat(m uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(m) * 100000000, nil
}
// uBtcToMSat converts the given amount in microBTC to millisatoshis.
func uBtcToMSat(u uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(u * 100000), nil
}
// nBtcToMSat converts the given amount in nanoBTC to millisatoshis.
func nBtcToMSat(n uint64) (lnwire.MilliSatoshi, error) {
return lnwire.MilliSatoshi(n * 100), nil
}
// pBtcToMSat converts the given amount in picoBTC to millisatoshis.
func pBtcToMSat(p uint64) (lnwire.MilliSatoshi, error) {
if p < 10 {
return 0, fmt.Errorf("minimum amount is 10p")
}
if p%10 != 0 {
return 0, fmt.Errorf("amount %d pBTC not expressible in msat",
p)
}
return lnwire.MilliSatoshi(p / 10), nil
}
// mSatToMBtc converts the given amount in millisatoshis to milliBTC.
func mSatToMBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100000000 != 0 {
return 0, fmt.Errorf("%d msat not expressible "+
"in mBTC", msat)
}
return uint64(msat / 100000000), nil
}
// mSatToUBtc converts the given amount in millisatoshis to microBTC.
func mSatToUBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100000 != 0 {
return 0, fmt.Errorf("%d msat not expressible "+
"in uBTC", msat)
}
return uint64(msat / 100000), nil
}
// mSatToNBtc converts the given amount in millisatoshis to nanoBTC.
func mSatToNBtc(msat lnwire.MilliSatoshi) (uint64, error) {
if msat%100 != 0 {
return 0, fmt.Errorf("%d msat not expressible in nBTC", msat)
}
return uint64(msat / 100), nil
}
// mSatToPBtc converts the given amount in millisatoshis to picoBTC.
func mSatToPBtc(msat lnwire.MilliSatoshi) (uint64, error) {
return uint64(msat * 10), nil
}
// decodeAmount returns the amount encoded by the provided string in
// millisatoshi.
func decodeAmount(amount string) (lnwire.MilliSatoshi, error) {
if len(amount) < 1 {
return 0, fmt.Errorf("amount must be non-empty")
}
// If last character is a digit, then the amount can just be
// interpreted as BTC.
char := amount[len(amount)-1]
digit := char - '0'
if digit >= 0 && digit <= 9 {
btc, err := strconv.ParseUint(amount, 10, 64)
if err != nil {
return 0, err
}
return lnwire.MilliSatoshi(btc) * mSatPerBtc, nil
}
// If not a digit, it must be part of the known units.
conv, ok := toMSat[char]
if !ok {
return 0, fmt.Errorf("unknown multiplier %c", char)
}
// Known unit.
num := amount[:len(amount)-1]
if len(num) < 1 {
return 0, fmt.Errorf("number must be non-empty")
}
am, err := strconv.ParseUint(num, 10, 64)
if err != nil {
return 0, err
}
return conv(am)
}
// encodeAmount encodes the provided millisatoshi amount using as few characters
// as possible.
func encodeAmount(msat lnwire.MilliSatoshi) (string, error) {
// If possible to express in BTC, that will always be the shortest
// representation.
if msat%mSatPerBtc == 0 {
return strconv.FormatInt(int64(msat/mSatPerBtc), 10), nil
}
// Should always be expressible in pico BTC.
pico, err := fromMSat['p'](msat)
if err != nil {
return "", fmt.Errorf("unable to express %d msat as pBTC: %w",
msat, err)
}
shortened := strconv.FormatUint(pico, 10) + "p"
for unit, conv := range fromMSat {
am, err := conv(msat)
if err != nil {
// Not expressible using this unit.
continue
}
// Save the shortest found representation.
str := strconv.FormatUint(am, 10) + string(unit)
if len(str) < len(shortened) {
shortened = str
}
}
return shortened, nil
}
package zpay32
import (
"fmt"
"strings"
)
const charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
var gen = []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3}
// NOTE: This method it a slight modification of the method bech32.Decode found
// btcutil, allowing strings to be more than 90 characters.
// decodeBech32 decodes a bech32 encoded string, returning the human-readable
// part and the data part excluding the checksum.
// Note: the data will be base32 encoded, that is each element of the returned
// byte array will encode 5 bits of data. Use the ConvertBits method to convert
// this to 8-bit representation.
func decodeBech32(bech string) (string, []byte, error) {
// The maximum allowed length for a bech32 string is 90. It must also
// be at least 8 characters, since it needs a non-empty HRP, a
// separator, and a 6 character checksum.
// NB: The 90 character check specified in BIP173 is skipped here, to
// allow strings longer than 90 characters.
if len(bech) < 8 {
return "", nil, fmt.Errorf("invalid bech32 string length %d",
len(bech))
}
// Only ASCII characters between 33 and 126 are allowed.
for i := 0; i < len(bech); i++ {
if bech[i] < 33 || bech[i] > 126 {
return "", nil, fmt.Errorf("invalid character in "+
"string: '%c'", bech[i])
}
}
// The characters must be either all lowercase or all uppercase.
lower := strings.ToLower(bech)
upper := strings.ToUpper(bech)
if bech != lower && bech != upper {
return "", nil, fmt.Errorf("string not all lowercase or all " +
"uppercase")
}
// We'll work with the lowercase string from now on.
bech = lower
// The string is invalid if the last '1' is non-existent, it is the
// first character of the string (no human-readable part) or one of the
// last 6 characters of the string (since checksum cannot contain '1'),
// or if the string is more than 90 characters in total.
one := strings.LastIndexByte(bech, '1')
if one < 1 || one+7 > len(bech) {
return "", nil, fmt.Errorf("invalid index of 1")
}
// The human-readable part is everything before the last '1'.
hrp := bech[:one]
data := bech[one+1:]
// Each character corresponds to the byte with value of the index in
// 'charset'.
decoded, err := toBytes(data)
if err != nil {
return "", nil, fmt.Errorf("failed converting data to bytes: "+
"%v", err)
}
if !bech32VerifyChecksum(hrp, decoded) {
moreInfo := ""
checksum := bech[len(bech)-6:]
expected, err := toChars(bech32Checksum(hrp,
decoded[:len(decoded)-6]))
if err == nil {
moreInfo = fmt.Sprintf("Expected %v, got %v.",
expected, checksum)
}
return "", nil, fmt.Errorf("checksum failed. " + moreInfo)
}
// We exclude the last 6 bytes, which is the checksum.
return hrp, decoded[:len(decoded)-6], nil
}
// toBytes converts each character in the string 'chars' to the value of the
// index of the corresponding character in 'charset'.
func toBytes(chars string) ([]byte, error) {
decoded := make([]byte, 0, len(chars))
for i := 0; i < len(chars); i++ {
index := strings.IndexByte(charset, chars[i])
if index < 0 {
return nil, fmt.Errorf("invalid character not part of "+
"charset: %v", chars[i])
}
decoded = append(decoded, byte(index))
}
return decoded, nil
}
// toChars converts the byte slice 'data' to a string where each byte in 'data'
// encodes the index of a character in 'charset'.
func toChars(data []byte) (string, error) {
result := make([]byte, 0, len(data))
for _, b := range data {
if int(b) >= len(charset) {
return "", fmt.Errorf("invalid data byte: %v", b)
}
result = append(result, charset[b])
}
return string(result), nil
}
// For more details on the checksum calculation, please refer to BIP 173.
func bech32Checksum(hrp string, data []byte) []byte {
// Convert the bytes to list of integers, as this is needed for the
// checksum calculation.
integers := make([]int, len(data))
for i, b := range data {
integers[i] = int(b)
}
values := append(bech32HrpExpand(hrp), integers...)
values = append(values, []int{0, 0, 0, 0, 0, 0}...)
polymod := bech32Polymod(values) ^ 1
var res []byte
for i := 0; i < 6; i++ {
res = append(res, byte((polymod>>uint(5*(5-i)))&31))
}
return res
}
// For more details on the polymod calculation, please refer to BIP 173.
func bech32Polymod(values []int) int {
chk := 1
for _, v := range values {
b := chk >> 25
chk = (chk&0x1ffffff)<<5 ^ v
for i := 0; i < 5; i++ {
if (b>>uint(i))&1 == 1 {
chk ^= gen[i]
}
}
}
return chk
}
// For more details on HRP expansion, please refer to BIP 173.
func bech32HrpExpand(hrp string) []int {
v := make([]int, 0, len(hrp)*2+1)
for i := 0; i < len(hrp); i++ {
v = append(v, int(hrp[i]>>5))
}
v = append(v, 0)
for i := 0; i < len(hrp); i++ {
v = append(v, int(hrp[i]&31))
}
return v
}
// For more details on the checksum verification, please refer to BIP 173.
func bech32VerifyChecksum(hrp string, data []byte) bool {
integers := make([]int, len(data))
for i, b := range data {
integers[i] = int(b)
}
concat := append(bech32HrpExpand(hrp), integers...)
return bech32Polymod(concat) == 1
}
package zpay32
import (
"encoding/binary"
"fmt"
"io"
"github.com/btcsuite/btcd/btcec/v2"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
const (
// maxNumHopsPerPath is the maximum number of blinded path hops that can
// be included in a single encoded blinded path. This is calculated
// based on the `data_length` limit of 638 bytes for any tagged field in
// a BOLT 11 invoice along with the estimated number of bytes required
// for encoding the most minimal blinded path hop. See the [bLIP
// proposal](https://github.com/lightning/blips/pull/39) for a detailed
// calculation.
maxNumHopsPerPath = 7
// maxCipherTextLength defines the largest cipher text size allowed.
// This is derived by using the `data_length` upper bound of 639 bytes
// and then assuming the case of a path with only a single hop (meaning
// the cipher text may be as large as possible).
maxCipherTextLength = 535
)
var (
// byteOrder defines the endian-ness we use for encoding to and from
// buffers.
byteOrder = binary.BigEndian
)
// BlindedPaymentPath holds all the information a payer needs to know about a
// blinded path to a receiver of a payment.
type BlindedPaymentPath struct {
// FeeBaseMsat is the total base fee for the path in milli-satoshis.
FeeBaseMsat uint32
// FeeRate is the total fee rate for the path in parts per million.
FeeRate uint32
// CltvExpiryDelta is the total CLTV delta to apply to the path.
CltvExpiryDelta uint16
// HTLCMinMsat is the minimum number of milli-satoshis that any hop in
// the path will route.
HTLCMinMsat uint64
// HTLCMaxMsat is the maximum number of milli-satoshis that a hop in the
// path will route.
HTLCMaxMsat uint64
// Features is the feature bit vector for the path.
Features *lnwire.FeatureVector
// FirstEphemeralBlindingPoint is the blinding point to send to the
// introduction node. It will be used by the introduction node to derive
// a shared secret with the receiver which can then be used to decode
// the encrypted payload from the receiver.
FirstEphemeralBlindingPoint *btcec.PublicKey
// Hops is the blinded path. The first hop is the introduction node and
// so the BlindedNodeID of this hop will be the real node ID.
Hops []*sphinx.BlindedHopInfo
}
// DecodeBlindedPayment attempts to parse a BlindedPaymentPath from the passed
// reader.
func DecodeBlindedPayment(r io.Reader) (*BlindedPaymentPath, error) {
var payment BlindedPaymentPath
if err := binary.Read(r, byteOrder, &payment.FeeBaseMsat); err != nil {
return nil, err
}
if err := binary.Read(r, byteOrder, &payment.FeeRate); err != nil {
return nil, err
}
err := binary.Read(r, byteOrder, &payment.CltvExpiryDelta)
if err != nil {
return nil, err
}
err = binary.Read(r, byteOrder, &payment.HTLCMinMsat)
if err != nil {
return nil, err
}
err = binary.Read(r, byteOrder, &payment.HTLCMaxMsat)
if err != nil {
return nil, err
}
// Parse the feature bit vector.
f := lnwire.EmptyFeatureVector()
err = f.Decode(r)
if err != nil {
return nil, err
}
payment.Features = f
// Parse the first ephemeral blinding point.
var blindingPointBytes [btcec.PubKeyBytesLenCompressed]byte
_, err = r.Read(blindingPointBytes[:])
if err != nil {
return nil, err
}
blinding, err := btcec.ParsePubKey(blindingPointBytes[:])
if err != nil {
return nil, err
}
payment.FirstEphemeralBlindingPoint = blinding
// Read the one byte hop number.
var numHops [1]byte
_, err = r.Read(numHops[:])
if err != nil {
return nil, err
}
payment.Hops = make([]*sphinx.BlindedHopInfo, int(numHops[0]))
// Parse each hop.
for i := 0; i < len(payment.Hops); i++ {
hop, err := DecodeBlindedHop(r)
if err != nil {
return nil, err
}
payment.Hops[i] = hop
}
return &payment, nil
}
// Encode serialises the BlindedPaymentPath and writes the bytes to the passed
// writer.
// 1) The first 26 bytes contain the relay info:
// - Base Fee in msat: uint32 (4 bytes).
// - Proportional Fee in PPM: uint32 (4 bytes).
// - CLTV expiry delta: uint16 (2 bytes).
// - HTLC min msat: uint64 (8 bytes).
// - HTLC max msat: uint64 (8 bytes).
//
// 2) Feature bit vector length (2 bytes).
// 3) Feature bit vector (can be zero length).
// 4) First blinding point: 33 bytes.
// 5) Number of hops: 1 byte.
// 6) Encoded BlindedHops.
func (p *BlindedPaymentPath) Encode(w io.Writer) error {
if err := binary.Write(w, byteOrder, p.FeeBaseMsat); err != nil {
return err
}
if err := binary.Write(w, byteOrder, p.FeeRate); err != nil {
return err
}
if err := binary.Write(w, byteOrder, p.CltvExpiryDelta); err != nil {
return err
}
if err := binary.Write(w, byteOrder, p.HTLCMinMsat); err != nil {
return err
}
if err := binary.Write(w, byteOrder, p.HTLCMaxMsat); err != nil {
return err
}
if err := p.Features.Encode(w); err != nil {
return err
}
_, err := w.Write(p.FirstEphemeralBlindingPoint.SerializeCompressed())
if err != nil {
return err
}
numHops := len(p.Hops)
if numHops > maxNumHopsPerPath {
return fmt.Errorf("the number of hops, %d, exceeds the "+
"maximum of %d", numHops, maxNumHopsPerPath)
}
if _, err := w.Write([]byte{byte(numHops)}); err != nil {
return err
}
for _, hop := range p.Hops {
if err := EncodeBlindedHop(w, hop); err != nil {
return err
}
}
return nil
}
// DecodeBlindedHop reads a sphinx.BlindedHopInfo from the passed reader.
func DecodeBlindedHop(r io.Reader) (*sphinx.BlindedHopInfo, error) {
var nodeIDBytes [btcec.PubKeyBytesLenCompressed]byte
_, err := r.Read(nodeIDBytes[:])
if err != nil {
return nil, err
}
nodeID, err := btcec.ParsePubKey(nodeIDBytes[:])
if err != nil {
return nil, err
}
dataLen, err := tlv.ReadVarInt(r, &[8]byte{})
if err != nil {
return nil, err
}
if dataLen > maxCipherTextLength {
return nil, fmt.Errorf("a blinded hop cipher text blob may "+
"not exceed the maximum of %d bytes",
maxCipherTextLength)
}
encryptedData := make([]byte, dataLen)
_, err = r.Read(encryptedData)
if err != nil {
return nil, err
}
return &sphinx.BlindedHopInfo{
BlindedNodePub: nodeID,
CipherText: encryptedData,
}, nil
}
// EncodeBlindedHop writes the passed BlindedHopInfo to the given writer.
//
// 1) Blinded node pub key: 33 bytes
// 2) Cipher text length: BigSize
// 3) Cipher text.
func EncodeBlindedHop(w io.Writer, hop *sphinx.BlindedHopInfo) error {
_, err := w.Write(hop.BlindedNodePub.SerializeCompressed())
if err != nil {
return err
}
if len(hop.CipherText) > maxCipherTextLength {
return fmt.Errorf("encrypted recipient data can not exceed a "+
"length of %d bytes", maxCipherTextLength)
}
err = tlv.WriteVarInt(w, uint64(len(hop.CipherText)), &[8]byte{})
if err != nil {
return err
}
_, err = w.Write(hop.CipherText)
return err
}
package zpay32
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"strings"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/ecdsa"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/bech32"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
// DecodeOption is a type that can be used to supply functional options to the
// Decode function.
type DecodeOption func(*decodeOptions)
// WithKnownFeatureBits is a functional option that overwrites the set of
// known feature bits. If not set, then LND's lnwire.Features variable will be
// used by default.
func WithKnownFeatureBits(features map[lnwire.FeatureBit]string) DecodeOption {
return func(options *decodeOptions) {
options.knownFeatureBits = features
}
}
// WithErrorOnUnknownFeatureBit is a functional option that will cause the
// Decode function to return an error if the decoded invoice contains an unknown
// feature bit.
func WithErrorOnUnknownFeatureBit() DecodeOption {
return func(options *decodeOptions) {
options.errorOnUnknownFeature = true
}
}
// decodeOptions holds the set of Decode options.
type decodeOptions struct {
knownFeatureBits map[lnwire.FeatureBit]string
errorOnUnknownFeature bool
}
// newDecodeOptions constructs the default decodeOptions struct.
func newDecodeOptions() *decodeOptions {
return &decodeOptions{
knownFeatureBits: lnwire.Features,
errorOnUnknownFeature: false,
}
}
// Decode parses the provided encoded invoice and returns a decoded Invoice if
// it is valid by BOLT-0011 and matches the provided active network.
func Decode(invoice string, net *chaincfg.Params, opts ...DecodeOption) (
*Invoice, error) {
options := newDecodeOptions()
for _, o := range opts {
o(options)
}
var decodedInvoice Invoice
// Before bech32 decoding the invoice, make sure that it is not too large.
// This is done as an anti-DoS measure since bech32 decoding is expensive.
if len(invoice) > maxInvoiceLength {
return nil, ErrInvoiceTooLarge
}
// Decode the invoice using the modified bech32 decoder.
hrp, data, err := decodeBech32(invoice)
if err != nil {
return nil, err
}
// We expect the human-readable part to at least have ln + one char
// encoding the network.
if len(hrp) < 3 {
return nil, fmt.Errorf("hrp too short")
}
// First two characters of HRP should be "ln".
if hrp[:2] != "ln" {
return nil, fmt.Errorf("prefix should be \"ln\"")
}
// The next characters should be a valid prefix for a segwit BIP173
// address that match the active network except for signet where we add
// an additional "s" to differentiate it from the older testnet3 (Core
// devs decided to use the same hrp for signet as for testnet3 which is
// not optimal for LN). See
// https://github.com/lightningnetwork/lightning-rfc/pull/844 for more
// information.
expectedPrefix := net.Bech32HRPSegwit
if net.Name == chaincfg.SigNetParams.Name {
expectedPrefix = "tbs"
}
if !strings.HasPrefix(hrp[2:], expectedPrefix) {
return nil, fmt.Errorf(
"invoice not for current active network '%s'", net.Name)
}
decodedInvoice.Net = net
// Optionally, if there's anything left of the HRP after ln + the segwit
// prefix, we try to decode this as the payment amount.
var netPrefixLength = len(expectedPrefix) + 2
if len(hrp) > netPrefixLength {
amount, err := decodeAmount(hrp[netPrefixLength:])
if err != nil {
return nil, err
}
decodedInvoice.MilliSat = &amount
}
// Everything except the last 520 bits of the data encodes the invoice's
// timestamp and tagged fields.
if len(data) < signatureBase32Len {
return nil, errors.New("short invoice")
}
invoiceData := data[:len(data)-signatureBase32Len]
// Parse the timestamp and tagged fields, and fill the Invoice struct.
if err := parseData(&decodedInvoice, invoiceData, net); err != nil {
return nil, err
}
// The last 520 bits (104 groups) make up the signature.
sigBase32 := data[len(data)-signatureBase32Len:]
sigBase256, err := bech32.ConvertBits(sigBase32, 5, 8, true)
if err != nil {
return nil, err
}
sig, err := lnwire.NewSigFromWireECDSA(sigBase256[:64])
if err != nil {
return nil, err
}
recoveryID := sigBase256[64]
// The signature is over the hrp + the data the invoice, encoded in
// base 256.
taggedDataBytes, err := bech32.ConvertBits(invoiceData, 5, 8, true)
if err != nil {
return nil, err
}
toSign := append([]byte(hrp), taggedDataBytes...)
// We expect the signature to be over the single SHA-256 hash of that
// data.
hash := chainhash.HashB(toSign)
// If the destination pubkey was provided as a tagged field, use that
// to verify the signature, if not do public key recovery.
if decodedInvoice.Destination != nil {
signature, err := sig.ToSignature()
if err != nil {
return nil, fmt.Errorf("unable to deserialize "+
"signature: %v", err)
}
if !signature.Verify(hash, decodedInvoice.Destination) {
return nil, fmt.Errorf("invalid invoice signature")
}
} else {
headerByte := recoveryID + 27 + 4
compactSign := append([]byte{headerByte}, sig.RawBytes()...)
pubkey, _, err := ecdsa.RecoverCompact(compactSign, hash)
if err != nil {
return nil, err
}
decodedInvoice.Destination = pubkey
}
// If no feature vector was decoded, populate an empty one.
if decodedInvoice.Features == nil {
decodedInvoice.Features = lnwire.NewFeatureVector(
nil, options.knownFeatureBits,
)
}
// Now that we have created the invoice, make sure it has the required
// fields set.
if err := validateInvoice(&decodedInvoice); err != nil {
return nil, err
}
if options.errorOnUnknownFeature {
// Make sure that we understand all the required feature bits
// in the invoice.
unknownFeatureBits := decodedInvoice.Features.
UnknownRequiredFeatures()
if len(unknownFeatureBits) > 0 {
errStr := fmt.Sprintf("invoice contains " +
"unknown feature bits:")
for _, bit := range unknownFeatureBits {
errStr += fmt.Sprintf(" %d,", bit)
}
return nil, errors.New(strings.TrimRight(errStr, ","))
}
}
return &decodedInvoice, nil
}
// parseData parses the data part of the invoice. It expects base32 data
// returned from the bech32.Decode method, except signature.
func parseData(invoice *Invoice, data []byte, net *chaincfg.Params) error {
// It must contain the timestamp, encoded using 35 bits (7 groups).
if len(data) < timestampBase32Len {
return fmt.Errorf("data too short: %d", len(data))
}
t, err := parseTimestamp(data[:timestampBase32Len])
if err != nil {
return err
}
invoice.Timestamp = time.Unix(int64(t), 0)
// The rest are tagged parts.
tagData := data[7:]
return parseTaggedFields(invoice, tagData, net)
}
// parseTimestamp converts a 35-bit timestamp (encoded in base32) to uint64.
func parseTimestamp(data []byte) (uint64, error) {
if len(data) != timestampBase32Len {
return 0, fmt.Errorf("timestamp must be 35 bits, was %d",
len(data)*5)
}
return base32ToUint64(data)
}
// parseTaggedFields takes the base32 encoded tagged fields of the invoice, and
// fills the Invoice struct accordingly.
func parseTaggedFields(invoice *Invoice, fields []byte, net *chaincfg.Params) error {
index := 0
for len(fields)-index > 0 {
// If there are less than 3 groups to read, there cannot be more
// interesting information, as we need the type (1 group) and
// length (2 groups).
//
// This means the last tagged field is broken.
if len(fields)-index < 3 {
return ErrBrokenTaggedField
}
typ := fields[index]
dataLength, err := parseFieldDataLength(fields[index+1 : index+3])
if err != nil {
return err
}
// If we don't have enough field data left to read this length,
// return error.
if len(fields) < index+3+int(dataLength) {
return ErrInvalidFieldLength
}
base32Data := fields[index+3 : index+3+int(dataLength)]
// Advance the index in preparation for the next iteration.
index += 3 + int(dataLength)
switch typ {
case fieldTypeP:
if invoice.PaymentHash != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.PaymentHash, err = parse32Bytes(base32Data)
case fieldTypeS:
if invoice.PaymentAddr.IsSome() {
// We skip the field if we have already seen a
// supported one.
continue
}
addr, err := parse32Bytes(base32Data)
if err != nil {
return err
}
if addr != nil {
invoice.PaymentAddr = fn.Some(*addr)
}
case fieldTypeD:
if invoice.Description != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Description, err = parseDescription(base32Data)
case fieldTypeM:
if invoice.Metadata != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Metadata, err = parseMetadata(base32Data)
case fieldTypeN:
if invoice.Destination != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Destination, err = parseDestination(base32Data)
case fieldTypeH:
if invoice.DescriptionHash != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.DescriptionHash, err = parse32Bytes(base32Data)
case fieldTypeX:
if invoice.expiry != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.expiry, err = parseExpiry(base32Data)
case fieldTypeC:
if invoice.minFinalCLTVExpiry != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.minFinalCLTVExpiry, err =
parseMinFinalCLTVExpiry(base32Data)
case fieldTypeF:
if invoice.FallbackAddr != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.FallbackAddr, err = parseFallbackAddr(
base32Data, net,
)
case fieldTypeR:
// An `r` field can be included in an invoice multiple
// times, so we won't skip it if we have already seen
// one.
routeHint, err := parseRouteHint(base32Data)
if err != nil {
return err
}
invoice.RouteHints = append(
invoice.RouteHints, routeHint,
)
case fieldType9:
if invoice.Features != nil {
// We skip the field if we have already seen a
// supported one.
continue
}
invoice.Features, err = parseFeatures(base32Data)
case fieldTypeB:
blindedPaymentPath, err := parseBlindedPaymentPath(
base32Data,
)
if err != nil {
return err
}
invoice.BlindedPaymentPaths = append(
invoice.BlindedPaymentPaths, blindedPaymentPath,
)
default:
// Ignore unknown type.
}
// Check if there was an error from parsing any of the tagged
// fields and return it.
if err != nil {
return err
}
}
return nil
}
// parseFieldDataLength converts the two byte slice into a uint16.
func parseFieldDataLength(data []byte) (uint16, error) {
if len(data) != 2 {
return 0, fmt.Errorf("data length must be 2 bytes, was %d",
len(data))
}
return uint16(data[0])<<5 | uint16(data[1]), nil
}
// parse32Bytes converts a 256-bit value (encoded in base32) to *[32]byte. This
// can be used for payment hashes, description hashes, payment addresses, etc.
func parse32Bytes(data []byte) (*[32]byte, error) {
var paymentHash [32]byte
// As BOLT-11 states, a reader must skip over the 32-byte fields if
// it does not have a length of 52, so avoid returning an error.
if len(data) != hashBase32Len {
return nil, nil
}
hash, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
copy(paymentHash[:], hash)
return &paymentHash, nil
}
// parseDescription converts the data (encoded in base32) into a string to use
// as the description.
func parseDescription(data []byte) (*string, error) {
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
description := string(base256Data)
return &description, nil
}
// parseMetadata converts the data (encoded in base32) into a byte slice to use
// as the metadata.
func parseMetadata(data []byte) ([]byte, error) {
return bech32.ConvertBits(data, 5, 8, false)
}
// parseDestination converts the data (encoded in base32) into a 33-byte public
// key of the payee node.
func parseDestination(data []byte) (*btcec.PublicKey, error) {
// As BOLT-11 states, a reader must skip over the destination field
// if it does not have a length of 53, so avoid returning an error.
if len(data) != pubKeyBase32Len {
return nil, nil
}
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
return btcec.ParsePubKey(base256Data)
}
// parseExpiry converts the data (encoded in base32) into the expiry time.
func parseExpiry(data []byte) (*time.Duration, error) {
expiry, err := base32ToUint64(data)
if err != nil {
return nil, err
}
duration := time.Duration(expiry) * time.Second
return &duration, nil
}
// parseMinFinalCLTVExpiry converts the data (encoded in base32) into a uint64
// to use as the minFinalCLTVExpiry.
func parseMinFinalCLTVExpiry(data []byte) (*uint64, error) {
expiry, err := base32ToUint64(data)
if err != nil {
return nil, err
}
return &expiry, nil
}
// parseFallbackAddr converts the data (encoded in base32) into a fallback
// on-chain address.
func parseFallbackAddr(data []byte, net *chaincfg.Params) (btcutil.Address, error) { // nolint:dupl
// Checks if the data is empty or contains a version without an address.
if len(data) < 2 {
return nil, fmt.Errorf("empty fallback address field")
}
var addr btcutil.Address
version := data[0]
switch version {
case 0:
witness, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
switch len(witness) {
case 20:
addr, err = btcutil.NewAddressWitnessPubKeyHash(witness, net)
case 32:
addr, err = btcutil.NewAddressWitnessScriptHash(witness, net)
default:
return nil, fmt.Errorf("unknown witness program length %d",
len(witness))
}
if err != nil {
return nil, err
}
case 17:
pubKeyHash, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
addr, err = btcutil.NewAddressPubKeyHash(pubKeyHash, net)
if err != nil {
return nil, err
}
case 18:
scriptHash, err := bech32.ConvertBits(data[1:], 5, 8, false)
if err != nil {
return nil, err
}
addr, err = btcutil.NewAddressScriptHashFromHash(scriptHash, net)
if err != nil {
return nil, err
}
default:
// Ignore unknown version.
}
return addr, nil
}
// parseRouteHint converts the data (encoded in base32) into an array containing
// one or more routing hop hints that represent a single route hint.
func parseRouteHint(data []byte) ([]HopHint, error) {
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
// Check that base256Data is a multiple of hopHintLen.
if len(base256Data)%hopHintLen != 0 {
return nil, fmt.Errorf("expected length multiple of %d bytes, "+
"got %d", hopHintLen, len(base256Data))
}
var routeHint []HopHint
for len(base256Data) > 0 {
hopHint := HopHint{}
hopHint.NodeID, err = btcec.ParsePubKey(base256Data[:33])
if err != nil {
return nil, err
}
hopHint.ChannelID = binary.BigEndian.Uint64(base256Data[33:41])
hopHint.FeeBaseMSat = binary.BigEndian.Uint32(base256Data[41:45])
hopHint.FeeProportionalMillionths = binary.BigEndian.Uint32(base256Data[45:49])
hopHint.CLTVExpiryDelta = binary.BigEndian.Uint16(base256Data[49:51])
routeHint = append(routeHint, hopHint)
base256Data = base256Data[51:]
}
return routeHint, nil
}
// parseBlindedPaymentPath attempts to parse a BlindedPaymentPath from the given
// byte slice.
func parseBlindedPaymentPath(data []byte) (*BlindedPaymentPath, error) {
base256Data, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return nil, err
}
return DecodeBlindedPayment(bytes.NewReader(base256Data))
}
// parseFeatures decodes any feature bits directly from the base32
// representation.
func parseFeatures(data []byte) (*lnwire.FeatureVector, error) {
rawFeatures := lnwire.NewRawFeatureVector()
err := rawFeatures.DecodeBase32(bytes.NewReader(data), len(data))
if err != nil {
return nil, err
}
return lnwire.NewFeatureVector(rawFeatures, lnwire.Features), nil
}
// base32ToUint64 converts a base32 encoded number to uint64.
func base32ToUint64(data []byte) (uint64, error) {
// Maximum that fits in uint64 is ceil(64 / 5) = 12 groups.
if len(data) > 13 {
return 0, fmt.Errorf("cannot parse data of length %d as uint64",
len(data))
}
val := uint64(0)
for i := 0; i < len(data); i++ {
val = val<<5 | uint64(data[i])
}
return val, nil
}
package zpay32
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/btcutil/bech32"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
// Encode takes the given MessageSigner and returns a string encoding this
// invoice signed by the node key of the signer.
func (invoice *Invoice) Encode(signer MessageSigner) (string, error) {
// First check that this invoice is valid before starting the encoding.
if err := validateInvoice(invoice); err != nil {
return "", err
}
// The buffer will encoded the invoice data using 5-bit groups (base32).
var bufferBase32 bytes.Buffer
// The timestamp will be encoded using 35 bits, in base32.
timestampBase32 := uint64ToBase32(uint64(invoice.Timestamp.Unix()))
// The timestamp must be exactly 35 bits, which means 7 groups. If it
// can fit into fewer groups we add leading zero groups, if it is too
// big we fail early, as there is not possible to encode it.
if len(timestampBase32) > timestampBase32Len {
return "", fmt.Errorf("timestamp too big: %d",
invoice.Timestamp.Unix())
}
// Add zero bytes to the first timestampBase32Len-len(timestampBase32)
// groups, then add the non-zero groups.
zeroes := make([]byte, timestampBase32Len-len(timestampBase32))
_, err := bufferBase32.Write(zeroes)
if err != nil {
return "", fmt.Errorf("unable to write to buffer: %w", err)
}
_, err = bufferBase32.Write(timestampBase32)
if err != nil {
return "", fmt.Errorf("unable to write to buffer: %w", err)
}
// We now write the tagged fields to the buffer, which will fill the
// rest of the data part before the signature.
if err := writeTaggedFields(&bufferBase32, invoice); err != nil {
return "", err
}
// The human-readable part (hrp) is "ln" + net hrp + optional amount,
// except for signet where we add an additional "s" to differentiate it
// from the older testnet3 (Core devs decided to use the same hrp for
// signet as for testnet3 which is not optimal for LN). See
// https://github.com/lightningnetwork/lightning-rfc/pull/844 for more
// information.
hrp := "ln" + invoice.Net.Bech32HRPSegwit
if invoice.Net.Name == chaincfg.SigNetParams.Name {
hrp = "lntbs"
}
if invoice.MilliSat != nil {
// Encode the amount using the fewest possible characters.
am, err := encodeAmount(*invoice.MilliSat)
if err != nil {
return "", err
}
hrp += am
}
// The signature is over the single SHA-256 hash of the hrp + the
// tagged fields encoded in base256.
taggedFieldsBytes, err := bech32.ConvertBits(bufferBase32.Bytes(), 5, 8, true)
if err != nil {
return "", err
}
toSign := append([]byte(hrp), taggedFieldsBytes...)
// We use compact signature format, and also encoded the recovery ID
// such that a reader of the invoice can recover our pubkey from the
// signature.
sign, err := signer.SignCompact(toSign)
if err != nil {
return "", err
}
// From the header byte we can extract the recovery ID, and the last 64
// bytes encode the signature.
recoveryID := sign[0] - 27 - 4
sig, err := lnwire.NewSigFromWireECDSA(sign[1:])
if err != nil {
return "", err
}
// If the pubkey field was explicitly set, it must be set to the pubkey
// used to create the signature.
if invoice.Destination != nil {
signature, err := sig.ToSignature()
if err != nil {
return "", fmt.Errorf("unable to deserialize "+
"signature: %v", err)
}
hash := chainhash.HashB(toSign)
valid := signature.Verify(hash, invoice.Destination)
if !valid {
return "", fmt.Errorf("signature does not match " +
"provided pubkey")
}
}
// Convert the signature to base32 before writing it to the buffer.
signBase32, err := bech32.ConvertBits(
append(sig.RawBytes(), recoveryID),
8, 5, true,
)
if err != nil {
return "", err
}
bufferBase32.Write(signBase32)
// Now we can create the bech32 encoded string from the base32 buffer.
b32, err := bech32.Encode(hrp, bufferBase32.Bytes())
if err != nil {
return "", err
}
// Before returning, check that the bech32 encoded string is not greater
// than our largest supported invoice size.
if len(b32) > maxInvoiceLength {
return "", ErrInvoiceTooLarge
}
return b32, nil
}
// writeTaggedFields writes the non-nil tagged fields of the Invoice to the
// base32 buffer.
func writeTaggedFields(bufferBase32 *bytes.Buffer, invoice *Invoice) error {
if invoice.PaymentHash != nil {
err := writeBytes32(bufferBase32, fieldTypeP, *invoice.PaymentHash)
if err != nil {
return err
}
}
if invoice.Description != nil {
base32, err := bech32.ConvertBits([]byte(*invoice.Description),
8, 5, true)
if err != nil {
return err
}
err = writeTaggedField(bufferBase32, fieldTypeD, base32)
if err != nil {
return err
}
}
if invoice.DescriptionHash != nil {
err := writeBytes32(
bufferBase32, fieldTypeH, *invoice.DescriptionHash,
)
if err != nil {
return err
}
}
if invoice.Metadata != nil {
base32, err := bech32.ConvertBits(invoice.Metadata, 8, 5, true)
if err != nil {
return err
}
err = writeTaggedField(bufferBase32, fieldTypeM, base32)
if err != nil {
return err
}
}
if invoice.minFinalCLTVExpiry != nil {
finalDelta := uint64ToBase32(*invoice.minFinalCLTVExpiry)
err := writeTaggedField(bufferBase32, fieldTypeC, finalDelta)
if err != nil {
return err
}
}
if invoice.expiry != nil {
seconds := invoice.expiry.Seconds()
expiry := uint64ToBase32(uint64(seconds))
err := writeTaggedField(bufferBase32, fieldTypeX, expiry)
if err != nil {
return err
}
}
if invoice.FallbackAddr != nil {
var version byte
switch addr := invoice.FallbackAddr.(type) {
case *btcutil.AddressPubKeyHash:
version = 17
case *btcutil.AddressScriptHash:
version = 18
case *btcutil.AddressWitnessPubKeyHash:
version = addr.WitnessVersion()
case *btcutil.AddressWitnessScriptHash:
version = addr.WitnessVersion()
default:
return fmt.Errorf("unknown fallback address type")
}
base32Addr, err := bech32.ConvertBits(
invoice.FallbackAddr.ScriptAddress(), 8, 5, true)
if err != nil {
return err
}
err = writeTaggedField(bufferBase32, fieldTypeF,
append([]byte{version}, base32Addr...))
if err != nil {
return err
}
}
for _, routeHint := range invoice.RouteHints {
// Each hop hint is encoded using 51 bytes, so we'll make to
// sure to allocate enough space for the whole route hint.
routeHintBase256 := make([]byte, 0, hopHintLen*len(routeHint))
for _, hopHint := range routeHint {
hopHintBase256 := make([]byte, hopHintLen)
copy(hopHintBase256[:33], hopHint.NodeID.SerializeCompressed())
binary.BigEndian.PutUint64(
hopHintBase256[33:41], hopHint.ChannelID,
)
binary.BigEndian.PutUint32(
hopHintBase256[41:45], hopHint.FeeBaseMSat,
)
binary.BigEndian.PutUint32(
hopHintBase256[45:49], hopHint.FeeProportionalMillionths,
)
binary.BigEndian.PutUint16(
hopHintBase256[49:51], hopHint.CLTVExpiryDelta,
)
routeHintBase256 = append(routeHintBase256, hopHintBase256...)
}
routeHintBase32, err := bech32.ConvertBits(
routeHintBase256, 8, 5, true,
)
if err != nil {
return err
}
err = writeTaggedField(bufferBase32, fieldTypeR, routeHintBase32)
if err != nil {
return err
}
}
for _, path := range invoice.BlindedPaymentPaths {
var buf bytes.Buffer
err := path.Encode(&buf)
if err != nil {
return err
}
blindedPathBase32, err := bech32.ConvertBits(
buf.Bytes(), 8, 5, true,
)
if err != nil {
return err
}
err = writeTaggedField(
bufferBase32, fieldTypeB, blindedPathBase32,
)
if err != nil {
return err
}
}
if invoice.Destination != nil {
// Convert 33 byte pubkey to 53 5-bit groups.
pubKeyBase32, err := bech32.ConvertBits(
invoice.Destination.SerializeCompressed(), 8, 5, true)
if err != nil {
return err
}
if len(pubKeyBase32) != pubKeyBase32Len {
return fmt.Errorf("invalid pubkey length: %d",
len(invoice.Destination.SerializeCompressed()))
}
err = writeTaggedField(bufferBase32, fieldTypeN, pubKeyBase32)
if err != nil {
return err
}
}
err := fn.MapOptionZ(invoice.PaymentAddr, func(addr [32]byte) error {
return writeBytes32(bufferBase32, fieldTypeS, addr)
})
if err != nil {
return err
}
if invoice.Features.SerializeSize32() > 0 {
var b bytes.Buffer
err := invoice.Features.RawFeatureVector.EncodeBase32(&b)
if err != nil {
return err
}
err = writeTaggedField(bufferBase32, fieldType9, b.Bytes())
if err != nil {
return err
}
}
return nil
}
// writeBytes32 encodes a 32-byte array as base32 and writes it to bufferBase32
// under the passed fieldType.
func writeBytes32(bufferBase32 *bytes.Buffer, fieldType byte, b [32]byte) error {
// Convert 32 byte hash to 52 5-bit groups.
base32, err := bech32.ConvertBits(b[:], 8, 5, true)
if err != nil {
return err
}
return writeTaggedField(bufferBase32, fieldType, base32)
}
// writeTaggedField takes the type of a tagged data field, and the data of
// the tagged field (encoded in base32), and writes the type, length and data
// to the buffer.
func writeTaggedField(bufferBase32 *bytes.Buffer, dataType byte, data []byte) error {
// Length must be exactly 10 bits, so add leading zero groups if
// needed.
lenBase32 := uint64ToBase32(uint64(len(data)))
for len(lenBase32) < 2 {
lenBase32 = append([]byte{0}, lenBase32...)
}
if len(lenBase32) != 2 {
return fmt.Errorf("data length too big to fit within 10 bits: %d",
len(data))
}
err := bufferBase32.WriteByte(dataType)
if err != nil {
return fmt.Errorf("unable to write to buffer: %w", err)
}
_, err = bufferBase32.Write(lenBase32)
if err != nil {
return fmt.Errorf("unable to write to buffer: %w", err)
}
_, err = bufferBase32.Write(data)
if err != nil {
return fmt.Errorf("unable to write to buffer: %w", err)
}
return nil
}
// uint64ToBase32 converts a uint64 to a base32 encoded integer encoded using
// as few 5-bit groups as possible.
func uint64ToBase32(num uint64) []byte {
// Return at least one group.
if num == 0 {
return []byte{0}
}
// To fit an uint64, we need at most is ceil(64 / 5) = 13 groups.
arr := make([]byte, 13)
i := 13
for num > 0 {
i--
arr[i] = byte(num & uint64(31)) // 0b11111 in binary
num >>= 5
}
// We only return non-zero leading groups.
return arr[i:]
}
package zpay32
import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// DefaultAssumedFinalCLTVDelta is the default value to be used as the
// final CLTV delta for a route if one is unspecified in the payment
// request.
// We adhere to the recommendation in BOLT 02 for terminal payments.
// See also:
// https://github.com/lightning/bolts/blob/master/02-peer-protocol.md
DefaultAssumedFinalCLTVDelta = 18
// feeRateParts is the total number of parts used to express fee rates.
feeRateParts = 1e6
)
// HopHint is a routing hint that contains the minimum information of a channel
// required for an intermediate hop in a route to forward the payment to the
// next. This should be ideally used for private channels, since they are not
// publicly advertised to the network for routing.
type HopHint struct {
// NodeID is the public key of the node at the start of the channel.
NodeID *btcec.PublicKey
// ChannelID is the unique identifier of the channel.
ChannelID uint64
// FeeBaseMSat is the base fee of the channel in millisatoshis.
FeeBaseMSat uint32
// FeeProportionalMillionths is the fee rate, in millionths of a
// satoshi, for every satoshi sent through the channel.
FeeProportionalMillionths uint32
// CLTVExpiryDelta is the time-lock delta of the channel.
CLTVExpiryDelta uint16
}
// Copy returns a deep copy of the hop hint.
func (h HopHint) Copy() HopHint {
nodeID := *h.NodeID
return HopHint{
NodeID: &nodeID,
ChannelID: h.ChannelID,
FeeBaseMSat: h.FeeBaseMSat,
FeeProportionalMillionths: h.FeeProportionalMillionths,
CLTVExpiryDelta: h.CLTVExpiryDelta,
}
}
// HopFee calculates the fee for a given amount that is forwarded over a hop.
// The amount has to be denoted in milli satoshi. The returned fee is also
// denoted in milli satoshi.
func (h HopHint) HopFee(amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
baseFee := lnwire.MilliSatoshi(h.FeeBaseMSat)
feeRate := lnwire.MilliSatoshi(h.FeeProportionalMillionths)
return baseFee + (amt*feeRate)/feeRateParts
}
package zpay32
import (
"errors"
"fmt"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)
const (
// mSatPerBtc is the number of millisatoshis in 1 BTC.
mSatPerBtc = 100000000000
// signatureBase32Len is the number of 5-bit groups needed to encode
// the 512 bit signature + 8 bit recovery ID.
signatureBase32Len = 104
// timestampBase32Len is the number of 5-bit groups needed to encode
// the 35-bit timestamp.
timestampBase32Len = 7
// hashBase32Len is the number of 5-bit groups needed to encode a
// 256-bit hash. Note that the last group will be padded with zeroes.
hashBase32Len = 52
// pubKeyBase32Len is the number of 5-bit groups needed to encode a
// 33-byte compressed pubkey. Note that the last group will be padded
// with zeroes.
pubKeyBase32Len = 53
// hopHintLen is the number of bytes needed to encode the hop hint of a
// single private route.
hopHintLen = 51
// The following byte values correspond to the supported field types.
// The field name is the character representing that 5-bit value in the
// bech32 string.
// fieldTypeP is the field containing the payment hash.
fieldTypeP = 1
// fieldTypeD contains a short description of the payment.
fieldTypeD = 13
// fieldTypeM contains the payment metadata.
fieldTypeM = 27
// fieldTypeN contains the pubkey of the target node.
fieldTypeN = 19
// fieldTypeH contains the hash of a description of the payment.
fieldTypeH = 23
// fieldTypeX contains the expiry in seconds of the invoice.
fieldTypeX = 6
// fieldTypeF contains a fallback on-chain address.
fieldTypeF = 9
// fieldTypeR contains extra routing information.
fieldTypeR = 3
// fieldTypeC contains an optional requested final CLTV delta.
fieldTypeC = 24
// fieldType9 contains one or more bytes for signaling features
// supported or required by the receiver.
fieldType9 = 5
// fieldTypeS contains a 32-byte payment address, which is a nonce
// included in the final hop's payload to prevent intermediaries from
// probing the recipient.
fieldTypeS = 16
// fieldTypeB contains blinded payment path information. This field may
// be repeated to include multiple blinded payment paths in the invoice.
fieldTypeB = 20
// maxInvoiceLength is the maximum total length an invoice can have.
// This is chosen to be the maximum number of bytes that can fit into a
// single QR code: https://en.wikipedia.org/wiki/QR_code#Storage
maxInvoiceLength = 7089
// DefaultInvoiceExpiry is the default expiry duration from the creation
// timestamp if expiry is set to zero.
DefaultInvoiceExpiry = time.Hour
)
var (
// ErrInvoiceTooLarge is returned when an invoice exceeds
// maxInvoiceLength.
ErrInvoiceTooLarge = errors.New("invoice is too large")
// ErrInvalidFieldLength is returned when a tagged field was specified
// with a length larger than the left over bytes of the data field.
ErrInvalidFieldLength = errors.New("invalid field length")
// ErrBrokenTaggedField is returned when the last tagged field is
// incorrectly formatted and doesn't have enough bytes to be read.
ErrBrokenTaggedField = errors.New("last tagged field is broken")
)
// MessageSigner is passed to the Encode method to provide a signature
// corresponding to the node's pubkey.
type MessageSigner struct {
// SignCompact signs the hash of the passed msg with the node's privkey.
// The returned signature should be 65 bytes, where the last 64 are the
// compact signature, and the first one is a header byte. This is the
// format returned by ecdsa.SignCompact.
SignCompact func(msg []byte) ([]byte, error)
}
// Invoice represents a decoded invoice, or to-be-encoded invoice. Some of the
// fields are optional, and will only be non-nil if the invoice this was parsed
// from contains that field. When encoding, only the non-nil fields will be
// added to the encoded invoice.
type Invoice struct {
// Net specifies what network this Lightning invoice is meant for.
Net *chaincfg.Params
// MilliSat specifies the amount of this invoice in millisatoshi.
// Optional.
MilliSat *lnwire.MilliSatoshi
// Timestamp specifies the time this invoice was created.
// Mandatory
Timestamp time.Time
// PaymentHash is the payment hash to be used for a payment to this
// invoice.
PaymentHash *[32]byte
// PaymentAddr is the payment address to be used by payments to prevent
// probing of the destination.
PaymentAddr fn.Option[[32]byte]
// Destination is the public key of the target node. This will always
// be set after decoding, and can optionally be set before encoding to
// include the pubkey as an 'n' field. If this is not set before
// encoding then the destination pubkey won't be added as an 'n' field,
// and the pubkey will be extracted from the signature during decoding.
Destination *btcec.PublicKey
// minFinalCLTVExpiry is the value that the creator of the invoice
// expects to be used for the CLTV expiry of the HTLC extended to it in
// the last hop.
//
// NOTE: This value is optional, and should be set to nil if the
// invoice creator doesn't have a strong requirement on the CLTV expiry
// of the final HTLC extended to it.
//
// This field is un-exported and can only be read by the
// MinFinalCLTVExpiry() method. By forcing callers to read via this
// method, we can easily enforce the default if not specified.
//
// NOTE: this field is ignored in the case that the invoice contains
// blinded paths since then the final minimum cltv expiry delta is
// expected to be included in the route's accumulated CLTV delta value.
minFinalCLTVExpiry *uint64
// Description is a short description of the purpose of this invoice.
// Optional. Non-nil iff DescriptionHash is nil.
Description *string
// DescriptionHash is the SHA256 hash of a description of the purpose of
// this invoice.
// Optional. Non-nil iff Description is nil.
DescriptionHash *[32]byte
// expiry specifies the timespan this invoice will be valid.
// Optional. If not set, a default expiry of 60 min will be implied.
//
// This field is unexported and can be read by the Expiry() method. This
// method makes sure the default expiry time is returned in case the
// field is not set.
expiry *time.Duration
// FallbackAddr is an on-chain address that can be used for payment in
// case the Lightning payment fails.
// Optional.
FallbackAddr btcutil.Address
// RouteHints represents one or more different route hints. Each route
// hint can be individually used to reach the destination. These usually
// represent private routes.
//
// NOTE: This is optional and should not be set at the same time as
// BlindedPaymentPaths.
RouteHints [][]HopHint
// BlindedPaymentPaths is a set of blinded payment paths that can be
// used to find the payment receiver.
//
// NOTE: This is optional and should not be set at the same time as
// RouteHints.
BlindedPaymentPaths []*BlindedPaymentPath
// Features represents an optional field used to signal optional or
// required support for features by the receiver.
Features *lnwire.FeatureVector
// Metadata is additional data that is sent along with the payment to
// the payee.
Metadata []byte
}
// Amount is a functional option that allows callers of NewInvoice to set the
// amount in millisatoshis that the Invoice should encode.
func Amount(milliSat lnwire.MilliSatoshi) func(*Invoice) {
return func(i *Invoice) {
i.MilliSat = &milliSat
}
}
// Destination is a functional option that allows callers of NewInvoice to
// explicitly set the pubkey of the Invoice's destination node.
func Destination(destination *btcec.PublicKey) func(*Invoice) {
return func(i *Invoice) {
i.Destination = destination
}
}
// Description is a functional option that allows callers of NewInvoice to set
// the payment description of the created Invoice.
//
// NOTE: Must be used if and only if DescriptionHash is not used.
func Description(description string) func(*Invoice) {
return func(i *Invoice) {
i.Description = &description
}
}
// CLTVExpiry is an optional value which allows the receiver of the payment to
// specify the delta between the current height and the HTLC extended to the
// receiver.
func CLTVExpiry(delta uint64) func(*Invoice) {
return func(i *Invoice) {
i.minFinalCLTVExpiry = &delta
}
}
// DescriptionHash is a functional option that allows callers of NewInvoice to
// set the payment description hash of the created Invoice.
//
// NOTE: Must be used if and only if Description is not used.
func DescriptionHash(descriptionHash [32]byte) func(*Invoice) {
return func(i *Invoice) {
i.DescriptionHash = &descriptionHash
}
}
// Expiry is a functional option that allows callers of NewInvoice to set the
// expiry of the created Invoice. If not set, a default expiry of 60 min will
// be implied.
func Expiry(expiry time.Duration) func(*Invoice) {
return func(i *Invoice) {
i.expiry = &expiry
}
}
// FallbackAddr is a functional option that allows callers of NewInvoice to set
// the Invoice's fallback on-chain address that can be used for payment in case
// the Lightning payment fails
func FallbackAddr(fallbackAddr btcutil.Address) func(*Invoice) {
return func(i *Invoice) {
i.FallbackAddr = fallbackAddr
}
}
// RouteHint is a functional option that allows callers of NewInvoice to add
// one or more hop hints that represent a private route to the destination.
func RouteHint(routeHint []HopHint) func(*Invoice) {
return func(i *Invoice) {
i.RouteHints = append(i.RouteHints, routeHint)
}
}
// WithBlindedPaymentPath is a functional option that allows a caller of
// NewInvoice to attach a blinded payment path to the invoice. The option can
// be used multiple times to attach multiple paths.
func WithBlindedPaymentPath(p *BlindedPaymentPath) func(*Invoice) {
return func(i *Invoice) {
i.BlindedPaymentPaths = append(i.BlindedPaymentPaths, p)
}
}
// Features is a functional option that allows callers of NewInvoice to set the
// desired feature bits that are advertised on the invoice. If this option is
// not used, an empty feature vector will automatically be populated.
func Features(features *lnwire.FeatureVector) func(*Invoice) {
return func(i *Invoice) {
i.Features = features
}
}
// PaymentAddr is a functional option that allows callers of NewInvoice to set
// the desired payment address that is advertised on the invoice.
func PaymentAddr(addr [32]byte) func(*Invoice) {
return func(i *Invoice) {
i.PaymentAddr = fn.Some(addr)
}
}
// Metadata is a functional option that allows callers of NewInvoice to set
// the desired payment Metadata that is advertised on the invoice.
func Metadata(metadata []byte) func(*Invoice) {
return func(i *Invoice) {
i.Metadata = metadata
}
}
// NewInvoice creates a new Invoice object. The last parameter is a set of
// variadic arguments for setting optional fields of the invoice.
//
// NOTE: Either Description or DescriptionHash must be provided for the Invoice
// to be considered valid.
func NewInvoice(net *chaincfg.Params, paymentHash [32]byte,
timestamp time.Time, options ...func(*Invoice)) (*Invoice, error) {
invoice := &Invoice{
Net: net,
PaymentHash: &paymentHash,
Timestamp: timestamp,
}
for _, option := range options {
option(invoice)
}
// If no features were set, we'll populate an empty feature vector.
if invoice.Features == nil {
invoice.Features = lnwire.NewFeatureVector(
nil, lnwire.Features,
)
}
if err := validateInvoice(invoice); err != nil {
return nil, err
}
return invoice, nil
}
// Expiry returns the expiry time for this invoice. If expiry time is not set
// explicitly, the default 3600 second expiry will be returned.
func (invoice *Invoice) Expiry() time.Duration {
if invoice.expiry != nil {
return *invoice.expiry
}
// If no expiry is set for this invoice, default is 3600 seconds.
return DefaultInvoiceExpiry
}
// MinFinalCLTVExpiry returns the minimum final CLTV expiry delta as specified
// by the creator of the invoice. This value specifies the delta between the
// current height and the expiry height of the HTLC extended in the last hop.
func (invoice *Invoice) MinFinalCLTVExpiry() uint64 {
if invoice.minFinalCLTVExpiry != nil {
return *invoice.minFinalCLTVExpiry
}
return DefaultAssumedFinalCLTVDelta
}
// validateInvoice does a sanity check of the provided Invoice, making sure it
// has all the necessary fields set for it to be considered valid by BOLT-0011.
func validateInvoice(invoice *Invoice) error {
// The net must be set.
if invoice.Net == nil {
return fmt.Errorf("net params not set")
}
// The invoice must contain a payment hash.
if invoice.PaymentHash == nil {
return fmt.Errorf("no payment hash found")
}
if len(invoice.RouteHints) != 0 &&
len(invoice.BlindedPaymentPaths) != 0 {
return fmt.Errorf("cannot have both route hints and blinded " +
"payment paths")
}
// Either Description or DescriptionHash must be set, not both.
if invoice.Description != nil && invoice.DescriptionHash != nil {
return fmt.Errorf("both description and description hash set")
}
if invoice.Description == nil && invoice.DescriptionHash == nil {
return fmt.Errorf("neither description nor description hash set")
}
// Check that we support the field lengths.
if len(invoice.PaymentHash) != 32 {
return fmt.Errorf("unsupported payment hash length: %d",
len(invoice.PaymentHash))
}
if invoice.DescriptionHash != nil && len(invoice.DescriptionHash) != 32 {
return fmt.Errorf("unsupported description hash length: %d",
len(invoice.DescriptionHash))
}
if invoice.Destination != nil &&
len(invoice.Destination.SerializeCompressed()) != 33 {
return fmt.Errorf("unsupported pubkey length: %d",
len(invoice.Destination.SerializeCompressed()))
}
// Ensure that all invoices have feature vectors.
if invoice.Features == nil {
return fmt.Errorf("missing feature vector")
}
return nil
}